Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 357 additions & 0 deletions extensions/riscv/circuit/src/base_alu_w/core.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
use std::{
array,
borrow::{Borrow, BorrowMut},
iter::zip,
};

use openvm_circuit::{
arch::*,
system::memory::{online::TracingMemory, MemoryAuxColsFactory},
};
use openvm_circuit_primitives::{
bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
utils::not,
AlignedBytesBorrow,
};
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
use openvm_riscv_transpiler::BaseAluOpcode;
use openvm_stark_backend::{
interaction::InteractionBuilder,
p3_air::{AirBuilder, BaseAir},
p3_field::{Field, FieldAlgebra, PrimeField32},
rap::BaseAirWithPublicValues,
};
use strum::IntoEnumIterator;

#[repr(C)]
#[derive(AlignedBorrow, Debug)]
pub struct BaseAluCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub a: [T; NUM_LIMBS],
pub b: [T; NUM_LIMBS],
pub c: [T; NUM_LIMBS],

pub opcode_add_flag: T,
pub opcode_sub_flag: T,
pub opcode_xor_flag: T,
pub opcode_or_flag: T,
pub opcode_and_flag: T,
}

#[derive(Copy, Clone, Debug, derive_new::new)]
pub struct BaseAluCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub bus: BitwiseOperationLookupBus,
pub offset: usize,
}

impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
{
fn width(&self) -> usize {
BaseAluCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
}
}
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
{
}

impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
where
AB: InteractionBuilder,
I: VmAdapterInterface<AB::Expr>,
I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
{
fn eval(
&self,
builder: &mut AB,
local_core: &[AB::Var],
_from_pc: AB::Var,
) -> AdapterAirContext<AB::Expr, I> {
let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
let flags = [
cols.opcode_add_flag,
cols.opcode_sub_flag,
cols.opcode_xor_flag,
cols.opcode_or_flag,
cols.opcode_and_flag,
];

let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
builder.assert_bool(flag);
acc + flag.into()
});
builder.assert_bool(is_valid.clone());

let a = &cols.a;
let b = &cols.b;
let c = &cols.c;

// For ADD, define carry[i] = (b[i] + c[i] + carry[i - 1] - a[i]) / 2^LIMB_BITS. If
// each carry[i] is boolean and 0 <= a[i] < 2^LIMB_BITS, it can be proven that
// a[i] = (b[i] + c[i]) % 2^LIMB_BITS as necessary. The same holds for SUB when
// carry[i] is (a[i] + c[i] - b[i] + carry[i - 1]) / 2^LIMB_BITS.
let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();

for i in 0..NUM_LIMBS {
// We explicitly separate the constraints for ADD and SUB in order to keep degree
// cubic. Because we constrain that the carry (which is arbitrary) is bool, if
// carry has degree larger than 1 the max-degree constrain could be at least 4.
carry_add[i] = AB::Expr::from(carry_divide)
* (b[i] + c[i] - a[i]
+ if i > 0 {
carry_add[i - 1].clone()
} else {
AB::Expr::ZERO
});
builder
.when(cols.opcode_add_flag)
.assert_bool(carry_add[i].clone());
carry_sub[i] = AB::Expr::from(carry_divide)
* (a[i] + c[i] - b[i]
+ if i > 0 {
carry_sub[i - 1].clone()
} else {
AB::Expr::ZERO
});
builder
.when(cols.opcode_sub_flag)
.assert_bool(carry_sub[i].clone());
}

// Interaction with BitwiseOperationLookup to range check a for ADD and SUB, and
// constrain a's correctness for XOR, OR, and AND.
let bitwise = cols.opcode_xor_flag + cols.opcode_or_flag + cols.opcode_and_flag;
for i in 0..NUM_LIMBS {
let x = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * b[i];
let y = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * c[i];
let x_xor_y = cols.opcode_xor_flag * a[i]
+ cols.opcode_or_flag * ((AB::Expr::from_canonical_u32(2) * a[i]) - b[i] - c[i])
+ cols.opcode_and_flag * (b[i] + c[i] - (AB::Expr::from_canonical_u32(2) * a[i]));
self.bus
.send_xor(x, y, x_xor_y)
.eval(builder, is_valid.clone());
}

let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
self,
flags.iter().zip(BaseAluOpcode::iter()).fold(
AB::Expr::ZERO,
|acc, (flag, local_opcode)| {
acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
},
),
);

AdapterAirContext {
to_pc: None,
reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
writes: [cols.a.map(Into::into)].into(),
instruction: MinimalInstruction {
is_valid,
opcode: expected_opcode,
}
.into(),
}
}

fn start_offset(&self) -> usize {
self.offset
}
}

#[repr(C, align(4))]
#[derive(AlignedBytesBorrow, Debug)]
pub struct BaseAluCoreRecord<const NUM_LIMBS: usize> {
pub b: [u8; NUM_LIMBS],
pub c: [u8; NUM_LIMBS],
// Use u8 instead of usize for better packing
pub local_opcode: u8,
}

#[derive(Clone, Copy, derive_new::new)]
pub struct BaseAluExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
adapter: A,
pub offset: usize,
}

#[derive(derive_new::new)]
pub struct BaseAluFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
adapter: A,
pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
pub offset: usize,
}

impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
for BaseAluExecutor<A, NUM_LIMBS, LIMB_BITS>
where
F: PrimeField32,
A: 'static
+ AdapterTraceExecutor<
F,
ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
WriteData: From<[[u8; NUM_LIMBS]; 1]>,
>,
for<'buf> RA: RecordArena<
'buf,
EmptyAdapterCoreLayout<F, A>,
(A::RecordMut<'buf>, &'buf mut BaseAluCoreRecord<NUM_LIMBS>),
>,
{
fn get_opcode_name(&self, opcode: usize) -> String {
format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset))
}

fn execute(
&self,
state: VmStateMut<F, TracingMemory, RA>,
instruction: &Instruction<F>,
) -> Result<(), ExecutionError> {
let Instruction { opcode, .. } = instruction;

let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset));
let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());

A::start(*state.pc, state.memory, &mut adapter_record);

[core_record.b, core_record.c] = self
.adapter
.read(state.memory, instruction, &mut adapter_record)
.into();

let rd = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &core_record.b, &core_record.c);

core_record.local_opcode = local_opcode as u8;

self.adapter
.write(state.memory, instruction, [rd].into(), &mut adapter_record);

*state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);

Ok(())
}
}

impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
for BaseAluFiller<A, NUM_LIMBS, LIMB_BITS>
where
F: PrimeField32,
A: 'static + AdapterTraceFiller<F>,
{
fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
// SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
// BaseAluCoreCols::width() elements
let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
self.adapter.fill_trace_row(mem_helper, adapter_row);
// SAFETY: core_row contains a valid BaseAluCoreRecord written by the executor
// during trace generation
let record: &BaseAluCoreRecord<NUM_LIMBS> =
unsafe { get_record_from_slice(&mut core_row, ()) };
let core_row: &mut BaseAluCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
// SAFETY: the following is highly unsafe. We are going to cast `core_row` to a record
// buffer, and then do an _overlapping_ write to the `core_row` as a row of field elements.
// This requires:
// - Cols and Record structs should be repr(C) and we write in reverse order (to ensure
// non-overlapping)
// - Do not overwrite any reference in `record` before it has already been used or moved
// - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic
// otherwise)

let local_opcode = BaseAluOpcode::from_usize(record.local_opcode as usize);
let a = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &record.b, &record.c);
// PERF: needless conversion
core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND);
core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR);
core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR);
core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB);
core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD);

if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB {
for a_val in a {
self.bitwise_lookup_chip
.request_xor(a_val as u32, a_val as u32);
}
} else {
for (b_val, c_val) in zip(record.b, record.c) {
self.bitwise_lookup_chip
.request_xor(b_val as u32, c_val as u32);
}
}
core_row.c = record.c.map(F::from_canonical_u8);
core_row.b = record.b.map(F::from_canonical_u8);
core_row.a = a.map(F::from_canonical_u8);
}
}

#[inline(always)]
pub(super) fn run_alu<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
opcode: BaseAluOpcode,
x: &[u8; NUM_LIMBS],
y: &[u8; NUM_LIMBS],
) -> [u8; NUM_LIMBS] {
debug_assert!(LIMB_BITS <= 8, "specialize for bytes");
match opcode {
BaseAluOpcode::ADD => run_add::<NUM_LIMBS, LIMB_BITS>(x, y),
BaseAluOpcode::SUB => run_subtract::<NUM_LIMBS, LIMB_BITS>(x, y),
BaseAluOpcode::XOR => run_xor::<NUM_LIMBS>(x, y),
BaseAluOpcode::OR => run_or::<NUM_LIMBS>(x, y),
BaseAluOpcode::AND => run_and::<NUM_LIMBS>(x, y),
}
}

#[inline(always)]
fn run_add<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u8; NUM_LIMBS],
y: &[u8; NUM_LIMBS],
) -> [u8; NUM_LIMBS] {
let mut z = [0u8; NUM_LIMBS];
let mut carry = [0u8; NUM_LIMBS];
for i in 0..NUM_LIMBS {
let mut overflow =
(x[i] as u16) + (y[i] as u16) + if i > 0 { carry[i - 1] as u16 } else { 0 };
carry[i] = (overflow >> LIMB_BITS) as u8;
overflow &= (1u16 << LIMB_BITS) - 1;
z[i] = overflow as u8;
}
z
}

#[inline(always)]
fn run_subtract<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u8; NUM_LIMBS],
y: &[u8; NUM_LIMBS],
) -> [u8; NUM_LIMBS] {
let mut z = [0u8; NUM_LIMBS];
let mut carry = [0u8; NUM_LIMBS];
for i in 0..NUM_LIMBS {
let rhs = y[i] as u16 + if i > 0 { carry[i - 1] as u16 } else { 0 };
if x[i] as u16 >= rhs {
z[i] = x[i] - rhs as u8;
carry[i] = 0;
} else {
z[i] = (x[i] as u16 + (1u16 << LIMB_BITS) - rhs) as u8;
carry[i] = 1;
}
}
z
}

#[inline(always)]
fn run_xor<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
array::from_fn(|i| x[i] ^ y[i])
}

#[inline(always)]
fn run_or<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
array::from_fn(|i| x[i] | y[i])
}

#[inline(always)]
fn run_and<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
array::from_fn(|i| x[i] & y[i])
}
Loading
Loading