"""STARK verifier orchestration.
Top-level verifier that coordinates transcript operations, PCS verification,
and constraint checking for a multi-AIR STARK proof.
The Fiat-Shamir transcript operations must match the Rust verifier EXACTLY in
order: any deviation produces wrong challenges and breaks verification.
Reference:
stark-backend/src/verifier/mod.rs -- MultiTraceStarkVerifier::verify_raps
stark-backend/src/verifier/error.rs -- VerificationError enum
stark-backend/src/interaction/fri_log_up.rs -- FriLogUpPhase::partially_verify
"""
from __future__ import annotations
from primitives.field import BABYBEAR_PRIME, FF4, Digest, Fe
from primitives.transcript import Challenger, check_witness, grind
from protocol.constraints import (
VerificationError,
verify_single_rap_constraints,
)
from protocol.domain import (
TwoAdicMultiplicativeCoset,
create_disjoint_domain,
natural_domain_for_degree,
)
from protocol.pcs import PcsRound, pcs_verify
from protocol.proof import (
AdjacentOpenedValues,
AirProofData,
BatchOpening,
Commitments,
CommitPhaseProofStep,
FriLogUpPartialProof,
FriParameters,
FriProof as ProofFriProof,
MultiStarkVerifyingKey,
OpenedValues,
OpeningProof,
Proof,
QueryProof,
StarkVerifyingKey,
)
# Number of challenges and exposed values for FRI LogUp phase.
# Reference: stark-backend/src/interaction/fri_log_up.rs lines 240-241
[docs]
STARK_LU_NUM_CHALLENGES = 2
[docs]
STARK_LU_NUM_EXPOSED_VALUES = 1
# Extension field degree (BabyBear quartic extension).
# ---------------------------------------------------------------------------
# Verification errors
# ---------------------------------------------------------------------------
# VerificationError and OodEvaluationMismatch are defined in protocol.constraints
# and re-exported here (imported above) to avoid circular imports.
[docs]
class InvalidProofShape(VerificationError):
"""Proof structure does not match the verifying key.
Reference:
stark-backend/src/verifier/error.rs (VerificationError::InvalidProofShape)
"""
def __init__(self, message: str = "") -> None:
super().__init__(f"InvalidProofShape: {message}" if message else "InvalidProofShape")
[docs]
class InvalidOpeningArgument(VerificationError):
"""PCS opening verification failed.
Reference:
stark-backend/src/verifier/error.rs (VerificationError::InvalidOpeningArgument)
"""
def __init__(self, message: str = "") -> None:
super().__init__(f"InvalidOpeningArgument: {message}" if message else "InvalidOpeningArgument")
[docs]
class ChallengePhaseError(VerificationError):
"""RAP challenge phase verification failed (e.g. non-zero cumulative sum).
Reference:
stark-backend/src/verifier/error.rs (VerificationError::ChallengePhaseError)
"""
def __init__(self, message: str = "") -> None:
super().__init__(f"ChallengePhaseError: {message}" if message else "ChallengePhaseError")
[docs]
class InvalidDeepPowWitness(VerificationError):
"""DEEP proof-of-work witness is invalid.
Reference:
stark-backend/src/verifier/error.rs (VerificationError::InvalidDeepPowWitness)
"""
def __init__(self, message: str = "") -> None:
super().__init__(f"InvalidDeepPowWitness: {message}" if message else "InvalidDeepPowWitness")
# ---------------------------------------------------------------------------
# Helper: VK view (analogous to MultiStarkVerifyingKeyView)
# ---------------------------------------------------------------------------
def _get_vk_view(
vk: MultiStarkVerifyingKey,
air_ids: list[int],
) -> list[StarkVerifyingKey]:
"""Index into the VK's per_air list by air_id.
Reference:
stark-backend/src/keygen/view.rs (MultiStarkVerifyingKey::view)
"""
return [vk.inner.per_air[air_id] for air_id in air_ids]
def _has_common_main(svk: StarkVerifyingKey) -> bool:
"""Check if this AIR has a common main trace.
Reference:
stark-backend/src/keygen/types.rs (StarkVerifyingKey::has_common_main)
"""
return svk.params.width.common_main != 0
def _has_interaction(svk: StarkVerifyingKey) -> bool:
"""Check if this AIR has bus interactions.
Reference:
stark-backend/src/keygen/types.rs (StarkVerifyingKey::has_interaction)
"""
return len(svk.symbolic_constraints.interactions) > 0
def _num_cached_mains(svk: StarkVerifyingKey) -> int:
"""Number of cached main trace partitions.
Reference:
stark-backend/src/keygen/types.rs (StarkVerifyingKey::num_cached_mains)
"""
return len(svk.params.width.cached_mains)
def _num_phases(per_air_vks: list[StarkVerifyingKey]) -> int:
"""Maximum number of challenge phases across all AIRs.
Reference:
stark-backend/src/keygen/view.rs (MultiStarkVerifyingKeyView::num_phases)
"""
if not per_air_vks:
return 0
return max(len(vk.params.width.after_challenge) for vk in per_air_vks)
def _flattened_preprocessed_commits(per_air_vks: list[StarkVerifyingKey]) -> list[Digest]:
"""Return all non-None preprocessed commitments.
Reference:
stark-backend/src/keygen/view.rs
(MultiStarkVerifyingKeyView::flattened_preprocessed_commits)
"""
commits = []
for vk in per_air_vks:
if vk.preprocessed_data is not None:
commits.append(vk.preprocessed_data.commit)
return commits
def _preprocessed_commits(per_air_vks: list[StarkVerifyingKey]) -> list[Digest | None]:
"""Return preprocessed commit for each AIR (None if not present).
Reference:
stark-backend/src/keygen/view.rs
(MultiStarkVerifyingKeyView::preprocessed_commits)
"""
result = []
for vk in per_air_vks:
if vk.preprocessed_data is not None:
result.append(vk.preprocessed_data.commit)
else:
result.append(None)
return result
# ---------------------------------------------------------------------------
# Helper: RAP phase partial verification (FRI LogUp)
# ---------------------------------------------------------------------------
def _partially_verify_fri_log_up(
challenger: Challenger,
partial_proof: FriLogUpPartialProof | None,
exposed_values_per_air_per_phase: list[list[list[FF4]]],
commitments_per_phase: list[Digest],
log_up_pow_bits: int,
) -> tuple[list[list[FF4]], str | None]:
"""Partially verify the FRI LogUp challenge phase.
Returns (challenges_per_phase, error_or_none).
The challenges_per_phase is always returned (even on error) because the Rust
verifier continues past ChallengePhaseError to check OodEvaluationMismatch.
Reference:
stark-backend/src/interaction/fri_log_up.rs
(FriLogUpPhase::partially_verify, lines 153-237)
"""
# Check if any AIR has exposed values (i.e., has interactions)
has_any = any(
len(ev_per_phase) > 0
for ev_per_phase in exposed_values_per_air_per_phase
)
if not has_any:
# No interactions at all: no challenges, no error
return [], None
# There are interactions, so we need a partial proof
if partial_proof is None:
return [], "MissingPartialProof"
# PoW check for LogUp security
# Reference: fri_log_up.rs lines 181-189
if not check_witness(challenger, log_up_pow_bits, partial_proof.logup_pow_witness):
return [], "InvalidPowWitness"
# Sample interaction challenges (2 for FRI LogUp)
# Reference: fri_log_up.rs lines 191-192
challenges: list[FF4] = []
for _ in range(STARK_LU_NUM_CHALLENGES):
challenges.append(challenger.sample_ext())
# Observe exposed values (cumulative sums)
# Reference: fri_log_up.rs lines 194-200
for exposed_values_per_phase in exposed_values_per_air_per_phase:
if exposed_values_per_phase:
exposed_values = exposed_values_per_phase[0] # .first()
for exposed_value in exposed_values:
# exposed_value is FF4 (4 base field elements)
# challenger.observe_slice(exposed_value.as_base_slice())
challenger.observe_many(exposed_value)
# Observe after_challenge commitment
# Reference: fri_log_up.rs line 202
challenger.observe_many(commitments_per_phase[0])
# Check cumulative sum == 0
# Reference: fri_log_up.rs lines 204-232
# Sum all cumulative sums (exposed_values[0][0] for each AIR that has them)
cumulative_sum = FF4.zero()
for exposed_values_per_phase in exposed_values_per_air_per_phase:
if exposed_values_per_phase and exposed_values_per_phase[0]:
assert len(exposed_values_per_phase) <= 1, (
"Verifier does not support more than 1 challenge phase"
)
assert len(exposed_values_per_phase[0]) == 1, (
"Only exposed value should be cumulative sum"
)
csum = exposed_values_per_phase[0][0]
cumulative_sum = cumulative_sum + FF4(csum)
error = None
if cumulative_sum != FF4.zero():
error = "NonZeroCumulativeSum"
return [challenges], error
# ---------------------------------------------------------------------------
# Helper: Build PCS rounds from proof data
# ---------------------------------------------------------------------------
def _build_trace_domain_and_openings(
domain: TwoAdicMultiplicativeCoset,
zeta: FF4,
values: AdjacentOpenedValues,
) -> tuple[TwoAdicMultiplicativeCoset, list[tuple]]:
"""Build (domain, [(zeta, local_values), (next_point, next_values)]).
Reference:
stark-backend/src/verifier/mod.rs lines 221-226
(trace_domain_and_openings closure)
"""
next_point = domain.next_point(zeta)
return (
domain,
[
(zeta, values.local),
(next_point, values.next),
],
)
# ---------------------------------------------------------------------------
# Main verifier
# ---------------------------------------------------------------------------
# <doc-anchor id="verify-stark">
[docs]
def verify_stark(
vk: MultiStarkVerifyingKey,
proof: Proof,
fri_params: FriParameters,
) -> None:
"""Verify a multi-AIR STARK proof.
This is the top-level verification function. It must reproduce the Rust
verifier's Fiat-Shamir transcript operations in EXACTLY the same order.
Args:
vk: The multi-AIR verifying key.
proof: The complete STARK proof.
fri_params: FRI protocol parameters.
Raises:
VerificationError: If any verification check fails.
Reference:
stark-backend/src/verifier/mod.rs
(MultiTraceStarkVerifier::verify + verify_raps, lines 38-430)
"""
# --- Phase 1: Setup ---
# Reference: mod.rs lines 46-47
air_ids = [ap.air_id for ap in proof.per_air]
per_air_vks = _get_vk_view(vk, air_ids)
num_airs = len(air_ids)
# --- Phase 2: Transcript initialization ---
challenger = Challenger()
# (1) Observe VK pre_hash (8 u32s, each observed individually)
# Reference: mod.rs line 66
challenger.observe_many(vk.pre_hash)
# (2) Observe number of AIRs
# Reference: mod.rs line 69
# Rust: challenger.observe(Val::from_usize(num_airs)) puts Montgomery form into sponge.
# Python: our Poseidon2 FFI converts canonical -> Montgomery internally, so we
# observe the canonical integer value. The FFI equivalence is:
# Rust sponge[i] = to_monty(x)
# Python sponge[i] = x, then FFI does BabyBear::new(x) = to_monty(x) before permute.
challenger.observe(num_airs)
# (3) Observe each air_id
# Reference: mod.rs lines 70-72
# Same canonical-form reasoning as num_airs above.
for air_id in air_ids:
challenger.observe(air_id)
# --- Trace height constraint checks ---
# Reference: mod.rs lines 74-83
for constraint in vk.inner.trace_height_constraints:
total = sum(
constraint.coefficients[ap.air_id] * ap.degree
for ap in proof.per_air
)
if total >= constraint.threshold:
raise InvalidProofShape(
f"trace height constraint violated: {total} >= {constraint.threshold}"
)
# --- Check air_ids are unique and sorted ---
# Reference: mod.rs lines 85-93
sorted_ids = sorted(air_ids)
for i in range(len(sorted_ids) - 1):
if sorted_ids[i] >= sorted_ids[i + 1]:
raise VerificationError("DuplicateAirs: duplicate air_id found")
# --- Get public values per AIR ---
# Reference: mod.rs lines 96-104
public_values = [ap.public_values for ap in proof.per_air]
for pvs, svk in zip(public_values, per_air_vks):
if len(pvs) != svk.params.num_public_values:
raise InvalidProofShape(
f"public values count mismatch: {len(pvs)} vs {svk.params.num_public_values}"
)
# (4) Observe public values for each AIR
# Reference: mod.rs lines 106-108
for pis in public_values:
challenger.observe_many(pis)
# (5) Observe preprocessed commitments
# Reference: mod.rs lines 110-112
for preprocessed_commit in _flattened_preprocessed_commits(per_air_vks):
challenger.observe_many(preprocessed_commit)
# --- Validate main trace commit count ---
# Reference: mod.rs lines 115-125
num_cached_mains = sum(
len(svk.params.width.cached_mains) for svk in per_air_vks
)
num_main_commits = num_cached_mains + 1 # always 1 common main
if len(proof.commitments.main_trace) != num_main_commits:
raise InvalidProofShape(
f"main trace commit count mismatch: "
f"{len(proof.commitments.main_trace)} vs {num_main_commits}"
)
# (6) Observe main trace commitments
# Reference: mod.rs line 127
# challenger.observe_slice(&proof.commitments.main_trace)
# Each commitment is a Digest (8 u32s), and observe(Hash) observes each element.
for commit in proof.commitments.main_trace:
challenger.observe_many(commit)
# (7) Observe log2(degree) for each AIR
# Reference: mod.rs lines 128-134
# Same canonical-form reasoning as num_airs above.
for ap in proof.per_air:
log_degree = ap.degree.bit_length() - 1
assert 1 << log_degree == ap.degree, f"degree must be power of 2: {ap.degree}"
challenger.observe(log_degree)
# --- Phase 3: RAP (interaction) challenge phase ---
# Reference: mod.rs lines 136-184
exposed_values_per_air_per_phase = [
ap.exposed_values_after_challenge for ap in proof.per_air
]
num_phases = _num_phases(per_air_vks)
if num_phases != len(proof.commitments.after_challenge) or num_phases > 1:
raise InvalidProofShape(
f"after_challenge phase count mismatch: "
f"num_phases={num_phases}, commits={len(proof.commitments.after_challenge)}"
)
# Validate exposed_values shape (T01c)
# Reference: mod.rs lines 166-172
for ev_per_phase, svk in zip(exposed_values_per_air_per_phase, per_air_vks):
if len(ev_per_phase) != len(svk.params.num_exposed_values_after_challenge):
raise InvalidProofShape("exposed_values_after_challenge phase count mismatch")
for ev, n in zip(ev_per_phase, svk.params.num_exposed_values_after_challenge):
if len(ev) != n:
raise InvalidProofShape(
f"exposed_values count mismatch: {len(ev)} vs {n}"
)
# <doc-anchor id="verify-logup">
# Call RAP phase partial verification
# Reference: mod.rs lines 174-184
challenges_per_phase, rap_phase_error = _partially_verify_fri_log_up(
challenger,
proof.rap_phase_seq_proof,
exposed_values_per_air_per_phase,
proof.commitments.after_challenge,
vk.inner.log_up_pow_bits,
)
# Don't bail on error yet -- OodEvaluationMismatch takes precedence
# Reference: mod.rs lines 181-184
# --- Phase 4: Quotient challenge ---
# Sample alpha (constraint combination challenge)
# Reference: mod.rs lines 187-188
alpha: FF4 = challenger.sample_ext()
# Observe quotient commitment (8 u32s)
# Reference: mod.rs line 191
challenger.observe_many(proof.commitments.quotient)
# DEEP proof-of-work check
# Reference: mod.rs lines 193-198
deep_pow_bits = vk.inner.deep_pow_bits
if not check_witness(challenger, deep_pow_bits, proof.opening.deep_pow_witness):
raise InvalidDeepPowWitness("DEEP proof-of-work witness is invalid")
# Sample zeta (OOD evaluation point)
# Reference: mod.rs lines 199-201
zeta: FF4 = challenger.sample_ext()
# --- Phase 5: Build domains ---
# Reference: mod.rs lines 204-218
domains: list[TwoAdicMultiplicativeCoset] = []
quotient_chunks_domains: list[list[TwoAdicMultiplicativeCoset]] = []
for svk, ap in zip(per_air_vks, proof.per_air):
degree = ap.degree
quotient_degree = svk.quotient_degree
domain = natural_domain_for_degree(degree)
quotient_domain = create_disjoint_domain(domain, degree * quotient_degree)
qc_doms = quotient_domain.split_domains(quotient_degree)
domains.append(domain)
quotient_chunks_domains.append(qc_doms)
# --- Phase 6: Build PCS rounds and observe opened values ---
opened_values = proof.opening.values
# 1. Preprocessed trace openings
# Reference: mod.rs lines 231-258
preprocessed_widths = [
svk.params.width.preprocessed
for svk in per_air_vks
if svk.params.width.preprocessed is not None
]
if len(preprocessed_widths) != len(opened_values.preprocessed):
raise InvalidProofShape(
f"preprocessed count mismatch: "
f"{len(preprocessed_widths)} vs {len(opened_values.preprocessed)}"
)
for w, ov in zip(preprocessed_widths, opened_values.preprocessed):
if w != len(ov.local) or w != len(ov.next):
raise InvalidProofShape("preprocessed width mismatch")
rounds: list[PcsRound] = []
# Build preprocessed rounds: each AIR with preprocessed gets its own round
prep_commits = _preprocessed_commits(per_air_vks)
prep_idx = 0
for i_air in range(num_airs):
commit = prep_commits[i_air]
if commit is not None:
domain = domains[i_air]
values = opened_values.preprocessed[prep_idx]
dom_and_openings = _build_trace_domain_and_openings(domain, zeta, values)
rounds.append(PcsRound(
commitment=commit,
domains_and_openings=[dom_and_openings],
))
prep_idx += 1
# 2. Main trace openings
# Reference: mod.rs lines 260-303
num_main_commits_ov = len(opened_values.main)
if num_main_commits_ov != len(proof.commitments.main_trace):
raise InvalidProofShape("main opened values count != main trace commits")
main_commit_idx = 0
# Cached main traces (all commits except the last)
# Reference: mod.rs lines 268-276
for svk, domain in zip(per_air_vks, domains):
for cached_main_width in svk.params.width.cached_mains:
commit = proof.commitments.main_trace[main_commit_idx]
if len(opened_values.main[main_commit_idx]) != 1:
raise InvalidProofShape("cached main should have exactly 1 matrix")
value = opened_values.main[main_commit_idx][0]
if cached_main_width != len(value.local) or cached_main_width != len(value.next):
raise InvalidProofShape("cached main width mismatch")
dom_and_openings = _build_trace_domain_and_openings(domain, zeta, value)
rounds.append(PcsRound(
commitment=commit,
domains_and_openings=[dom_and_openings],
))
main_commit_idx += 1
# Common main trace (last commit, multiple matrices for different AIRs)
# Reference: mod.rs lines 283-303
values_per_mat = opened_values.main[main_commit_idx]
commit = proof.commitments.main_trace[main_commit_idx]
common_main_domains_and_openings = []
mat_idx = 0
for svk, domain in zip(per_air_vks, domains):
if _has_common_main(svk):
if mat_idx >= len(values_per_mat):
raise InvalidProofShape("not enough common main matrices")
values = values_per_mat[mat_idx]
width = svk.params.width.common_main
if width != len(values.local) or width != len(values.next):
raise InvalidProofShape("common main width mismatch")
dom_and_openings = _build_trace_domain_and_openings(domain, zeta, values)
common_main_domains_and_openings.append(dom_and_openings)
mat_idx += 1
if len(common_main_domains_and_openings) != len(values_per_mat):
raise InvalidProofShape("common main matrix count mismatch")
rounds.append(PcsRound(
commitment=commit,
domains_and_openings=common_main_domains_and_openings,
))
# 3. After-challenge trace openings (at most 1 phase)
# Reference: mod.rs lines 305-340
has_any_interaction = any(_has_interaction(svk) for svk in per_air_vks)
if not has_any_interaction:
if (proof.commitments.after_challenge
or opened_values.after_challenge):
raise InvalidProofShape("no interactions but after_challenge data present")
assert num_phases == 0
else:
if num_phases != 1 or len(opened_values.after_challenge) != 1:
raise InvalidProofShape("after_challenge shape mismatch")
after_challenge_commit = proof.commitments.after_challenge[0]
ac_domains_and_openings = []
ac_mat_idx = 0
for svk, domain in zip(per_air_vks, domains):
if _has_interaction(svk):
if ac_mat_idx >= len(opened_values.after_challenge[0]):
raise InvalidProofShape("not enough after_challenge matrices")
values = opened_values.after_challenge[0][ac_mat_idx]
width = svk.params.width.after_challenge[0] * EXT_DEGREE
if width != len(values.local) or width != len(values.next):
raise InvalidProofShape("after_challenge width mismatch")
dom_and_openings = _build_trace_domain_and_openings(
domain, zeta, values
)
ac_domains_and_openings.append(dom_and_openings)
ac_mat_idx += 1
if len(ac_domains_and_openings) != len(opened_values.after_challenge[0]):
raise InvalidProofShape("after_challenge matrix count mismatch")
rounds.append(PcsRound(
commitment=after_challenge_commit,
domains_and_openings=ac_domains_and_openings,
))
# 4. Quotient openings
# Reference: mod.rs lines 341-366
if len(opened_values.quotient) != num_airs:
raise InvalidProofShape(
f"quotient count mismatch: {len(opened_values.quotient)} vs {num_airs}"
)
for per_air_q, svk in zip(opened_values.quotient, per_air_vks):
if len(per_air_q) != svk.quotient_degree:
raise InvalidProofShape("quotient chunk count mismatch")
for chunk in per_air_q:
if len(chunk) != EXT_DEGREE:
raise InvalidProofShape("quotient chunk width mismatch")
quotient_domains_and_openings = []
for air_chunks, qc_doms in zip(opened_values.quotient, quotient_chunks_domains):
for chunk_values, qc_domain in zip(air_chunks, qc_doms):
quotient_domains_and_openings.append(
(qc_domain, [(zeta, chunk_values)])
)
rounds.append(PcsRound(
commitment=proof.commitments.quotient,
domains_and_openings=quotient_domains_and_openings,
))
# <doc-anchor id="verify-pcs">
# --- Phase 6b: PCS verification ---
# Reference: mod.rs lines 362-363 (pcs.verify)
try:
pcs_verify(rounds, proof.opening.proof, challenger, fri_params)
except (AssertionError, Exception) as e:
raise InvalidOpeningArgument(str(e)) from e
# --- Phase 7: Constraint verification ---
# Reference: mod.rs lines 365-424
preprocessed_idx = 0
after_challenge_idx = [0] * num_phases
cached_main_commit_idx = 0
common_main_matrix_idx = 0
for i_air in range(num_airs):
domain = domains[i_air]
qc_doms = quotient_chunks_domains[i_air]
quotient_chunks = opened_values.quotient[i_air]
svk = per_air_vks[i_air]
air_proof = proof.per_air[i_air]
# Preprocessed values
# Reference: mod.rs lines 378-382
preprocessed_values: AdjacentOpenedValues | None = None
if svk.preprocessed_data is not None:
preprocessed_values = opened_values.preprocessed[preprocessed_idx]
preprocessed_idx += 1
# Partitioned main values
# Reference: mod.rs lines 383-397
partitioned_main_values: list[AdjacentOpenedValues] = []
for _ in range(_num_cached_mains(svk)):
partitioned_main_values.append(
opened_values.main[cached_main_commit_idx][0]
)
cached_main_commit_idx += 1
if _has_common_main(svk):
partitioned_main_values.append(
opened_values.main[-1][common_main_matrix_idx]
)
common_main_matrix_idx += 1
# After challenge values
# Reference: mod.rs lines 394-410
after_challenge_values: list[AdjacentOpenedValues] = []
if _has_interaction(svk):
for phase_idx in range(num_phases):
matrix_idx = after_challenge_idx[phase_idx]
after_challenge_idx[phase_idx] += 1
after_challenge_values.append(
opened_values.after_challenge[phase_idx][matrix_idx]
)
# <doc-anchor id="verify-constraints">
# Verify constraints for this AIR
# Reference: mod.rs lines 405-418
verify_single_rap_constraints(
constraints=svk.symbolic_constraints.constraints,
preprocessed_values=preprocessed_values,
partitioned_main_values=partitioned_main_values,
after_challenge_values=after_challenge_values,
quotient_chunks=quotient_chunks,
domain=domain,
qc_domains=qc_doms,
zeta=zeta,
alpha=alpha,
challenges=challenges_per_phase,
public_values=air_proof.public_values,
exposed_values_after_challenge=air_proof.exposed_values_after_challenge,
)
# If we made it this far without OodEvaluationMismatch, check the RAP phase result
# Reference: mod.rs lines 422-428
if rap_phase_error is not None:
raise ChallengePhaseError(rap_phase_error)
# ===========================================================================
# STARK Prover
# ===========================================================================
# <doc-anchor id="prove-stark">
[docs]
def prove_stark(
vk: MultiStarkVerifyingKey,
traces: list[list[list[Fe]]],
public_values_per_air: list[list[Fe]],
fri_params: FriParameters,
preprocessed_traces: list[list[list[Fe]] | None] | None = None,
cached_main_traces: list[list[list[list[Fe]]]] | None = None,
air_ids: list[int] | None = None,
) -> Proof:
"""Generate a multi-AIR STARK proof.
Prover counterpart to verify_stark. Reproduces the same Fiat-Shamir
transcript operations so that the resulting proof is accepted by the
verifier.
Args:
vk: The multi-AIR verifying key.
traces: Per-AIR common main trace matrices (list of rows).
public_values_per_air: Per-AIR public values.
fri_params: FRI protocol parameters.
preprocessed_traces: Per-AIR preprocessed trace matrices (None if absent).
cached_main_traces: Per-AIR list of cached main trace matrices.
air_ids: Maps proof index to VK AIR index. If None, assumes identity
mapping (all VK AIRs have data).
Returns:
A Proof matching the Rust prover output.
Reference:
stark-backend/src/prover/coordinator.rs prove
"""
from protocol.pcs import (
CommittedData,
PcsOpeningRound,
_generate_batch_opening,
pcs_commit,
pcs_open,
)
from protocol.quotient import compute_quotient_chunks
if air_ids is None:
air_ids = list(range(len(vk.inner.per_air)))
per_air_vks = _get_vk_view(vk, air_ids)
num_airs = len(air_ids)
# --- Phase 1: Transcript initialization ---
challenger = Challenger()
# <doc-anchor id="seed-transcript">
# (1) Observe VK pre_hash
challenger.observe_many(vk.pre_hash)
# (2) Observe number of AIRs
challenger.observe(num_airs)
# (3) Observe each air_id
for air_id in air_ids:
challenger.observe(air_id)
# --- Phase 2: Commit traces ---
# Build domains for each AIR
domains: list[TwoAdicMultiplicativeCoset] = []
for i_air in range(num_airs):
# Determine height from available trace data
if traces[i_air]:
height = len(traces[i_air])
elif cached_main_traces and cached_main_traces[i_air]:
height = len(cached_main_traces[i_air][0])
elif preprocessed_traces and preprocessed_traces[i_air] is not None:
height = len(preprocessed_traces[i_air])
else:
raise ValueError(f"AIR {i_air} has no trace data to determine height")
domain = natural_domain_for_degree(height)
domains.append(domain)
# (a) Commit preprocessed traces (each AIR gets its own commitment).
# Re-commit from raw data so we have Merkle trees for PCS openings.
# Verify roots match VK to ensure consistency.
preprocessed_committed_list: list[tuple[int, CommittedData]] = []
for i_air in range(num_airs):
svk = per_air_vks[i_air]
if svk.preprocessed_data is not None:
prep = preprocessed_traces[i_air] if preprocessed_traces else None
assert prep is not None, (
f"VK expects preprocessed for AIR {i_air} but none provided"
)
committed = pcs_commit([(domains[i_air], prep)], fri_params.log_blowup)
assert committed.root == svk.preprocessed_data.commit, (
f"Preprocessed commitment mismatch for AIR {i_air}"
)
preprocessed_committed_list.append((i_air, committed))
# (b) Commit cached main traces (each partition gets its own commitment).
# Order: for each AIR, for each cached main partition -> separate commit.
cached_committed_list: list[CommittedData] = []
main_commits: list[Digest] = []
for i_air in range(num_airs):
svk = per_air_vks[i_air]
num_cached = len(svk.params.width.cached_mains)
if num_cached > 0:
air_cached = cached_main_traces[i_air] if cached_main_traces else []
assert len(air_cached) == num_cached, (
f"AIR {i_air}: expected {num_cached} cached mains, got {len(air_cached)}"
)
for cached_matrix in air_cached:
committed = pcs_commit(
[(domains[i_air], cached_matrix)], fri_params.log_blowup
)
cached_committed_list.append(committed)
main_commits.append(committed.root)
# <doc-anchor id="commit-main-trace">
# (c) Commit common main trace (all AIRs in one batch).
common_main_evals = []
for i_air in range(num_airs):
if _has_common_main(per_air_vks[i_air]):
common_main_evals.append((domains[i_air], traces[i_air]))
main_committed = pcs_commit(common_main_evals, fri_params.log_blowup)
main_commits.append(main_committed.root)
# --- Phase 3: Observe into transcript ---
# (4) Observe public values
for pis in public_values_per_air:
challenger.observe_many(pis)
# (5) Observe preprocessed commitments (from VK, not re-committed roots)
for commit in _flattened_preprocessed_commits(per_air_vks):
challenger.observe_many(commit)
# (6) Observe main trace commitments (cached + common)
for commit in main_commits:
challenger.observe_many(commit)
# (7) Observe log_degree for each AIR
for i_air in range(num_airs):
degree = domains[i_air].size()
log_degree = degree.bit_length() - 1
challenger.observe(log_degree)
# --- Phase 4: RAP phase (interactions) ---
has_any_interaction = any(_has_interaction(svk) for svk in per_air_vks)
rap_phase_seq_proof = None
challenges_per_phase: list[list[FF4]] = []
after_challenge_commits: list[Digest] = []
after_challenge_per_air: list[list[list[FF4]] | None] = [None] * num_airs
exposed_values_per_air: list[list[list[FF4]]] = [[] for _ in range(num_airs)]
ac_committed: CommittedData | None = None
# <doc-anchor id="commit-after-challenge">
if has_any_interaction:
from protocol.logup import (
compute_after_challenge_trace,
compute_max_constraint_degree,
find_interaction_chunks,
)
max_cd = compute_max_constraint_degree(vk.inner.per_air)
# (a) LogUp PoW grinding
logup_pow_witness = grind(challenger, vk.inner.log_up_pow_bits)
# (b) Sample 2 interaction challenges
interaction_challenges: list[FF4] = [
challenger.sample_ext() for _ in range(STARK_LU_NUM_CHALLENGES)
]
alpha_lu, beta_lu = interaction_challenges[0], interaction_challenges[1]
# (c) Compute after_challenge trace per AIR
for i_air in range(num_airs):
svk = per_air_vks[i_air]
interactions = svk.symbolic_constraints.interactions
if not interactions:
continue
dag = svk.symbolic_constraints.constraints
partitions = find_interaction_chunks(interactions, dag, max_cd)
# Build partitioned_main: [cached_mains..., common_main]
partitioned_main: list[list[list[Fe]]] = []
if cached_main_traces and cached_main_traces[i_air]:
for cm in cached_main_traces[i_air]:
partitioned_main.append(cm)
if _has_common_main(svk):
partitioned_main.append(traces[i_air])
prep = None
if preprocessed_traces and preprocessed_traces[i_air] is not None:
prep = preprocessed_traces[i_air]
height = domains[i_air].size()
perm_trace, cum_sum = compute_after_challenge_trace(
interactions, partitions, dag,
partitioned_main, prep,
public_values_per_air[i_air],
alpha_lu, beta_lu, height,
)
after_challenge_per_air[i_air] = perm_trace
exposed_values_per_air[i_air] = [[cum_sum]]
# (d) Observe exposed values (cumulative sums)
for i_air in range(num_airs):
evs = exposed_values_per_air[i_air]
if evs:
for ev in evs[0]: # phase 0
challenger.observe_many(ev)
# (e) Flatten after_challenge traces to base field and commit.
# Each FF4 element → 4 base field columns.
ac_evals = []
for i_air in range(num_airs):
if after_challenge_per_air[i_air] is not None:
domain = domains[i_air]
perm_trace = after_challenge_per_air[i_air]
height = len(perm_trace)
perm_width = len(perm_trace[0])
base_field_rows: list[list[Fe]] = []
for row_idx in range(height):
row: list[Fe] = []
for col in range(perm_width):
row.extend(perm_trace[row_idx][col])
base_field_rows.append(row)
ac_evals.append((domain, base_field_rows))
ac_committed = pcs_commit(ac_evals, fri_params.log_blowup)
# (f) Observe after_challenge commitment
challenger.observe_many(ac_committed.root)
rap_phase_seq_proof = FriLogUpPartialProof(
logup_pow_witness=logup_pow_witness
)
challenges_per_phase = [interaction_challenges]
after_challenge_commits = [ac_committed.root]
# --- Phase 5: Sample alpha and compute quotient ---
alpha: FF4 = challenger.sample_ext()
# Compute quotient for each AIR
all_quotient_chunks: list[list[list[list[Fe]]]] = []
quotient_chunk_domains: list[list[TwoAdicMultiplicativeCoset]] = []
for i_air in range(num_airs):
svk = per_air_vks[i_air]
domain = domains[i_air]
# Build partitioned traces for quotient evaluation
partitioned_traces_list: list[list[list[Fe]]] | None = None
num_cached = _num_cached_mains(svk)
if num_cached > 0 or (preprocessed_traces and preprocessed_traces[i_air]):
# Multi-partition or preprocessed: need partitioned traces
parts: list[list[list[Fe]]] = []
if cached_main_traces and cached_main_traces[i_air]:
for cm in cached_main_traces[i_air]:
parts.append(cm)
if _has_common_main(svk):
parts.append(traces[i_air])
partitioned_traces_list = parts
# Prepare challenges and exposed values for quotient
chall = None
exp_vals = None
if has_any_interaction and _has_interaction(svk):
chall = [challenges_per_phase[0]]
evs = exposed_values_per_air[i_air]
exp_vals = evs if evs else None
prep = None
if preprocessed_traces and preprocessed_traces[i_air] is not None:
prep = preprocessed_traces[i_air]
chunks, qc_doms = compute_quotient_chunks(
trace=traces[i_air],
constraints_dag=svk.symbolic_constraints.constraints,
public_values=public_values_per_air[i_air],
alpha=alpha,
trace_domain=domain,
quotient_degree=svk.quotient_degree,
preprocessed_trace=prep,
after_challenge_trace=after_challenge_per_air[i_air],
challenges=chall,
exposed_values=exp_vals,
partitioned_traces=partitioned_traces_list,
)
all_quotient_chunks.append(chunks)
quotient_chunk_domains.append(qc_doms)
# <doc-anchor id="commit-quotient">
# --- Phase 6: Commit quotient ---
quotient_evals = []
for i_air in range(num_airs):
for chunk_idx, chunk_matrix in enumerate(all_quotient_chunks[i_air]):
qc_domain = quotient_chunk_domains[i_air][chunk_idx]
quotient_evals.append((qc_domain, chunk_matrix))
quotient_committed = pcs_commit(quotient_evals, fri_params.log_blowup)
# (8) Observe quotient commitment
challenger.observe_many(quotient_committed.root)
# --- Phase 7: DEEP PoW ---
deep_pow_witness = grind(challenger, vk.inner.deep_pow_bits)
# <doc-anchor id="open-at-zeta">
# --- Phase 8: Sample zeta ---
zeta: FF4 = challenger.sample_ext()
# --- Phase 9: Build PCS opening rounds ---
# Round order must match the verifier exactly:
# preprocessed (each own round) → cached mains (each own round) →
# common main (one round) → after_challenge (one round) → quotient (one round)
opening_rounds: list[PcsOpeningRound] = []
next_points = [domains[i_air].next_point(zeta) for i_air in range(num_airs)]
# (a) Preprocessed traces (each in its own round)
for i_air, committed in preprocessed_committed_list:
opening_rounds.append(PcsOpeningRound(
committed=committed,
points_per_mat=[[zeta, next_points[i_air]]],
))
# (b) Cached main traces (each in its own round)
cached_idx = 0
for i_air in range(num_airs):
svk = per_air_vks[i_air]
for _ in svk.params.width.cached_mains:
committed = cached_committed_list[cached_idx]
opening_rounds.append(PcsOpeningRound(
committed=committed,
points_per_mat=[[zeta, next_points[i_air]]],
))
cached_idx += 1
# (c) Common main trace (one round, multiple matrices)
common_main_points = []
for i_air in range(num_airs):
if _has_common_main(per_air_vks[i_air]):
common_main_points.append([zeta, next_points[i_air]])
opening_rounds.append(PcsOpeningRound(
committed=main_committed,
points_per_mat=common_main_points,
))
# (d) After-challenge traces (one round if present)
if has_any_interaction and ac_committed is not None:
ac_points = []
for i_air in range(num_airs):
if _has_interaction(per_air_vks[i_air]):
ac_points.append([zeta, next_points[i_air]])
opening_rounds.append(PcsOpeningRound(
committed=ac_committed,
points_per_mat=ac_points,
))
# (e) Quotient chunks (one round)
quotient_points = []
for i_air in range(num_airs):
for _ in all_quotient_chunks[i_air]:
quotient_points.append([zeta])
opening_rounds.append(PcsOpeningRound(
committed=quotient_committed,
points_per_mat=quotient_points,
))
# --- Phase 10: PCS open ---
all_opened_values, fri_proof_data, query_indices = pcs_open(
opening_rounds, challenger, fri_params
)
# --- Phase 11: Build proof structure ---
# Extract opened values from all_opened_values[round_idx][mat_idx][point_idx]
round_idx = 0
# Preprocessed opened values
preprocessed_ov: list[AdjacentOpenedValues] = []
for _ in preprocessed_committed_list:
mat_values = all_opened_values[round_idx][0] # single matrix
preprocessed_ov.append(AdjacentOpenedValues(
local=mat_values[0],
next=mat_values[1],
))
round_idx += 1
# Main trace opened values: [cached_0, ..., cached_n, common_main_batch]
main_ov: list[list[AdjacentOpenedValues]] = []
# Cached mains (each has 1 matrix)
for _ in cached_committed_list:
mat_values = all_opened_values[round_idx][0]
main_ov.append([AdjacentOpenedValues(
local=mat_values[0],
next=mat_values[1],
)])
round_idx += 1
# Common main (multiple matrices, one per AIR with common main)
common_main_round = all_opened_values[round_idx]
common_main_avs: list[AdjacentOpenedValues] = []
for mat_values in common_main_round:
common_main_avs.append(AdjacentOpenedValues(
local=mat_values[0],
next=mat_values[1],
))
main_ov.append(common_main_avs)
round_idx += 1
# After-challenge opened values
after_challenge_ov: list[list[AdjacentOpenedValues]] = []
if has_any_interaction and ac_committed is not None:
ac_round = all_opened_values[round_idx]
ac_avs: list[AdjacentOpenedValues] = []
for mat_values in ac_round:
ac_avs.append(AdjacentOpenedValues(
local=mat_values[0],
next=mat_values[1],
))
after_challenge_ov.append(ac_avs)
round_idx += 1
# Quotient opened values
quotient_round = all_opened_values[round_idx]
quotient_ov: list[list[list[FF4]]] = []
chunk_idx = 0
for i_air in range(num_airs):
air_chunks_ov: list[list[FF4]] = []
for _ in all_quotient_chunks[i_air]:
air_chunks_ov.append(quotient_round[chunk_idx][0])
chunk_idx += 1
quotient_ov.append(air_chunks_ov)
opened_values = OpenedValues(
preprocessed=preprocessed_ov,
main=main_ov,
after_challenge=after_challenge_ov,
quotient=quotient_ov,
)
# Build query proofs
log_global_max_height = (
len(fri_proof_data["commit_phase_commits"])
+ fri_params.log_blowup
+ fri_params.log_final_poly_len
)
query_proofs: list[QueryProof] = []
for qi in range(fri_params.num_queries):
query_index = query_indices[qi]
_, fri_openings = fri_proof_data["fri_query_proofs"][qi]
input_proof: list[BatchOpening] = []
for rnd in opening_rounds:
batch_opening = _generate_batch_opening(
rnd.committed, query_index, log_global_max_height
)
input_proof.append(batch_opening)
commit_phase_openings: list[CommitPhaseProofStep] = []
for step in fri_openings:
commit_phase_openings.append(CommitPhaseProofStep(
sibling_value=step.sibling_value,
opening_proof=step.opening_proof,
))
query_proofs.append(QueryProof(
input_proof=input_proof,
commit_phase_openings=commit_phase_openings,
))
# Build FRI proof
fri_proof = ProofFriProof(
commit_phase_commits=fri_proof_data["commit_phase_commits"],
query_proofs=query_proofs,
final_poly=fri_proof_data["final_poly"],
commit_pow_witnesses=fri_proof_data["commit_pow_witnesses"],
query_pow_witness=fri_proof_data["query_pow_witness"],
)
opening_proof = OpeningProof(
proof=fri_proof,
values=opened_values,
deep_pow_witness=deep_pow_witness,
)
# Build per-AIR proof data
per_air_proof_data: list[AirProofData] = []
for i_air in range(num_airs):
per_air_proof_data.append(AirProofData(
air_id=air_ids[i_air],
degree=domains[i_air].size(),
exposed_values_after_challenge=exposed_values_per_air[i_air],
public_values=public_values_per_air[i_air],
))
commitments = Commitments(
main_trace=main_commits,
after_challenge=after_challenge_commits,
quotient=quotient_committed.root,
)
return Proof(
commitments=commitments,
opening=opening_proof,
per_air=per_air_proof_data,
rap_phase_seq_proof=rap_phase_seq_proof,
)