1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use super::{CRTInteger, OverflowInteger};
use halo2_base::{
gates::GateInstructions,
utils::{log2_ceil, ScalarField},
Context,
QuantumCell::Constant,
};
use itertools::Itertools;
use std::cmp::max;
pub fn assign<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: OverflowInteger<F>,
b: OverflowInteger<F>,
c_f: F,
c_log2_ceil: usize,
) -> OverflowInteger<F> {
let out_limbs = a
.limbs
.into_iter()
.zip_eq(b.limbs)
.map(|(a_limb, b_limb)| gate.mul_add(ctx, a_limb, Constant(c_f), b_limb))
.collect();
OverflowInteger::new(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1)
}
pub fn crt<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: CRTInteger<F>,
b: CRTInteger<F>,
c: i64,
) -> CRTInteger<F> {
debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len());
let (c_f, c_abs) = if c >= 0 {
let c_abs = u64::try_from(c).unwrap();
(F::from(c_abs), c_abs)
} else {
let c_abs = u64::try_from(-c).unwrap();
(-F::from(c_abs), c_abs)
};
let out_trunc = assign(gate, ctx, a.truncation, b.truncation, c_f, log2_ceil(c_abs));
let out_native = gate.mul_add(ctx, a.native, Constant(c_f), b.native);
let out_val = a.value * c + b.value;
CRTInteger::new(out_trunc, out_native, out_val)
}