"""Constraint DAG evaluation and STARK constraint verification.
Evaluates the SymbolicExpressionDag at the OOD (out-of-domain) point and checks
that folded_constraints * inv_zeroifier == quotient for each AIR.
Reference:
stark-backend/src/verifier/constraints.rs -- verify_single_rap_constraints
stark-backend/src/verifier/folder.rs -- GenericVerifierConstraintFolder
stark-backend/src/air_builders/symbolic/symbolic_expression.rs -- SymbolicEvaluator::eval_nodes
stark-backend/src/air_builders/symbolic/dag.rs -- SymbolicExpressionNode
"""
from __future__ import annotations
import numpy as np
from primitives.field import (
FF4,
FF,
Fe,
)
from protocol.domain import (
DomainSelectors,
TwoAdicMultiplicativeCoset,
)
from protocol.proof import (
AdjacentOpenedValues,
EntryType,
SymbolicExpressionDag,
SymbolicNodeKind,
SymbolicVariable,
)
# ---------------------------------------------------------------------------
# Unflatten: reconstitute extension field elements from flattened base field
# ---------------------------------------------------------------------------
[docs]
def unflatten_ext_values(flattened: list) -> list[FF4]:
"""Reconstitute extension field elements from flattened challenge values.
In the Rust verifier, after_challenge (permutation) trace values are stored
as "flattened" extension field elements: each FF4 element is stored as 4
consecutive Challenge (== FF4) values where only the base field component
is meaningful. This function groups them back into proper FF4 elements.
Given D=4 (extension degree), each group of 4 consecutive flattened values
[c0, c1, c2, c3] (each an FF4Coeffs with only coeffs[0] used) is combined
using the monomial basis:
result = c0 * 1 + c1 * x + c2 * x^2 + c3 * x^3
Since each c_i is already an FF4Coeffs, multiplying by x^e_i is the same
as shifting coefficients. But in the verifier, each c_i is treated as a
full FF4 element, and monomial(e_i) * c_i uses extension field multiplication.
For BabyBear extension degree 4 with irreducible x^4 - 11:
monomial(0) = [1,0,0,0]
monomial(1) = [0,1,0,0]
monomial(2) = [0,0,1,0]
monomial(3) = [0,0,0,1]
Reference:
stark-backend/src/verifier/constraints.rs lines 65-75 (unflatten closure)
"""
ext_degree = 4
assert len(flattened) % ext_degree == 0, (
f"Flattened length {len(flattened)} not divisible by ext degree {ext_degree}"
)
result = []
for i in range(0, len(flattened), ext_degree):
chunk = flattened[i : i + ext_degree]
acc = FF4.zero()
for e_i in range(ext_degree):
monomial = [0, 0, 0, 0]
monomial[e_i] = 1
acc = acc + FF4(chunk[e_i]) * FF4(monomial)
result.append(acc)
return result
# ---------------------------------------------------------------------------
# DAG evaluation
# ---------------------------------------------------------------------------
# <doc-anchor id="eval-symbolic-dag">
[docs]
def eval_symbolic_dag(
dag: SymbolicExpressionDag,
selectors: DomainSelectors,
preprocessed_local: list,
preprocessed_next: list,
partitioned_main_values: list[AdjacentOpenedValues],
after_challenge_values: list[AdjacentOpenedValues],
challenges: list[list],
public_values: list[Fe],
exposed_values_after_challenge: list[list],
) -> list[FF4]:
"""Evaluate the symbolic expression DAG and return constraint values.
Walks the DAG nodes in topological order (they are already sorted).
For each node, computes the extension field value based on the node kind.
Returns the values at the constraint output indices (dag.constraint_idx).
Reference:
stark-backend/src/air_builders/symbolic/symbolic_expression.rs
SymbolicEvaluator::eval_nodes (lines 364-396)
stark-backend/src/verifier/folder.rs
GenericVerifierConstraintFolder impl of SymbolicEvaluator (lines 74-123)
"""
results: list[FF4] = []
for node in dag.nodes:
value: FF4
if node.kind == SymbolicNodeKind.VARIABLE:
value = _lookup_variable(
node.variable,
preprocessed_local,
preprocessed_next,
partitioned_main_values,
after_challenge_values,
challenges,
public_values,
exposed_values_after_challenge,
)
elif node.kind == SymbolicNodeKind.CONSTANT:
value = FF4(node.constant_value)
elif node.kind == SymbolicNodeKind.IS_FIRST_ROW:
value = selectors.is_first_row
elif node.kind == SymbolicNodeKind.IS_LAST_ROW:
value = selectors.is_last_row
elif node.kind == SymbolicNodeKind.IS_TRANSITION:
value = selectors.is_transition
elif node.kind == SymbolicNodeKind.ADD:
value = results[node.left_idx] + results[node.right_idx]
elif node.kind == SymbolicNodeKind.SUB:
value = results[node.left_idx] - results[node.right_idx]
elif node.kind == SymbolicNodeKind.MUL:
value = results[node.left_idx] * results[node.right_idx]
elif node.kind == SymbolicNodeKind.NEG:
value = -results[node.idx]
else:
raise ValueError(f"Unknown SymbolicNodeKind: {node.kind}")
results.append(value)
return [results[idx] for idx in dag.constraint_idx]
def _lookup_variable(
var: SymbolicVariable,
preprocessed_local: list,
preprocessed_next: list,
partitioned_main_values: list[AdjacentOpenedValues],
after_challenge_values: list[AdjacentOpenedValues],
challenges: list[list],
public_values: list[Fe],
exposed_values_after_challenge: list[list],
) -> FF4:
"""Look up a symbolic variable's value from the opened proof values.
Maps each Entry type to the appropriate opened values slice:
- Preprocessed{offset} -> preprocessed local (offset=0) or next (offset=1)
- Main{part_index, offset} -> partitioned_main[part_index].local/next[index]
- Public -> public_values[index] embedded into FF4
- Permutation{offset} -> after_challenge[0].local/next[index] (always phase 0)
- Challenge -> challenges[0][index] (always phase 0)
- Exposed -> exposed_values_after_challenge[0][index] (always phase 0)
Reference:
stark-backend/src/verifier/folder.rs lines 95-121
(GenericVerifierConstraintFolder::eval_var)
"""
entry = var.entry
index = var.index
if entry.kind == EntryType.PREPROCESSED:
if entry.offset == 0:
return FF4(preprocessed_local[index])
else:
return FF4(preprocessed_next[index])
elif entry.kind == EntryType.MAIN:
part = partitioned_main_values[entry.part_index]
if entry.offset == 0:
return FF4(part.local[index])
else:
return FF4(part.next[index])
elif entry.kind == EntryType.PUBLIC:
return FF4(public_values[index])
elif entry.kind == EntryType.PERMUTATION:
part = after_challenge_values[0]
if entry.offset == 0:
return FF4(part.local[index])
else:
return FF4(part.next[index])
elif entry.kind == EntryType.CHALLENGE:
return FF4(challenges[0][index])
elif entry.kind == EntryType.EXPOSED:
return FF4(exposed_values_after_challenge[0][index])
else:
raise ValueError(f"Unknown EntryType: {entry.kind}")
# ---------------------------------------------------------------------------
# Constraint folding
# ---------------------------------------------------------------------------
# <doc-anchor id="fold-constraints">
[docs]
def fold_constraints(
constraint_evals: list[FF4],
alpha: FF4,
) -> FF4:
"""Compute random linear combination of constraint evaluations.
The Rust verifier uses Horner's method: starting with accumulator = 0,
for each constraint C_i:
accumulator = accumulator * alpha + C_i
This produces: C_0 * alpha^{n-1} + C_1 * alpha^{n-2} + ... + C_{n-1}
Reference:
stark-backend/src/verifier/folder.rs lines 56-72
(GenericVerifierConstraintFolder::eval_constraints + assert_zero)
"""
accumulator = FF4.zero()
for c_eval in constraint_evals:
accumulator = accumulator * alpha + c_eval
return accumulator
# ---------------------------------------------------------------------------
# Quotient reconstruction
# ---------------------------------------------------------------------------
# <doc-anchor id="reconstruct-quotient">
[docs]
def reconstruct_quotient(
quotient_chunks: list[list],
qc_domains: list[TwoAdicMultiplicativeCoset],
zeta: FF4,
) -> FF4:
"""Recompute the full quotient polynomial value at zeta from chunks.
For each chunk domain D_i, compute:
zps[i] = prod_{j != i} Z_{D_j}(zeta) / Z_{D_j}(first_point(D_i))
Then the full quotient at zeta is:
Q(zeta) = sum_i zps[i] * sum_e (monomial(e) * chunk[i][e])
Each quotient_chunks[i] has D=4 elements, which are the coefficients
in the extension field basis. So quotient_chunks[i] is treated as a
single FF4 element by combining with the monomial basis.
Reference:
stark-backend/src/verifier/constraints.rs lines 38-63
"""
num_chunks = len(qc_domains)
assert len(quotient_chunks) == num_chunks
zps: list[FF4] = []
for i in range(num_chunks):
prod = FF4.one()
for j in range(num_chunks):
if j != i:
zp_at_zeta = qc_domains[j].vanishing_poly_at_point(zeta)
first_pt = FF4(qc_domains[i].first_point())
zp_at_first = qc_domains[j].vanishing_poly_at_point(first_pt)
prod = prod * (zp_at_zeta / zp_at_first)
zps.append(prod)
result = FF4.zero()
for ch_i in range(num_chunks):
chunk = quotient_chunks[ch_i]
chunk_sum = FF4.zero()
for e_i, c in enumerate(chunk):
monomial = [0, 0, 0, 0]
monomial[e_i] = 1
chunk_sum = chunk_sum + FF4(c) * FF4(monomial)
result = result + zps[ch_i] * chunk_sum
return result
# ---------------------------------------------------------------------------
# Vectorized DAG evaluation (all rows at once, base field)
# ---------------------------------------------------------------------------
[docs]
def eval_dag_all_rows(
dag: SymbolicExpressionDag,
partitioned_main: list[list[list[Fe]]],
preprocessed: list[list[Fe]] | None,
public_values: list[Fe],
height: int,
) -> list[FF | None]:
"""Evaluate full DAG at ALL rows simultaneously using field column arrays.
Each DAG node evaluates to a base-field column vector of length height.
Args:
dag: SymbolicExpressionDag
partitioned_main: [part_index][rows][cols] — list of trace matrices
preprocessed: [rows][cols] or None
public_values: list of Fe
height: trace height
Returns:
List of field column arrays, one per DAG node.
"""
# Pre-convert trace matrices to FF column arrays for fast lookup
ff_parts = []
for part in partitioned_main:
cols = len(part[0]) if part else 0
ff_cols = [
FF([part[r][c] for r in range(height)])
for c in range(cols)
]
ff_parts.append(ff_cols)
ff_prep = None
if preprocessed is not None:
prep_cols = len(preprocessed[0]) if preprocessed else 0
ff_prep = [
FF([preprocessed[r][c] for r in range(height)])
for c in range(prep_cols)
]
nodes = dag.nodes
node_values = [None] * len(nodes)
for i, node in enumerate(nodes):
kind = node.kind
if kind == SymbolicNodeKind.VARIABLE:
var = node.variable
entry = var.entry
if entry.kind == EntryType.MAIN:
col = ff_parts[entry.part_index][var.index]
if entry.offset == 0:
node_values[i] = col
else:
node_values[i] = np.roll(col, -entry.offset)
elif entry.kind == EntryType.PREPROCESSED:
col = ff_prep[var.index]
if entry.offset == 0:
node_values[i] = col
else:
node_values[i] = np.roll(col, -entry.offset)
elif entry.kind == EntryType.PUBLIC:
node_values[i] = FF(np.full(height, int(public_values[var.index])))
else:
node_values[i] = FF.Zeros(height)
elif kind == SymbolicNodeKind.CONSTANT:
node_values[i] = FF(np.full(height, int(node.constant_value)))
elif kind in (SymbolicNodeKind.IS_FIRST_ROW,
SymbolicNodeKind.IS_LAST_ROW,
SymbolicNodeKind.IS_TRANSITION):
node_values[i] = FF.Zeros(height)
elif kind == SymbolicNodeKind.ADD:
node_values[i] = node_values[node.left_idx] + node_values[node.right_idx]
elif kind == SymbolicNodeKind.SUB:
node_values[i] = node_values[node.left_idx] - node_values[node.right_idx]
elif kind == SymbolicNodeKind.MUL:
node_values[i] = node_values[node.left_idx] * node_values[node.right_idx]
elif kind == SymbolicNodeKind.NEG:
node_values[i] = -node_values[node.idx]
else:
raise ValueError(f"Unknown node kind: {kind}")
return node_values
# ---------------------------------------------------------------------------
# Verification error
# ---------------------------------------------------------------------------
[docs]
class VerificationError(Exception):
"""Raised when STARK constraint verification fails.
Reference:
stark-backend/src/verifier/error.rs (enum VerificationError)
"""
pass
[docs]
class OodEvaluationMismatch(VerificationError):
"""Out-of-domain evaluation mismatch: constraints(zeta) != quotient(zeta) * Z_H(zeta).
Reference:
stark-backend/src/verifier/error.rs (VerificationError::OodEvaluationMismatch)
"""
pass
# ---------------------------------------------------------------------------
# Main per-AIR verification function
# ---------------------------------------------------------------------------
# <doc-anchor id="verify-single-rap">
[docs]
def verify_single_rap_constraints(
constraints: SymbolicExpressionDag,
preprocessed_values: AdjacentOpenedValues | None,
partitioned_main_values: list[AdjacentOpenedValues],
after_challenge_values: list[AdjacentOpenedValues],
quotient_chunks: list[list],
domain: TwoAdicMultiplicativeCoset,
qc_domains: list[TwoAdicMultiplicativeCoset],
zeta: FF4,
alpha: FF4,
challenges: list[list],
public_values: list[Fe],
exposed_values_after_challenge: list[list],
) -> None:
"""Verify constraints for a single RAP (AIR with interactions).
Steps:
1. Compute selectors at zeta (is_first_row, is_last_row, is_transition, inv_zeroifier).
2. Unflatten after_challenge values from base-field-flattened to FF4 elements.
3. Evaluate constraint DAG at opened values.
4. Fold constraints with alpha (random linear combination).
5. Reconstruct quotient from chunks.
6. Assert: folded_constraints * inv_zeroifier == quotient_value.
Raises OodEvaluationMismatch if the check fails.
Reference:
stark-backend/src/verifier/constraints.rs lines 21-140
(verify_single_rap_constraints)
"""
# Step 1: compute selectors at zeta
selectors = domain.selectors_at_point(zeta)
# Step 2: extract preprocessed local/next
if preprocessed_values is not None:
preprocessed_local = preprocessed_values.local
preprocessed_next = preprocessed_values.next
else:
preprocessed_local = []
preprocessed_next = []
# Step 3: unflatten after_challenge values
# In Rust, after_challenge opened values are stored as flattened extension
# field elements (each FF4 element = 4 consecutive Challenge values).
# We reconstitute them before passing to the DAG evaluator.
unflattened_after_challenge: list[AdjacentOpenedValues] = []
for ac_vals in after_challenge_values:
local_unflat = unflatten_ext_values(ac_vals.local)
next_unflat = unflatten_ext_values(ac_vals.next)
unflattened_after_challenge.append(
AdjacentOpenedValues(local=local_unflat, next=next_unflat)
)
# Step 4: evaluate constraint DAG
constraint_evals = eval_symbolic_dag(
dag=constraints,
selectors=selectors,
preprocessed_local=preprocessed_local,
preprocessed_next=preprocessed_next,
partitioned_main_values=partitioned_main_values,
after_challenge_values=unflattened_after_challenge,
challenges=challenges,
public_values=public_values,
exposed_values_after_challenge=exposed_values_after_challenge,
)
# Step 5: fold constraints
folded_constraints = fold_constraints(constraint_evals, alpha)
# Step 6: reconstruct quotient
quotient = reconstruct_quotient(quotient_chunks, qc_domains, zeta)
# Step 7: check folded_constraints * inv_zeroifier == quotient
lhs = folded_constraints * selectors.inv_zeroifier
if lhs != quotient:
raise OodEvaluationMismatch(
f"OOD evaluation mismatch: "
f"folded_constraints * inv_zeroifier = {lhs} != quotient = {quotient}"
)