Skip to content

Commit bd9af8e

Browse files
committed
table-aware constraint system
This is the last PR in the series that implements [CRY-363: Table-Aware Constraint System](https://linear.app/irreducible/issue/CRY-363/table-aware-constraint-system) In this PR we introduce a structure called SymbolicMultilinearOracleSet. This structure is very similar to MultilinearOracleSet with the exception that it does not have any information about the sizes (n_vars) of the declared multilinears. This symbolic representation is what defines the ConstraintSystem, instead of the concrete one. That, and other changes in this PR, make constraint system serialization independent of the sizes of tables. This PR avoids making large changes to the proving path and to achieve that I decided to take the strategy in which the multilinears for zero-sized tables are not created at all.
1 parent b7cb7b7 commit bd9af8e

File tree

14 files changed

+801
-182
lines changed

14 files changed

+801
-182
lines changed

crates/core/src/constraint_system/channel.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,8 @@ where
129129
ref selectors,
130130
multiplicity,
131131
table_id,
132-
log_values_per_row,
132+
log_values_per_row: _,
133133
} = flush;
134-
let _ = log_values_per_row;
135134

136135
if channel_id > max_channel_id {
137136
return Err(Error::ChannelIdOutOfRange {
@@ -141,6 +140,9 @@ where
141140
}
142141

143142
let table_size = table_sizes[table_id];
143+
if table_size == 0 {
144+
continue;
145+
}
144146

145147
// We check the variables only of OracleOrConst::Oracle variant oracles being the same.
146148
let non_const_polys = oracles

crates/core/src/constraint_system/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use exp::Exp;
2020
pub use prove::prove;
2121
pub use verify::verify;
2222

23-
use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId};
23+
use crate::oracle::{ConstraintSet, OracleId, SymbolicMultilinearOracleSet};
2424

2525
/// Contains the 3 things that place constraints on witness data in Binius
2626
/// - virtual oracles
@@ -32,7 +32,7 @@ use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId};
3232
#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
3333
#[deserialize_bytes(eval_generics(F = BinaryField128b))]
3434
pub struct ConstraintSystem<F: TowerField> {
35-
pub oracles: MultilinearOracleSet<F>,
35+
pub oracles: SymbolicMultilinearOracleSet<F>,
3636
pub table_constraints: Vec<ConstraintSet<F>>,
3737
pub non_zero_oracle_ids: Vec<OracleId>,
3838
pub flushes: Vec<Flush<F>>,

crates/core/src/constraint_system/prove.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ where
108108
let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
109109

110110
let ConstraintSystem {
111-
mut oracles,
111+
oracles,
112112
table_constraints,
113113
mut flushes,
114114
mut exponents,
115-
non_zero_oracle_ids,
115+
mut non_zero_oracle_ids,
116116
channel_count,
117117
table_size_specs,
118118
} = constraint_system.clone();
@@ -155,6 +155,33 @@ where
155155
let mut writer = transcript.message();
156156
writer.write_slice(table_sizes);
157157

158+
let mut oracles = oracles.instantiate(table_sizes)?;
159+
160+
// Prepare the constraint system for proving:
161+
//
162+
// - Trim all the zero sized oracles.
163+
// - Canonicalize the ordering.
164+
165+
flushes.retain(|flush| table_sizes[flush.table_id] > 0);
166+
flushes.sort_by_key(|flush| flush.channel_id);
167+
168+
non_zero_oracle_ids.retain(|oracle| !oracles.is_zero_sized(*oracle));
169+
exponents.retain(|exp| !oracles.is_zero_sized(exp.exp_result_id));
170+
171+
let mut table_constraints = table_constraints
172+
.into_iter()
173+
.filter_map(|u| {
174+
if table_sizes[u.table_id] == 0 {
175+
None
176+
} else {
177+
let n_vars = u.log_values_per_row + log2_ceil_usize(table_sizes[u.table_id]);
178+
Some(SizedConstraintSet::new(n_vars, u))
179+
}
180+
})
181+
.collect::<Vec<_>>();
182+
// Stable sort constraint sets in ascending order by number of variables.
183+
table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
184+
158185
reorder_exponents(&mut exponents, &oracles);
159186

160187
let witness_span = tracing::info_span!(
@@ -177,20 +204,6 @@ where
177204

178205
drop(witness_span);
179206

180-
let mut table_constraints = table_constraints
181-
.into_iter()
182-
.filter_map(|u| {
183-
if table_sizes[u.table_id] == 0 {
184-
None
185-
} else {
186-
let n_vars = u.log_values_per_row + log2_ceil_usize(table_sizes[u.table_id]);
187-
Some(SizedConstraintSet::new(n_vars, u))
188-
}
189-
})
190-
.collect::<Vec<_>>();
191-
// Stable sort constraint sets in ascending order by number of variables.
192-
table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
193-
194207
// Commit polynomials
195208
let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
196209
let merkle_scheme = merkle_prover.scheme();

crates/core/src/constraint_system/validate.rs

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use super::{
1212
error::Error,
1313
};
1414
use crate::{
15+
constraint_system::TableSizeSpec,
1516
oracle::{
1617
ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
1718
ShiftVariant,
@@ -31,8 +32,58 @@ where
3132
P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
3233
F: TowerField,
3334
{
35+
let ConstraintSystem {
36+
oracles,
37+
table_constraints,
38+
non_zero_oracle_ids,
39+
flushes,
40+
channel_count,
41+
table_size_specs,
42+
exponents: _,
43+
} = constraint_system;
44+
45+
if table_sizes.len() != table_size_specs.len() {
46+
return Err(Error::TableSizesLenMismatch {
47+
expected: table_size_specs.len(),
48+
got: table_sizes.len(),
49+
});
50+
}
51+
for (table_id, (&table_size, table_size_spec)) in
52+
table_sizes.iter().zip(table_size_specs.iter()).enumerate()
53+
{
54+
match table_size_spec {
55+
TableSizeSpec::PowerOfTwo => {
56+
if !table_size.is_power_of_two() {
57+
return Err(Error::TableSizePowerOfTwoRequired {
58+
table_id,
59+
size: table_size,
60+
});
61+
}
62+
}
63+
TableSizeSpec::Fixed { log_size } => {
64+
if table_size != 1 << log_size {
65+
return Err(Error::TableSizeFixedRequired {
66+
table_id,
67+
size: table_size,
68+
});
69+
}
70+
}
71+
TableSizeSpec::Arbitrary => (),
72+
}
73+
}
74+
75+
let unsized_oracles = oracles;
76+
let oracles = unsized_oracles.instantiate(table_sizes)?;
77+
3478
// Check the constraint sets
35-
for constraint_set in &constraint_system.table_constraints {
79+
for constraint_set in table_constraints {
80+
if table_sizes[constraint_set.table_id] == 0 {
81+
continue;
82+
}
83+
if constraint_set.oracle_ids.is_empty() {
84+
continue;
85+
}
86+
3687
let multilinears = constraint_set
3788
.oracle_ids
3889
.iter()
@@ -56,25 +107,15 @@ where
56107
}
57108

58109
// Check that nonzero oracles are non-zero over the entire hypercube
59-
nonzerocheck::validate_witness(
60-
witness,
61-
&constraint_system.oracles,
62-
&constraint_system.non_zero_oracle_ids,
63-
)?;
110+
nonzerocheck::validate_witness(witness, &oracles, non_zero_oracle_ids)?;
64111

65112
// Check that the channels balance with flushes and boundaries
66-
channel::validate_witness(
67-
witness,
68-
&constraint_system.flushes,
69-
boundaries,
70-
table_sizes,
71-
constraint_system.channel_count,
72-
)?;
113+
channel::validate_witness(witness, flushes, boundaries, table_sizes, *channel_count)?;
73114

74115
// Check consistency of virtual oracle witnesses (eg. that shift polynomials are actually
75116
// shifts).
76-
for oracle in constraint_system.oracles.polys() {
77-
validate_virtual_oracle_witness(oracle, &constraint_system.oracles, witness)?;
117+
for oracle in oracles.polys() {
118+
validate_virtual_oracle_witness(oracle, &oracles, witness)?;
78119
}
79120

80121
Ok(())
@@ -321,6 +362,9 @@ pub mod nonzerocheck {
321362
F: TowerField,
322363
{
323364
oracle_ids.into_par_iter().try_for_each(|id| {
365+
if oracles.is_zero_sized(*id) {
366+
return Ok(());
367+
}
324368
let multilinear = witness.get_multilin_poly(*id)?;
325369
(0..(1 << multilinear.n_vars()))
326370
.into_par_iter()

crates/core/src/constraint_system/verify.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,11 @@ where
5959
Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
6060
Challenger_: Challenger + Default,
6161
{
62-
let _ = constraint_system_digest;
6362
let ConstraintSystem {
64-
mut oracles,
63+
oracles,
6564
table_constraints,
6665
mut flushes,
67-
non_zero_oracle_ids,
66+
mut non_zero_oracle_ids,
6867
channel_count,
6968
mut exponents,
7069
table_size_specs,
@@ -107,6 +106,14 @@ where
107106
}
108107
}
109108

109+
let mut oracles = oracles.instantiate(&table_sizes)?;
110+
111+
flushes.retain(|flush| table_sizes[flush.table_id] > 0);
112+
flushes.sort_by_key(|flush| flush.channel_id);
113+
114+
non_zero_oracle_ids.retain(|oracle| !oracles.is_zero_sized(*oracle));
115+
exponents.retain(|exp| !oracles.is_zero_sized(exp.exp_result_id));
116+
110117
let mut table_constraints = table_constraints
111118
.into_iter()
112119
.filter_map(|u| {

crates/core/src/oracle/constraint.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ impl<F: Field> SizedConstraintSet<F> {
6363
pub struct ConstraintSet<F: Field> {
6464
pub table_id: TableId,
6565
pub log_values_per_row: usize,
66-
pub n_vars: usize,
6766
pub oracle_ids: Vec<OracleId>,
6867
pub constraints: Vec<Constraint<F>>,
6968
}

crates/core/src/oracle/error.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// Copyright 2024-2025 Irreducible Inc.
22

3-
use crate::oracle::OracleId;
3+
use crate::{constraint_system::TableId, oracle::OracleId};
44

55
#[derive(Debug, thiserror::Error)]
66
pub enum Error {
@@ -32,4 +32,6 @@ pub enum Error {
3232
"expected constraint set to contain only constraints with n_vars={expected}, but found n_vars={got}"
3333
)]
3434
ConstraintSetNvarsMismatch { got: usize, expected: usize },
35+
#[error("table size for table {table_id} is missing")]
36+
TableSizeMissing { table_id: TableId },
3537
}

crates/core/src/oracle/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ mod constraint;
1111
mod error;
1212
mod multilinear;
1313
mod oracle_id;
14+
mod symbolic;
1415

1516
pub use composite::*;
1617
pub use constraint::*;
1718
pub use error::Error;
1819
pub use multilinear::*;
1920
pub use oracle_id::*;
21+
pub use symbolic::*;

crates/core/src/oracle/multilinear.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,32 +353,42 @@ impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
353353
#[derive(Default, Debug, Clone, SerializeBytes, DeserializeBytes)]
354354
#[deserialize_bytes(eval_generics(F = BinaryField128b))]
355355
pub struct MultilinearOracleSet<F: TowerField> {
356-
oracles: Vec<MultilinearPolyOracle<F>>,
356+
/// The vector of oracles.
357+
///
358+
/// During sizing, an oracle could be skipped. In which case, the entry corresponding to the
359+
/// oracle will be `None`.
360+
oracles: Vec<Option<MultilinearPolyOracle<F>>>,
361+
/// The number of non-`None` entries in `oracles`.
362+
size: usize,
357363
}
358364

359365
impl<F: TowerField> MultilinearOracleSet<F> {
360366
pub const fn new() -> Self {
361367
Self {
362368
oracles: Vec::new(),
369+
size: 0,
363370
}
364371
}
365372

366373
pub fn size(&self) -> usize {
367-
self.oracles.len()
374+
self.size
368375
}
369376

370377
pub fn polys(&self) -> impl Iterator<Item = &MultilinearPolyOracle<F>> + '_ {
371-
(0..self.oracles.len()).map(|index| &self[OracleId::from_index(index)])
378+
self.iter().map(|(_, poly)| poly)
372379
}
373380

374381
pub fn ids(&self) -> impl Iterator<Item = OracleId> {
375-
(0..self.oracles.len()).map(OracleId::from_index)
382+
self.iter().map(|(id, _)| id)
376383
}
377384

378385
pub fn iter(&self) -> impl Iterator<Item = (OracleId, &MultilinearPolyOracle<F>)> + '_ {
379-
(0..self.oracles.len()).map(|index| {
380-
let oracle_id = OracleId::from_index(index);
381-
(oracle_id, &self[oracle_id])
386+
(0..self.oracles.len()).filter_map(|index| match self.oracles[index] {
387+
Some(ref oracle) => {
388+
let oracle_id = OracleId::from_index(index);
389+
Some((oracle_id, oracle))
390+
}
391+
None => None,
382392
})
383393
}
384394

@@ -400,15 +410,23 @@ impl<F: TowerField> MultilinearOracleSet<F> {
400410
id.index() < self.oracles.len()
401411
}
402412

403-
fn add_to_set(
413+
pub(crate) fn add_to_set(
404414
&mut self,
405415
oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
406416
) -> OracleId {
407417
let id = OracleId::from_index(self.oracles.len());
408-
self.oracles.push(oracle(id));
418+
self.oracles.push(Some(oracle(id)));
419+
self.size += 1;
409420
id
410421
}
411422

423+
/// Instead of adding a concrete oracle adds a skip mark, essentially just reserving an
424+
/// [`OracleId`]. Accessing such oracle, with some exceptions, is an error.
425+
pub(crate) fn add_skip(&mut self) {
426+
self.oracles.push(None);
427+
// don't increment size
428+
}
429+
412430
pub fn add_transparent(
413431
&mut self,
414432
poly: impl MultivariatePoly<F> + 'static,
@@ -515,13 +533,21 @@ impl<F: TowerField> MultilinearOracleSet<F> {
515533
pub fn tower_level(&self, id: OracleId) -> usize {
516534
self[id].binary_tower_level()
517535
}
536+
537+
/// Returns `true` if the given [`OracleId`] refers to an oracle that was skipped during the
538+
/// instantiation of the symbolic multilinear oracle set.
539+
pub fn is_zero_sized(&self, id: OracleId) -> bool {
540+
self.oracles[id.index()].is_none()
541+
}
518542
}
519543

520544
impl<F: TowerField> std::ops::Index<OracleId> for MultilinearOracleSet<F> {
521545
type Output = MultilinearPolyOracle<F>;
522546

523547
fn index(&self, id: OracleId) -> &Self::Output {
524-
&self.oracles[id.index()]
548+
self.oracles[id.index()]
549+
.as_ref()
550+
.expect("tried to access skipped oracle")
525551
}
526552
}
527553

0 commit comments

Comments
 (0)