1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
use super::OverflowInteger;
use halo2_base::{
gates::{GateInstructions, RangeInstructions},
utils::{bigint_to_fe, fe_to_bigint, BigPrimeField},
Context,
QuantumCell::{Constant, Existing, Witness},
};
use num_bigint::BigInt;
// check that `a` carries to `0 mod 2^{a.limb_bits * a.limbs.len()}`
// same as `assign` above except we need to provide `c_{k - 1}` witness as well
// checks there exist d_i = -c_i so that
// a_0 = c_0 * 2^n
// a_i + c_{i - 1} = c_i * 2^n for i = 1..=k - 1
// and c_i \in [-2^{m - n + EPSILON}, 2^{m - n + EPSILON}], with EPSILON >= 1 for i = 0..=k-1
// where m = a.max_limb_size.bits() and we choose EPSILON to round up to the next multiple of the range check table size
//
// translated to d_i, this becomes:
// a_0 + d_0 * 2^n = 0
// a_i + d_i * 2^n = d_{i - 1} for i = 1.. k - 1
// aztec optimization:
// note that a_i + c_{i - 1} = c_i * 2^n can be expanded to
// a_i * 2^{n*w} + a_{i - 1} * 2^{n*(w-1)} + ... + a_{i - w} + c_{i - w - 1} = c_i * 2^{n*(w+1)}
// which is valid as long as `(m - n + EPSILON) + n * (w+1) < native_modulus::<F>().bits() - 1`
// so we only need to range check `c_i` every `w + 1` steps, starting with `i = w`
pub fn truncate<F: BigPrimeField>(
range: &impl RangeInstructions<F>,
ctx: &mut Context<F>,
a: OverflowInteger<F>,
limb_bits: usize,
limb_base: F,
limb_base_big: &BigInt,
) {
let k = a.limbs.len();
let max_limb_bits = a.max_limb_bits;
let mut carries = Vec::with_capacity(k);
for a_limb in a.limbs.iter() {
let a_val_big = fe_to_bigint(a_limb.value());
let carry = if let Some(carry_val) = carries.last() {
(a_val_big + carry_val) / limb_base_big
} else {
// warning: using >> on negative integer produces undesired effect
a_val_big / limb_base_big
};
carries.push(carry);
}
// round `max_limb_bits - limb_bits + EPSILON + 1` up to the next multiple of range.lookup_bits
const EPSILON: usize = 1;
let range_bits = max_limb_bits - limb_bits + EPSILON;
let range_bits =
((range_bits + range.lookup_bits()) / range.lookup_bits()) * range.lookup_bits() - 1;
// `window = w + 1` valid as long as `range_bits + n * (w+1) < native_modulus::<F>().bits() - 1`
// let window = (F::NUM_BITS as usize - 2 - range_bits) / limb_bits;
// assert!(window > 0);
// In practice, we are currently always using window = 1 so the above is commented out
let shift_val = range.gate().pow_of_two()[range_bits];
// let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1;
let mut previous = None;
for (a_limb, carry) in a.limbs.into_iter().zip(carries.into_iter()) {
let neg_carry_val = bigint_to_fe(&-carry);
ctx.assign_region(
[
Existing(a_limb),
Witness(neg_carry_val),
Constant(limb_base),
previous.map(Existing).unwrap_or_else(|| Constant(F::zero())),
],
[0],
);
let neg_carry = ctx.get(-3);
// i in 0..num_windows {
// let idx = std::cmp::min(window * i + window - 1, k - 1);
// let carry_cell = &neg_carry_assignments[idx];
let shifted_carry = range.gate().add(ctx, neg_carry, Constant(shift_val));
range.range_check(ctx, shifted_carry, range_bits + 1);
previous = Some(neg_carry);
}
}