1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#![allow(non_snake_case)]
use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip};
use crate::ecc::{ec_sub_strict, load_random_point};
use crate::fields::{FieldChip, PrimeField, Selectable};
use group::Curve;
use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder};
use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context};
use itertools::Itertools;
use rayon::prelude::*;
use std::cmp::min;

/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant)
/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s
/// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
///
/// # Assumptions
/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits)
/// - `scalar > 0`
/// - `max_bits <= modulus::<F>.bits()`
pub fn scalar_multiply<F, FC, C>(
    chip: &FC,
    ctx: &mut Context<F>,
    point: &C,
    scalar: Vec<AssignedValue<F>>,
    max_bits: usize,
    window_bits: usize,
) -> EcPoint<F, FC::FieldPoint>
where
    F: PrimeField,
    C: CurveAffineExt,
    FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
{
    if point.is_identity().into() {
        let zero = chip.load_constant(ctx, C::Base::zero());
        return EcPoint::new(zero.clone(), zero);
    }
    assert!(!scalar.is_empty());
    assert!((max_bits as u32) <= F::NUM_BITS);

    let total_bits = max_bits * scalar.len();
    let num_windows = (total_bits + window_bits - 1) / window_bits;

    // Jacobian coordinate
    let base_pt = point.to_curve();
    // cached_points[i * 2^w + j] holds `[j * 2^(i * w)] * point` for j in {0, ..., 2^w - 1}

    // first we compute all cached points in Jacobian coordinates since it's fastest
    let mut increment = base_pt;
    let cached_points_jacobian = (0..num_windows)
        .flat_map(|i| {
            let mut curr = increment;
            // start with increment at index 0 instead of identity just as a dummy value to avoid divide by 0 issues
            let cache_vec = std::iter::once(increment)
                .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(|_| {
                    let prev = curr;
                    curr += increment;
                    prev
                }))
                .collect::<Vec<_>>();
            increment = curr;
            cache_vec
        })
        .collect::<Vec<_>>();
    // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive
    // initialize to all 0s
    let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
    C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);

    // TODO: do not assign and use select_from_bits on Constant(_) QuantumCells
    let cached_points = cached_points_affine
        .into_iter()
        .map(|point| {
            let (x, y) = point.into_coordinates();
            let [x, y] = [x, y].map(|x| chip.load_constant(ctx, x));
            EcPoint::new(x, y)
        })
        .collect_vec();

    let bits = scalar
        .into_iter()
        .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits))
        .collect::<Vec<_>>();

    let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();
    let bit_window_rev = bits.chunks(window_bits).rev();
    let any_point = load_random_point::<F, FC, C>(chip, ctx);
    let mut curr_point = any_point.clone();
    for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
        let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied());
        // are we just adding a window of all 0s? if so, skip
        let is_zero_window = chip.gate().is_zero(ctx, bit_sum);
        curr_point = {
            let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window);
            let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true);
            ec_select(chip, ctx, curr_point, sum, is_zero_window)
        };
    }
    ec_sub_strict(chip, ctx, curr_point, any_point)
}

// basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation
// we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO)

/// # Assumptions
/// * `points.len() = scalars.len()`
/// * `scalars[i].len() = scalars[j].len()` for all `i,j`
/// * `points` are all on the curve
/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand
/// * The integer value of `scalars[i]` is less than the order of `points[i]`
/// * Output may be point at infinity, in which case (0, 0) is returned
pub fn msm_par<F, FC, C>(
    chip: &EccChip<F, FC>,
    builder: &mut GateThreadBuilder<F>,
    points: &[C],
    scalars: Vec<Vec<AssignedValue<F>>>,
    max_scalar_bits_per_cell: usize,
    window_bits: usize,
    phase: usize,
) -> EcPoint<F, FC::FieldPoint>
where
    F: PrimeField,
    C: CurveAffineExt,
    FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
{
    if points.is_empty() {
        return chip.assign_constant_point(builder.main(phase), C::identity());
    }
    assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS);
    assert_eq!(points.len(), scalars.len());
    assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point");
    let scalar_len = scalars[0].len();
    let total_bits = max_scalar_bits_per_cell * scalar_len;
    let num_windows = (total_bits + window_bits - 1) / window_bits;

    // `cached_points` is a flattened 2d vector
    // first we compute all cached points in Jacobian coordinates since it's fastest
    let cached_points_jacobian = points
        .par_iter()
        .flat_map(|point| -> Vec<_> {
            let base_pt = point.to_curve();
            // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1}
            // EXCEPT cached_points[idx][0] = points[idx]
            let mut increment = base_pt;
            (0..num_windows)
                .flat_map(|i| {
                    let mut curr = increment;
                    let cache_vec = std::iter::once(increment)
                        .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(
                            |_| {
                                let prev = curr;
                                curr += increment;
                                prev
                            },
                        ))
                        .collect::<Vec<_>>();
                    increment = curr;
                    cache_vec
                })
                .collect()
        })
        .collect::<Vec<_>>();
    // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive
    // initialize to all 0s
    let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
    C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);

    let field_chip = chip.field_chip();
    let ctx = builder.main(phase);
    let any_point = chip.load_random_point::<C>(ctx);

    let scalar_mults = parallelize_in(
        phase,
        builder,
        cached_points_affine
            .chunks(cached_points_affine.len() / points.len())
            .zip_eq(scalars)
            .collect(),
        |ctx, (cached_points, scalar)| {
            let cached_points = cached_points
                .iter()
                .map(|point| chip.assign_constant_point(ctx, *point))
                .collect_vec();
            let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();

            assert_eq!(scalar.len(), scalar_len);
            let bits = scalar
                .into_iter()
                .flat_map(|scalar_chunk| {
                    field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell)
                })
                .collect::<Vec<_>>();
            let bit_window_rev = bits.chunks(window_bits).rev();
            let mut curr_point = any_point.clone();
            for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
                let is_zero_window = {
                    let sum = field_chip.gate().sum(ctx, bit_window.iter().copied());
                    field_chip.gate().is_zero(ctx, sum)
                };
                curr_point = {
                    let add_point =
                        ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window);
                    let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true);
                    ec_select(field_chip, ctx, curr_point, sum, is_zero_window)
                };
            }
            curr_point
        },
    );
    let ctx = builder.main(phase);
    // sum `scalar_mults` but take into account possiblity of identity points
    let any_point2 = chip.load_random_point::<C>(ctx);
    let mut acc = any_point2.clone();
    for point in scalar_mults {
        let new_acc = chip.add_unequal(ctx, &acc, point, true);
        acc = chip.sub_unequal(ctx, new_acc, &any_point, true);
    }
    ec_sub_strict(field_chip, ctx, acc, any_point2)
}