Source code for protocol.constraints

"""Constraint DAG evaluation and STARK constraint verification.

Evaluates the SymbolicExpressionDag at the OOD (out-of-domain) point and checks
that folded_constraints * inv_zeroifier == quotient for each AIR.

Reference:
    stark-backend/src/verifier/constraints.rs   -- verify_single_rap_constraints
    stark-backend/src/verifier/folder.rs         -- GenericVerifierConstraintFolder
    stark-backend/src/air_builders/symbolic/symbolic_expression.rs -- SymbolicEvaluator::eval_nodes
    stark-backend/src/air_builders/symbolic/dag.rs -- SymbolicExpressionNode
"""

from __future__ import annotations

import numpy as np

from primitives.field import (
    FF4,
    FF,
    Fe,
)
from protocol.domain import (
    DomainSelectors,
    TwoAdicMultiplicativeCoset,
)
from protocol.proof import (
    AdjacentOpenedValues,
    EntryType,
    SymbolicExpressionDag,
    SymbolicNodeKind,
    SymbolicVariable,
)


# ---------------------------------------------------------------------------
# Unflatten: reconstitute extension field elements from flattened base field
# ---------------------------------------------------------------------------


[docs] def unflatten_ext_values(flattened: list) -> list[FF4]: """Reconstitute extension field elements from flattened challenge values. In the Rust verifier, after_challenge (permutation) trace values are stored as "flattened" extension field elements: each FF4 element is stored as 4 consecutive Challenge (== FF4) values where only the base field component is meaningful. This function groups them back into proper FF4 elements. Given D=4 (extension degree), each group of 4 consecutive flattened values [c0, c1, c2, c3] (each an FF4Coeffs with only coeffs[0] used) is combined using the monomial basis: result = c0 * 1 + c1 * x + c2 * x^2 + c3 * x^3 Since each c_i is already an FF4Coeffs, multiplying by x^e_i is the same as shifting coefficients. But in the verifier, each c_i is treated as a full FF4 element, and monomial(e_i) * c_i uses extension field multiplication. For BabyBear extension degree 4 with irreducible x^4 - 11: monomial(0) = [1,0,0,0] monomial(1) = [0,1,0,0] monomial(2) = [0,0,1,0] monomial(3) = [0,0,0,1] Reference: stark-backend/src/verifier/constraints.rs lines 65-75 (unflatten closure) """ ext_degree = 4 assert len(flattened) % ext_degree == 0, ( f"Flattened length {len(flattened)} not divisible by ext degree {ext_degree}" ) result = [] for i in range(0, len(flattened), ext_degree): chunk = flattened[i : i + ext_degree] acc = FF4.zero() for e_i in range(ext_degree): monomial = [0, 0, 0, 0] monomial[e_i] = 1 acc = acc + FF4(chunk[e_i]) * FF4(monomial) result.append(acc) return result
# --------------------------------------------------------------------------- # DAG evaluation # --------------------------------------------------------------------------- # <doc-anchor id="eval-symbolic-dag">
[docs] def eval_symbolic_dag( dag: SymbolicExpressionDag, selectors: DomainSelectors, preprocessed_local: list, preprocessed_next: list, partitioned_main_values: list[AdjacentOpenedValues], after_challenge_values: list[AdjacentOpenedValues], challenges: list[list], public_values: list[Fe], exposed_values_after_challenge: list[list], ) -> list[FF4]: """Evaluate the symbolic expression DAG and return constraint values. Walks the DAG nodes in topological order (they are already sorted). For each node, computes the extension field value based on the node kind. Returns the values at the constraint output indices (dag.constraint_idx). Reference: stark-backend/src/air_builders/symbolic/symbolic_expression.rs SymbolicEvaluator::eval_nodes (lines 364-396) stark-backend/src/verifier/folder.rs GenericVerifierConstraintFolder impl of SymbolicEvaluator (lines 74-123) """ results: list[FF4] = [] for node in dag.nodes: value: FF4 if node.kind == SymbolicNodeKind.VARIABLE: value = _lookup_variable( node.variable, preprocessed_local, preprocessed_next, partitioned_main_values, after_challenge_values, challenges, public_values, exposed_values_after_challenge, ) elif node.kind == SymbolicNodeKind.CONSTANT: value = FF4(node.constant_value) elif node.kind == SymbolicNodeKind.IS_FIRST_ROW: value = selectors.is_first_row elif node.kind == SymbolicNodeKind.IS_LAST_ROW: value = selectors.is_last_row elif node.kind == SymbolicNodeKind.IS_TRANSITION: value = selectors.is_transition elif node.kind == SymbolicNodeKind.ADD: value = results[node.left_idx] + results[node.right_idx] elif node.kind == SymbolicNodeKind.SUB: value = results[node.left_idx] - results[node.right_idx] elif node.kind == SymbolicNodeKind.MUL: value = results[node.left_idx] * results[node.right_idx] elif node.kind == SymbolicNodeKind.NEG: value = -results[node.idx] else: raise ValueError(f"Unknown SymbolicNodeKind: {node.kind}") results.append(value) return [results[idx] for idx in dag.constraint_idx]
def _lookup_variable( var: SymbolicVariable, preprocessed_local: list, preprocessed_next: list, partitioned_main_values: list[AdjacentOpenedValues], after_challenge_values: list[AdjacentOpenedValues], challenges: list[list], public_values: list[Fe], exposed_values_after_challenge: list[list], ) -> FF4: """Look up a symbolic variable's value from the opened proof values. Maps each Entry type to the appropriate opened values slice: - Preprocessed{offset} -> preprocessed local (offset=0) or next (offset=1) - Main{part_index, offset} -> partitioned_main[part_index].local/next[index] - Public -> public_values[index] embedded into FF4 - Permutation{offset} -> after_challenge[0].local/next[index] (always phase 0) - Challenge -> challenges[0][index] (always phase 0) - Exposed -> exposed_values_after_challenge[0][index] (always phase 0) Reference: stark-backend/src/verifier/folder.rs lines 95-121 (GenericVerifierConstraintFolder::eval_var) """ entry = var.entry index = var.index if entry.kind == EntryType.PREPROCESSED: if entry.offset == 0: return FF4(preprocessed_local[index]) else: return FF4(preprocessed_next[index]) elif entry.kind == EntryType.MAIN: part = partitioned_main_values[entry.part_index] if entry.offset == 0: return FF4(part.local[index]) else: return FF4(part.next[index]) elif entry.kind == EntryType.PUBLIC: return FF4(public_values[index]) elif entry.kind == EntryType.PERMUTATION: part = after_challenge_values[0] if entry.offset == 0: return FF4(part.local[index]) else: return FF4(part.next[index]) elif entry.kind == EntryType.CHALLENGE: return FF4(challenges[0][index]) elif entry.kind == EntryType.EXPOSED: return FF4(exposed_values_after_challenge[0][index]) else: raise ValueError(f"Unknown EntryType: {entry.kind}") # --------------------------------------------------------------------------- # Constraint folding # --------------------------------------------------------------------------- # <doc-anchor id="fold-constraints">
[docs] def fold_constraints( constraint_evals: list[FF4], alpha: FF4, ) -> FF4: """Compute random linear combination of constraint evaluations. The Rust verifier uses Horner's method: starting with accumulator = 0, for each constraint C_i: accumulator = accumulator * alpha + C_i This produces: C_0 * alpha^{n-1} + C_1 * alpha^{n-2} + ... + C_{n-1} Reference: stark-backend/src/verifier/folder.rs lines 56-72 (GenericVerifierConstraintFolder::eval_constraints + assert_zero) """ accumulator = FF4.zero() for c_eval in constraint_evals: accumulator = accumulator * alpha + c_eval return accumulator
# --------------------------------------------------------------------------- # Quotient reconstruction # --------------------------------------------------------------------------- # <doc-anchor id="reconstruct-quotient">
[docs] def reconstruct_quotient( quotient_chunks: list[list], qc_domains: list[TwoAdicMultiplicativeCoset], zeta: FF4, ) -> FF4: """Recompute the full quotient polynomial value at zeta from chunks. For each chunk domain D_i, compute: zps[i] = prod_{j != i} Z_{D_j}(zeta) / Z_{D_j}(first_point(D_i)) Then the full quotient at zeta is: Q(zeta) = sum_i zps[i] * sum_e (monomial(e) * chunk[i][e]) Each quotient_chunks[i] has D=4 elements, which are the coefficients in the extension field basis. So quotient_chunks[i] is treated as a single FF4 element by combining with the monomial basis. Reference: stark-backend/src/verifier/constraints.rs lines 38-63 """ num_chunks = len(qc_domains) assert len(quotient_chunks) == num_chunks zps: list[FF4] = [] for i in range(num_chunks): prod = FF4.one() for j in range(num_chunks): if j != i: zp_at_zeta = qc_domains[j].vanishing_poly_at_point(zeta) first_pt = FF4(qc_domains[i].first_point()) zp_at_first = qc_domains[j].vanishing_poly_at_point(first_pt) prod = prod * (zp_at_zeta / zp_at_first) zps.append(prod) result = FF4.zero() for ch_i in range(num_chunks): chunk = quotient_chunks[ch_i] chunk_sum = FF4.zero() for e_i, c in enumerate(chunk): monomial = [0, 0, 0, 0] monomial[e_i] = 1 chunk_sum = chunk_sum + FF4(c) * FF4(monomial) result = result + zps[ch_i] * chunk_sum return result
# --------------------------------------------------------------------------- # Vectorized DAG evaluation (all rows at once, base field) # ---------------------------------------------------------------------------
[docs] def eval_dag_all_rows( dag: SymbolicExpressionDag, partitioned_main: list[list[list[Fe]]], preprocessed: list[list[Fe]] | None, public_values: list[Fe], height: int, ) -> list[FF | None]: """Evaluate full DAG at ALL rows simultaneously using field column arrays. Each DAG node evaluates to a base-field column vector of length height. Args: dag: SymbolicExpressionDag partitioned_main: [part_index][rows][cols] — list of trace matrices preprocessed: [rows][cols] or None public_values: list of Fe height: trace height Returns: List of field column arrays, one per DAG node. """ # Pre-convert trace matrices to FF column arrays for fast lookup ff_parts = [] for part in partitioned_main: cols = len(part[0]) if part else 0 ff_cols = [ FF([part[r][c] for r in range(height)]) for c in range(cols) ] ff_parts.append(ff_cols) ff_prep = None if preprocessed is not None: prep_cols = len(preprocessed[0]) if preprocessed else 0 ff_prep = [ FF([preprocessed[r][c] for r in range(height)]) for c in range(prep_cols) ] nodes = dag.nodes node_values = [None] * len(nodes) for i, node in enumerate(nodes): kind = node.kind if kind == SymbolicNodeKind.VARIABLE: var = node.variable entry = var.entry if entry.kind == EntryType.MAIN: col = ff_parts[entry.part_index][var.index] if entry.offset == 0: node_values[i] = col else: node_values[i] = np.roll(col, -entry.offset) elif entry.kind == EntryType.PREPROCESSED: col = ff_prep[var.index] if entry.offset == 0: node_values[i] = col else: node_values[i] = np.roll(col, -entry.offset) elif entry.kind == EntryType.PUBLIC: node_values[i] = FF(np.full(height, int(public_values[var.index]))) else: node_values[i] = FF.Zeros(height) elif kind == SymbolicNodeKind.CONSTANT: node_values[i] = FF(np.full(height, int(node.constant_value))) elif kind in (SymbolicNodeKind.IS_FIRST_ROW, SymbolicNodeKind.IS_LAST_ROW, SymbolicNodeKind.IS_TRANSITION): node_values[i] = FF.Zeros(height) elif kind == SymbolicNodeKind.ADD: node_values[i] = node_values[node.left_idx] + node_values[node.right_idx] elif kind == SymbolicNodeKind.SUB: node_values[i] = node_values[node.left_idx] - node_values[node.right_idx] elif kind == SymbolicNodeKind.MUL: node_values[i] = node_values[node.left_idx] * node_values[node.right_idx] elif kind == SymbolicNodeKind.NEG: node_values[i] = -node_values[node.idx] else: raise ValueError(f"Unknown node kind: {kind}") return node_values
# --------------------------------------------------------------------------- # Verification error # ---------------------------------------------------------------------------
[docs] class VerificationError(Exception): """Raised when STARK constraint verification fails. Reference: stark-backend/src/verifier/error.rs (enum VerificationError) """ pass
[docs] class OodEvaluationMismatch(VerificationError): """Out-of-domain evaluation mismatch: constraints(zeta) != quotient(zeta) * Z_H(zeta). Reference: stark-backend/src/verifier/error.rs (VerificationError::OodEvaluationMismatch) """ pass
# --------------------------------------------------------------------------- # Main per-AIR verification function # --------------------------------------------------------------------------- # <doc-anchor id="verify-single-rap">
[docs] def verify_single_rap_constraints( constraints: SymbolicExpressionDag, preprocessed_values: AdjacentOpenedValues | None, partitioned_main_values: list[AdjacentOpenedValues], after_challenge_values: list[AdjacentOpenedValues], quotient_chunks: list[list], domain: TwoAdicMultiplicativeCoset, qc_domains: list[TwoAdicMultiplicativeCoset], zeta: FF4, alpha: FF4, challenges: list[list], public_values: list[Fe], exposed_values_after_challenge: list[list], ) -> None: """Verify constraints for a single RAP (AIR with interactions). Steps: 1. Compute selectors at zeta (is_first_row, is_last_row, is_transition, inv_zeroifier). 2. Unflatten after_challenge values from base-field-flattened to FF4 elements. 3. Evaluate constraint DAG at opened values. 4. Fold constraints with alpha (random linear combination). 5. Reconstruct quotient from chunks. 6. Assert: folded_constraints * inv_zeroifier == quotient_value. Raises OodEvaluationMismatch if the check fails. Reference: stark-backend/src/verifier/constraints.rs lines 21-140 (verify_single_rap_constraints) """ # Step 1: compute selectors at zeta selectors = domain.selectors_at_point(zeta) # Step 2: extract preprocessed local/next if preprocessed_values is not None: preprocessed_local = preprocessed_values.local preprocessed_next = preprocessed_values.next else: preprocessed_local = [] preprocessed_next = [] # Step 3: unflatten after_challenge values # In Rust, after_challenge opened values are stored as flattened extension # field elements (each FF4 element = 4 consecutive Challenge values). # We reconstitute them before passing to the DAG evaluator. unflattened_after_challenge: list[AdjacentOpenedValues] = [] for ac_vals in after_challenge_values: local_unflat = unflatten_ext_values(ac_vals.local) next_unflat = unflatten_ext_values(ac_vals.next) unflattened_after_challenge.append( AdjacentOpenedValues(local=local_unflat, next=next_unflat) ) # Step 4: evaluate constraint DAG constraint_evals = eval_symbolic_dag( dag=constraints, selectors=selectors, preprocessed_local=preprocessed_local, preprocessed_next=preprocessed_next, partitioned_main_values=partitioned_main_values, after_challenge_values=unflattened_after_challenge, challenges=challenges, public_values=public_values, exposed_values_after_challenge=exposed_values_after_challenge, ) # Step 5: fold constraints folded_constraints = fold_constraints(constraint_evals, alpha) # Step 6: reconstruct quotient quotient = reconstruct_quotient(quotient_chunks, qc_domains, zeta) # Step 7: check folded_constraints * inv_zeroifier == quotient lhs = folded_constraints * selectors.inv_zeroifier if lhs != quotient: raise OodEvaluationMismatch( f"OOD evaluation mismatch: " f"folded_constraints * inv_zeroifier = {lhs} != quotient = {quotient}" )