Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 71 additions & 27 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_impl_generic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,34 +771,78 @@ template <class T> constexpr field<T> field<T>::montgomery_mul(const field& othe
auto left = wasm_convert(data);
auto right = wasm_convert(other.data);
constexpr uint64_t mask = 0x1fffffff;
uint64_t temp_0 = 0;
uint64_t temp_1 = 0;
uint64_t temp_2 = 0;
uint64_t temp_3 = 0;
uint64_t temp_4 = 0;
uint64_t temp_5 = 0;
uint64_t temp_6 = 0;
uint64_t temp_7 = 0;
uint64_t temp_8 = 0;
uint64_t temp_9 = 0;
uint64_t temp_10 = 0;
uint64_t temp_11 = 0;
uint64_t temp_12 = 0;
uint64_t temp_13 = 0;
uint64_t temp_14 = 0;
uint64_t temp_15 = 0;
uint64_t temp_16 = 0;

// Perform a series of mul-adds and then reductions
wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
// Karatsuba multiplication: split 9 limbs into 5 (lo) + 4 (hi).
// P_lo = left[0..4] * right[0..4] (25 muls)
// P_hi = left[5..8] * right[5..8] (16 muls)
// P_cross = (left_lo + left_hi) * (right_lo + right_hi) (25 muls)
// P_mid = P_cross - P_lo - P_hi
// Total: 66 muls vs 81 for schoolbook 9x9.

// P_lo = left[0..4] * right[0..4] — 5x5 schoolbook
uint64_t pl0 = left[0] * right[0];
uint64_t pl1 = left[0] * right[1] + left[1] * right[0];
uint64_t pl2 = left[0] * right[2] + left[1] * right[1] + left[2] * right[0];
uint64_t pl3 = left[0] * right[3] + left[1] * right[2] + left[2] * right[1] + left[3] * right[0];
uint64_t pl4 =
left[0] * right[4] + left[1] * right[3] + left[2] * right[2] + left[3] * right[1] + left[4] * right[0];
uint64_t pl5 = left[1] * right[4] + left[2] * right[3] + left[3] * right[2] + left[4] * right[1];
uint64_t pl6 = left[2] * right[4] + left[3] * right[3] + left[4] * right[2];
uint64_t pl7 = left[3] * right[4] + left[4] * right[3];
uint64_t pl8 = left[4] * right[4];

// P_hi = left[5..8] * right[5..8] — 4x4 schoolbook
uint64_t ph0 = left[5] * right[5];
uint64_t ph1 = left[5] * right[6] + left[6] * right[5];
uint64_t ph2 = left[5] * right[7] + left[6] * right[6] + left[7] * right[5];
uint64_t ph3 = left[5] * right[8] + left[6] * right[7] + left[7] * right[6] + left[8] * right[5];
uint64_t ph4 = left[6] * right[8] + left[7] * right[7] + left[8] * right[6];
uint64_t ph5 = left[7] * right[8] + left[8] * right[7];
uint64_t ph6 = left[8] * right[8];

// Sums for the cross product (left_lo + left_hi, right_lo + right_hi)
uint64_t sl0 = left[0] + left[5];
uint64_t sl1 = left[1] + left[6];
uint64_t sl2 = left[2] + left[7];
uint64_t sl3 = left[3] + left[8];
uint64_t sl4 = left[4];
uint64_t sr0 = right[0] + right[5];
uint64_t sr1 = right[1] + right[6];
uint64_t sr2 = right[2] + right[7];
uint64_t sr3 = right[3] + right[8];
uint64_t sr4 = right[4];

// P_cross = sum_left * sum_right — 5x5 schoolbook
uint64_t pc0 = sl0 * sr0;
uint64_t pc1 = sl0 * sr1 + sl1 * sr0;
uint64_t pc2 = sl0 * sr2 + sl1 * sr1 + sl2 * sr0;
uint64_t pc3 = sl0 * sr3 + sl1 * sr2 + sl2 * sr1 + sl3 * sr0;
uint64_t pc4 = sl0 * sr4 + sl1 * sr3 + sl2 * sr2 + sl3 * sr1 + sl4 * sr0;
uint64_t pc5 = sl1 * sr4 + sl2 * sr3 + sl3 * sr2 + sl4 * sr1;
uint64_t pc6 = sl2 * sr4 + sl3 * sr3 + sl4 * sr2;
uint64_t pc7 = sl3 * sr4 + sl4 * sr3;
uint64_t pc8 = sl4 * sr4;

// Combine: temp[k] = P_lo[k] + P_mid[k-5] + P_hi[k-10]
// where P_mid = P_cross - P_lo - P_hi
uint64_t temp_0 = pl0;
uint64_t temp_1 = pl1;
uint64_t temp_2 = pl2;
uint64_t temp_3 = pl3;
uint64_t temp_4 = pl4;
uint64_t temp_5 = pl5 + (pc0 - pl0 - ph0);
uint64_t temp_6 = pl6 + (pc1 - pl1 - ph1);
uint64_t temp_7 = pl7 + (pc2 - pl2 - ph2);
uint64_t temp_8 = pl8 + (pc3 - pl3 - ph3);
uint64_t temp_9 = pc4 - pl4 - ph4;
uint64_t temp_10 = (pc5 - pl5 - ph5) + ph0;
uint64_t temp_11 = (pc6 - pl6 - ph6) + ph1;
uint64_t temp_12 = (pc7 - pl7) + ph2;
uint64_t temp_13 = (pc8 - pl8) + ph3;
uint64_t temp_14 = ph4;
uint64_t temp_15 = ph5;
uint64_t temp_16 = ph6;

// At this point, the value aR * bR is contained in \sum_{i=0}^16 temp_{i}*2^{29*i}. Note that this value is no
// greater than 4p^2 as aR and bR are both less than 2p.
wasm_reduce_yuval(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
Expand Down
Loading