1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use super::{check_carry_to_zero, CRTInteger, OverflowInteger};
use halo2_base::{
gates::{GateInstructions, RangeInstructions},
utils::{decompose_bigint, BigPrimeField},
AssignedValue, Context,
QuantumCell::{Constant, Existing, Witness},
};
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::{One, Signed, Zero};
use std::{cmp::max, iter};
pub fn crt<F: BigPrimeField>(
range: &impl RangeInstructions<F>,
ctx: &mut Context<F>,
a: CRTInteger<F>,
k_bits: usize, modulus: &BigInt,
mod_vec: &[F],
mod_native: F,
limb_bits: usize,
limb_bases: &[F],
limb_base_big: &BigInt,
) {
let n = limb_bits;
let k = a.truncation.limbs.len();
let trunc_len = n * k;
debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2);
let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize);
assert!(quot_max_bits < trunc_len);
let quot_last_limb_bits = quot_max_bits - n * (k - 1);
let (quot_val, _out_val) = a.value.div_mod_floor(modulus);
debug_assert_eq!(_out_val, BigInt::zero());
debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits));
let quot_vec = decompose_bigint::<F>("_val, k, n);
debug_assert!(modulus < &(BigInt::one() << (n * k)));
let mut quot_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
let mut check_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
for (i, (a_limb, quot_v)) in a.truncation.limbs.into_iter().zip(quot_vec).enumerate() {
let (prod, new_quot_cell) = range.gate().inner_product_left_last(
ctx,
quot_assigned.iter().map(|x| Existing(*x)).chain(iter::once(Witness(quot_v))),
mod_vec[0..=i].iter().rev().map(|c| Constant(*c)),
);
let check_val = *prod.value() - a_limb.value();
let check_cell = ctx
.assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]);
quot_assigned.push(new_quot_cell);
check_assigned.push(check_cell);
}
for (q_index, quot_cell) in quot_assigned.iter().enumerate() {
let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n };
let limb_base =
if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] };
let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base));
range.range_check(ctx, quot_shift, limb_bits + 1);
}
let check_overflow_int =
OverflowInteger::new(check_assigned, max(a.truncation.max_limb_bits, 2 * n + k_bits));
check_carry_to_zero::truncate::<F>(
range,
ctx,
check_overflow_int,
limb_bits,
limb_bases[1],
limb_base_big,
);
let quot_native =
OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases);
ctx.assign_region(
[Constant(F::zero()), Constant(mod_native), Existing(quot_native), Existing(a.native)],
[0],
);
}