1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
use super::{CRTInteger, OverflowInteger};
use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context};
use itertools::Itertools;
use std::cmp::max;

/// # Assumptions
/// * `a, b` have same number of limbs
/// * Number of limbs is nonzero
pub fn assign<F: ScalarField>(
    gate: &impl GateInstructions<F>,
    ctx: &mut Context<F>,
    a: OverflowInteger<F>,
    b: OverflowInteger<F>,
    sel: AssignedValue<F>,
) -> OverflowInteger<F> {
    let out_limbs = a
        .limbs
        .into_iter()
        .zip_eq(b.limbs)
        .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel))
        .collect();

    OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits))
}

pub fn crt<F: ScalarField>(
    gate: &impl GateInstructions<F>,
    ctx: &mut Context<F>,
    a: CRTInteger<F>,
    b: CRTInteger<F>,
    sel: AssignedValue<F>,
) -> CRTInteger<F> {
    debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len());
    let out_limbs = a
        .truncation
        .limbs
        .into_iter()
        .zip_eq(b.truncation.limbs)
        .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel))
        .collect();

    let out_trunc = OverflowInteger::new(
        out_limbs,
        max(a.truncation.max_limb_bits, b.truncation.max_limb_bits),
    );

    let out_native = gate.select(ctx, a.native, b.native, sel);
    let out_val = if sel.value().is_zero_vartime() { b.value } else { a.value };
    CRTInteger::new(out_trunc, out_native, out_val)
}