Source code for protocol.fri

"""FRI (Fast Reed-Solomon IOP of Proximity) protocol.

Implements FRI folding, verification, and proving matching Plonky3's FRI.

Reference:
    p3-fri-0.4.1/src/ (prover.rs, verifier.rs, two_adic_pcs.rs)
"""

from dataclasses import dataclass

import numpy as np

from primitives.field import (
    BABYBEAR_PRIME,
    DIGEST_WIDTH,
    Digest,
    FF4,
    FIELD_EXTENSION_DEGREE,
    FF,
    Fe,
    MerklePath,
    TWO_INV,
    W,
    bit_reverse_list,
    get_omega,
    inv_mod,
    reverse_bits_len,
)
from primitives.merkle import (
    build_merkle_tree,
    get_opening_proof,
    verify_opening_prehashed,
)
from primitives.ntt import ef4_idft
from primitives.poseidon2 import hash_to_digest
from primitives.transcript import Challenger, grind

[docs] p = BABYBEAR_PRIME
# --- Data Structures --- @dataclass
[docs] class CommitPhaseResult: """Output of FRI commit phase."""
[docs] commits: list[Digest]
[docs] betas: list[list[int]]
[docs] final_poly: list[list[int]]
[docs] trees: list[list[list[Digest]]]
[docs] folded_per_round: list[list[list[int]]]
[docs] all_round_evals: list[list[list[int]]]
[docs] commit_pow_witnesses: list[int]
@dataclass
[docs] class FriQueryStep: """One round of a FRI query opening."""
[docs] sibling_value: list[int]
[docs] opening_proof: MerklePath
@dataclass
[docs] class FriQueryResult: """FRI query proof for a single query index."""
[docs] index: int
[docs] commit_phase_openings: list[FriQueryStep]
@dataclass
[docs] class CommitPhaseOutput: """Complete FRI proof."""
[docs] commit_phase_commits: list[Digest]
[docs] final_poly: list[list[int]]
[docs] query_proofs: list[FriQueryResult]
[docs] betas: list[list[int]]
[docs] folded_per_round: list[list[list[int]]]
# --- Verifier: Natural-Order Folding ---
[docs] def fri_fold( evals: list[list[int]], challenge: list[int], log_domain_size: int, coset_shift: Fe, ) -> list[list[int]]: """Fold evaluations on coset: f_even(y) + beta * f_odd(y). Reference: p3-fri two_adic_pcs.rs (TwoAdicFriFolder::fold_row) """ n = len(evals) half = n // 2 beta = FF4(challenge) omega = get_omega(log_domain_size) folded = [] for i in range(half): f_pos = FF4(evals[i]) # f(x) where x = shift * omega^i f_neg = FF4(evals[i + half]) # f(-x) where -x = shift * omega^(i+N/2) x = (coset_shift * pow(omega, i, p)) % p half_inv_x = inv_mod((2 * x) % p) # 1/(2x) mod p even = (f_pos + f_neg).mul_base(TWO_INV) odd = (f_pos - f_neg).mul_base(half_inv_x) result = even + beta * odd folded.append(result.to_list()) return folded
# --- Verifier: Query Verification --- # <doc-anchor id="fri-fold">
[docs] def fold_row( index: int, log_height: int, beta: FF4, e0: FF4, e1: FF4, ) -> FF4: """Lagrange interpolation fold at challenge beta. Reference: p3-fri two_adic_pcs.rs (TwoAdicFriFolding::fold_row) """ beta, e0, e1 = FF4(beta), FF4(e0), FF4(e1) # x_even = two_adic_generator(log_height + 1) ^ reverse_bits_len(index, log_height) subgroup_start = pow(W[log_height + 1], reverse_bits_len(index, log_height), p) x_even = FF4(subgroup_start) # Lagrange interpolation: e0 + (beta - x_even) * (e1 - e0) / (x_odd - x_even) inv_diff = FF4(inv_mod((-2 * subgroup_start) % p)) return e0 + (beta - x_even) * (e1 - e0) * inv_diff
# <doc-anchor id="hash-fri-leaf">
[docs] def hash_fri_leaf(e0: FF4, e1: FF4) -> Digest: """Hash pair of extension field evaluations as FRI Merkle leaf. Reference: p3-merkle-tree mmcs.rs (verify_batch leaf hashing) """ e0, e1 = FF4(e0), FF4(e1) return hash_to_digest(e0.to_list() + e1.to_list())
[docs] def ef4_pairs_to_leaves(evals: list[list[int]]) -> list[list[int]]: """Pair consecutive FF4 elements into 8-element Merkle leaves.""" return [evals[i] + evals[i + 1] for i in range(0, len(evals), 2)]
[docs] def fri_verify_query( commit_phase_commits: list[Digest], betas: list[list[int]], query_index: int, query_proof: dict, reduced_opening: list[int], final_poly: list[list[int]], log_max_height: int, log_final_poly_len: int, ) -> list[int]: """Verify single FRI query: fold chain + Merkle proofs + final poly check. Reference: p3-fri verifier.rs (verify_query) """ num_rounds = len(commit_phase_commits) start_index = query_index folded_eval = FF4(reduced_opening) for round_idx in range(num_rounds): log_folded_height = log_max_height - 1 - round_idx opening = query_proof["commit_phase_openings"][round_idx] sibling = FF4(opening["sibling_value"]) beta = FF4(betas[round_idx]) # Arrange evals: e0 at even position, e1 at odd position index_sibling = start_index ^ 1 if index_sibling % 2 == 0: e0, e1 = sibling, folded_eval else: e0, e1 = folded_eval, sibling # Parent index (matching p3-fri: start_index >>= 1 before verify_batch) parent_index = start_index >> 1 # Verify Merkle proof for the pair of evaluations leaf_digest = hash_fri_leaf(e0, e1) assert verify_opening_prehashed( commit_phase_commits[round_idx], leaf_digest, parent_index, opening["opening_proof"], ), f"Merkle proof failed at round {round_idx}" # Fold via Lagrange interpolation folded_eval = fold_row(parent_index, log_folded_height, beta, e0, e1) # Advance to parent index for next round start_index = parent_index # Verify final polynomial evaluation if log_final_poly_len == 0: expected = FF4(final_poly[0]) else: x = pow(W[log_final_poly_len], reverse_bits_len(start_index, log_final_poly_len), p) expected = FF4.zero() for i, c in enumerate(final_poly): expected = expected + FF4(c) * FF4(pow(x, i, p)) assert folded_eval.to_list() == expected.to_list(), ( f"Final polynomial check failed: " f"{folded_eval.to_list()} != {expected.to_list()}" ) return folded_eval.to_list()
# <doc-anchor id="verify-fri">
[docs] def verify_fri( commit_phase_commits: list[Digest], final_poly: list[list[int]], query_proofs: list[dict], log_blowup: int, log_final_poly_len: int, num_queries: int, ) -> bool: """Verify FRI proof: transcript replay and structural consistency. Note: Full fold-chain verification requires reduced_openings from PCS. This function verifies transcript replay, query index derivation, and proof structure (lengths, digest sizes). Reference: p3-fri verifier.rs (verify_fri) """ challenger = Challenger() num_rounds = len(commit_phase_commits) log_max_height = num_rounds + log_blowup + log_final_poly_len # Phase 1: Replay commit phase — observe roots and sample betas betas = [] for commit in commit_phase_commits: challenger.observe_many(commit) beta = challenger.sample_ext() betas.append(beta) # Observe final polynomial coefficients for coeff in final_poly: challenger.observe_many(coeff) # Phase 2: Verify each query assert len(query_proofs) == num_queries, ( f"Expected {num_queries} query proofs, got {len(query_proofs)}" ) for qi, query_proof in enumerate(query_proofs): # Derive query index from transcript query_index = challenger.sample_bits(log_max_height) assert 0 <= query_index < (1 << log_max_height), ( f"Query {qi}: index {query_index} out of range [0, {1 << log_max_height})" ) # Verify opening structure openings = query_proof["commit_phase_openings"] assert len(openings) == num_rounds, ( f"Query {qi}: expected {num_rounds} opening steps, got {len(openings)}" ) # Verify each round's Merkle proof has correct length for tree height for round_idx, opening in enumerate(openings): log_folded_height = log_max_height - 1 - round_idx expected_proof_len = log_folded_height proof = opening["opening_proof"] assert len(proof) == expected_proof_len, ( f"Query {qi}, round {round_idx}: " f"expected {expected_proof_len} Merkle siblings, got {len(proof)}" ) # Verify sibling value is an extension field element assert len(opening["sibling_value"]) == FIELD_EXTENSION_DEGREE, ( f"Query {qi}, round {round_idx}: " f"sibling_value should have {FIELD_EXTENSION_DEGREE} components" ) # Verify each Merkle sibling is a valid digest for si, sibling in enumerate(proof): assert len(sibling) == DIGEST_WIDTH, ( f"Query {qi}, round {round_idx}, sibling {si}: " f"digest should have {DIGEST_WIDTH} elements, got {len(sibling)}" ) return True
# --- Prover: Bit-Reversed Folding --- # <doc-anchor id="fold-matrix">
[docs] def fold_matrix( evals_bit_reversed: list[list[int]], beta: FF4, log_height: int, ) -> list[list[int]]: """Fold bit-reversed evaluations: adjacent pairs are conjugates. Reference: p3-fri two_adic_pcs.rs (TwoAdicFriFolding::fold_matrix) """ height = len(evals_bit_reversed) // 2 # g_inv = two_adic_generator(log_height + 1)^(-1) g_inv = inv_mod(W[log_height + 1]) # Precompute halve_inv_powers[i] = g_inv^i / 2 (before bit-reversal) halve_inv_powers = [] val = TWO_INV # (1/2) * g_inv^0 = 1/2 for _ in range(height): halve_inv_powers.append(val) val = (val * g_inv) % p # Bit-reverse the powers halve_inv_powers = bit_reverse_list(halve_inv_powers) # Vectorized fold beta = FF4(beta) lo = FF4.from_rows(evals_bit_reversed[0::2]) hi = FF4.from_rows(evals_bit_reversed[1::2]) hip = FF(halve_inv_powers) two_inv_col = FF(np.full(height, TWO_INV)) # result = (lo + hi) * TWO_INV + (lo - hi) * beta * halve_inv_power sum_half = (lo + hi).mul_base(two_inv_col) diff_beta = (lo - hi) * beta diff_beta_hip = diff_beta.mul_base(hip) result = sum_half + diff_beta_hip return result.to_rows()
# --- Prover: Commit Phase --- # <doc-anchor id="commit-phase">
[docs] def commit_phase( evals_bit_reversed: list[list[int]], log_blowup: int, log_final_poly_len: int, challenger: Challenger, commit_pow_bits: int = 0, reduced_openings_by_height: dict[int, list[list[int]]] | None = None, ) -> CommitPhaseResult: """FRI commit phase: iterative folding with Merkle commitments. For multi-height FRI, reduced_openings_by_height maps log_height to bit-reversed reduced evaluations at that height. After folding to a given height, the corresponding reduced opening is rolled in using beta^2 as the combination factor (matching the verifier). Reference: p3-fri prover.rs (commit_phase) """ folded = list(evals_bit_reversed) commits: list[Digest] = [] betas: list[list[int]] = [] trees: list[list[list[Digest]]] = [] folded_per_round: list[list[list[int]]] = [] all_round_evals: list[list[list[int]]] = [] commit_pow_witnesses: list[int] = [] blowup = 1 << log_blowup final_poly_len = 1 << log_final_poly_len while len(folded) > blowup * final_poly_len: height = len(folded) // 2 log_height = height.bit_length() - 1 # Store current evals for query phase all_round_evals.append(folded) # Build Merkle tree from pairs of evaluations. # Each leaf = hash of [lo_c0..lo_c3, hi_c0..hi_c3] (8 base field elements). leaves = ef4_pairs_to_leaves(folded) root, tree = build_merkle_tree(leaves) # Observe commitment challenger.observe_many(root) commits.append(root) trees.append(tree) # Per-round PoW pow_witness = grind(challenger, commit_pow_bits) commit_pow_witnesses.append(pow_witness) # Sample folding challenge beta = challenger.sample_ext() betas.append(beta.to_list()) # Fold folded = fold_matrix(folded, beta, log_height) # Roll in reduced openings at this folded height, if any. # This mirrors the verifier's: folded += beta^2 * reduced_opening # Reference: p3-fri verifier.rs verify_query (line ~310) if reduced_openings_by_height and log_height in reduced_openings_by_height: roll_in_data = reduced_openings_by_height[log_height] assert len(roll_in_data) == len(folded), ( f"Roll-in size mismatch at log_height {log_height}: " f"{len(roll_in_data)} vs {len(folded)}" ) beta_sq = beta * beta fv = FF4.from_rows(folded) rv = FF4.from_rows(roll_in_data) result = fv + rv * beta_sq folded = result.to_rows() folded_per_round.append(folded) # Compute final polynomial via IDFT: # Truncate to final_poly_len, bit-reverse, IDFT. final_evals = folded[:final_poly_len] final_evals_natural = bit_reverse_list(final_evals) final_poly = ef4_idft(final_evals_natural) # Observe final polynomial for coeff in final_poly: challenger.observe_many(coeff) return CommitPhaseResult( commits=commits, betas=betas, final_poly=final_poly, trees=trees, folded_per_round=folded_per_round, all_round_evals=all_round_evals, commit_pow_witnesses=commit_pow_witnesses, )
# --- Prover: Query Phase --- # <doc-anchor id="answer-query">
[docs] def answer_query( trees: list[list[list[Digest]]], all_round_evals: list[list[list[int]]], start_index: int, num_rounds: int, ) -> list[FriQueryStep]: """Generate FRI query opening for a single query index. Reference: p3-fri prover.rs (answer_query) """ openings = [] for i in range(num_rounds): index_i = start_index >> i index_i_sibling = index_i ^ 1 index_pair = index_i >> 1 # Get Merkle proof for the pair proof = get_opening_proof(trees[i], index_pair) # Sibling value from stored evaluations sibling_value = all_round_evals[i][index_i_sibling] openings.append(FriQueryStep( sibling_value=sibling_value, opening_proof=proof, )) return openings
# <doc-anchor id="prove-fri">
[docs] def prove_fri( evals_bit_reversed: list[list[int]], log_blowup: int, log_final_poly_len: int, num_queries: int, challenger: Challenger, ) -> CommitPhaseOutput: """Full FRI proof generation. Reference: p3-fri prover.rs (prove) """ # Commit phase result = commit_phase( evals_bit_reversed, log_blowup, log_final_poly_len, challenger ) # Query phase log_max_height = len(evals_bit_reversed).bit_length() - 1 num_rounds = len(result.commits) query_proofs = [] for _ in range(num_queries): query_index = challenger.sample_bits(log_max_height) openings = answer_query( result.trees, result.all_round_evals, query_index, num_rounds, ) query_proofs.append(FriQueryResult( index=query_index, commit_phase_openings=openings, )) return CommitPhaseOutput( commit_phase_commits=result.commits, final_poly=result.final_poly, query_proofs=query_proofs, betas=result.betas, folded_per_round=result.folded_per_round, )