"""PCS (Polynomial Commitment Scheme) for TwoAdicFriPcs.
Implements both prover and verifier sides of the FRI-based PCS.
Verifier: given committed polynomial batches opened at evaluation points,
reduces them to a single FRI verification problem.
Prover: commits polynomials via LDE + Merkle, evaluates at opening points,
computes reduced polynomials, and delegates to the FRI prover.
Reference:
p3-fri-0.4.1/src/two_adic_pcs.rs -- TwoAdicFriPcs::commit, open, verify
p3-fri-0.4.1/src/verifier.rs -- verify_fri, verify_query, open_input
p3-fri-0.4.1/src/prover.rs -- prove (FRI prover)
"""
from dataclasses import dataclass
import numpy as np
from primitives.field import (
BABYBEAR_PRIME,
Digest,
FF4,
Fe,
FF,
GENERATOR,
W,
bit_reverse_list,
eval_poly_ef4_batch,
get_omega,
inv_mod,
reverse_bits_len,
)
from primitives.merkle import get_opening_proof, verify_opening_prehashed
from primitives.ntt import coset_lde_batch as _coset_lde_batch_ffi, intt_batch as _intt_batch_ffi
from primitives.poseidon2 import compress, hash_to_digest, compress_batch, hash_batch
from primitives.transcript import Challenger, check_witness, grind
from protocol.domain import TwoAdicMultiplicativeCoset
from protocol.fri import answer_query, commit_phase as fri_commit_phase, fold_row, hash_fri_leaf
from protocol.proof import (
BatchOpening,
CommitPhaseProofStep,
FriParameters,
FriProof,
)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
[docs]
class PcsRound:
"""One round of PCS verification data.
Corresponds to a single MMCS commitment (batch of matrices) and the
domains/openings associated with it.
Reference:
p3-fri/src/two_adic_pcs.rs (CommitmentWithOpeningPoints type alias)
"""
[docs]
domains_and_openings: list[tuple[TwoAdicMultiplicativeCoset, list[tuple[list[int], list[list[int]]]]]]
# ---------------------------------------------------------------------------
# Merkle batch verification for input MMCS
# ---------------------------------------------------------------------------
# <doc-anchor id="verify-batch-opening">
def _verify_batch_opening(
commitment: Digest,
dimensions: list[int],
index: int,
batch_opening: BatchOpening,
) -> None:
"""Verify a batch Merkle opening for multiple matrices of varying heights.
This implements the MerkleTreeMmcs::verify_batch logic from Plonky3.
Matrices are sorted by height (tallest first). The Merkle tree has
log_max_height levels. Shorter matrices are "rolled in" at the
appropriate level by compressing their leaf hash with the running root.
Args:
commitment: The Merkle root commitment for this batch.
dimensions: Heights of each matrix in the batch (in original order).
index: The query index (at the tallest matrix's height).
batch_opening: The opened values and Merkle proof.
Reference:
p3-merkle-tree-0.4.1/src/mmcs.rs (MerkleTreeMmcs::verify_batch, lines 230-338)
"""
opened_values = batch_opening.opened_values
opening_proof = batch_opening.opening_proof
assert len(dimensions) == len(opened_values), (
f"batch size mismatch: {len(dimensions)} dimensions vs {len(opened_values)} openings"
)
if not dimensions:
raise ValueError("empty batch: no matrices to verify")
# Sort matrices by height descending, keeping original indices.
# This mirrors heights_tallest_first in the Rust code.
sorted_entries = sorted(
enumerate(dimensions), key=lambda x: -x[1]
)
max_height = sorted_entries[0][1]
curr_height_padded = _next_power_of_two(max_height)
log_max_height = curr_height_padded.bit_length() - 1
assert len(opening_proof) == log_max_height, (
f"wrong proof height: expected {log_max_height} siblings, got {len(opening_proof)}"
)
assert index < max_height, (
f"index {index} out of bounds for max_height {max_height}"
)
# Consume entries from sorted_entries using a pointer, mimicking
# Rust's peeking_take_while iterator pattern.
entry_ptr = 0
# Hash all matrix openings at the initial (tallest) padded height.
initial_group = []
while (entry_ptr < len(sorted_entries)
and _next_power_of_two(sorted_entries[entry_ptr][1]) == curr_height_padded):
initial_group.append(sorted_entries[entry_ptr])
entry_ptr += 1
root = _hash_matrix_rows(initial_group, opened_values)
current_index = index
for sibling in opening_proof:
# Combine current root with sibling based on index parity
if current_index & 1 == 0:
root = compress(root, sibling)
else:
root = compress(sibling, root)
current_index >>= 1
curr_height_padded >>= 1
# Check if there are new matrix rows to inject at the current level.
# We peek at the next entry and check if its padded height matches.
if (entry_ptr < len(sorted_entries)
and _next_power_of_two(sorted_entries[entry_ptr][1]) == curr_height_padded):
next_height = sorted_entries[entry_ptr][1]
# Collect all entries with this exact height (peeking_take_while)
inject_group = []
while (entry_ptr < len(sorted_entries)
and sorted_entries[entry_ptr][1] == next_height):
inject_group.append(sorted_entries[entry_ptr])
entry_ptr += 1
# Hash their rows and compress with current root
inject_digest = _hash_matrix_rows(inject_group, opened_values)
root = compress(root, inject_digest)
assert root == commitment, "Merkle root mismatch in batch verification"
def _next_power_of_two(n: int) -> int:
"""Round up to the next power of two (or n if already a power of two)."""
if n <= 0:
return 1
# Bit trick: if n is already a power of two, return n
if n & (n - 1) == 0:
return n
return 1 << n.bit_length()
def _hash_matrix_rows(
entries: list[tuple[int, int]],
opened_values: list[list[Fe]],
) -> Digest:
"""Hash the opened row values for a group of matrices.
Mimics PaddingFreeSponge::hash_iter_slices: concatenate all row data
from the matrices (in the order given by entries) and hash.
Args:
entries: List of (original_index, height) pairs, all at the same padded height.
opened_values: The full list of opened row values.
Returns:
A Poseidon2 digest of the concatenated row data.
Reference:
p3-symmetric PaddingFreeSponge::hash_iter_slices
"""
all_values: list[Fe] = []
for orig_idx, _ in entries:
all_values.extend(opened_values[orig_idx])
return hash_to_digest(all_values)
# ---------------------------------------------------------------------------
# Reduced opening computation (open_input)
# ---------------------------------------------------------------------------
# <doc-anchor id="reduce-to-fri-input">
def _open_input(
fri_params: FriParameters,
log_global_max_height: int,
index: int,
input_proof: list[BatchOpening],
alpha: FF4,
rounds: list[PcsRound],
) -> list[tuple[int, FF4]]:
"""Open input polynomials and combine into FRI reduced openings.
For each batch commitment and its opening proof:
1. Verify the Merkle batch opening
2. For each matrix, compute x = g * omega^(bit_reverse(index >> bits_reduced, log_height))
3. For each (point z, values) pair, compute (p(z) - p(x)) / (z - x) weighted by alpha
Returns reduced openings as (log_height, reduced_value) pairs, sorted
descending by log_height.
Reference:
p3-fri-0.4.1/src/verifier.rs (open_input, lines 343-455)
"""
# For each log_height, store (alpha_pow, reduced_opening)
reduced_openings: dict[int, tuple] = {} # log_height -> (FF4_alpha_pow, FF4_reduced)
alpha_one = FF4(1)
alpha_zero = FF4.zero()
assert len(input_proof) == len(rounds), (
f"input_proof length {len(input_proof)} != rounds length {len(rounds)}"
)
for batch_opening, round_data in zip(input_proof, rounds):
mats = round_data.domains_and_openings
# Compute the height of each matrix in the batch (domain size * blowup)
batch_heights = [
domain.size() << fri_params.log_blowup
for domain, _ in mats
]
# Compute the reduced index for this batch.
# If the max height of the batch is smaller than the global max height,
# shift the index down to compensate.
if batch_heights:
max_batch_height = max(batch_heights)
log_max_batch = max_batch_height.bit_length() - 1
reduced_index = index >> (log_global_max_height - log_max_batch)
else:
reduced_index = 0
# Verify the Merkle batch opening
_verify_batch_opening(
round_data.commitment,
batch_heights,
reduced_index,
batch_opening,
)
# For each matrix in the commitment
for mat_idx, (mat_domain, mat_points_and_values) in enumerate(mats):
mat_opening = batch_opening.opened_values[mat_idx]
log_height = mat_domain.log_n + fri_params.log_blowup
bits_reduced = log_global_max_height - log_height
rev_reduced_index = reverse_bits_len(index >> bits_reduced, log_height)
# x = GENERATOR * two_adic_generator(log_height)^rev_reduced_index
# This is the evaluation point on the LDE coset g*H
x = (GENERATOR * pow(get_omega(log_height), rev_reduced_index, p)) % p
# Get or initialize the reduced opening accumulator for this log_height
if log_height not in reduced_openings:
reduced_openings[log_height] = (alpha_one, alpha_zero)
alpha_pow_ef, ro_ef = reduced_openings[log_height]
# For each (point z, claimed values) pair
for z, ps_at_z in mat_points_and_values:
# quotient = 1 / (z - x) in extension field
z_ef = FF4(z)
x_ef = FF4(x)
quotient = (z_ef - x_ef) ** (-1)
# For each column value
for col_idx in range(len(mat_opening)):
p_at_x = FF4(mat_opening[col_idx])
p_at_z = FF4(ps_at_z[col_idx])
# ro += alpha_pow * (p_at_z - p_at_x) * quotient
ro_ef = ro_ef + alpha_pow_ef * (p_at_z - p_at_x) * quotient
alpha_pow_ef = alpha_pow_ef * alpha
reduced_openings[log_height] = (alpha_pow_ef, ro_ef)
# Check: if there's a reduced opening at log_height = log_blowup,
# the polynomial must be constant, so the reduced opening must be zero
if fri_params.log_blowup in reduced_openings:
_, ro_check = reduced_openings[fri_params.log_blowup]
assert ro_check.to_list() == [0, 0, 0, 0], (
"constant polynomial has non-zero reduced opening"
)
# Return reduced openings sorted descending by log_height
result = []
for log_h in sorted(reduced_openings.keys(), reverse=True):
_, ro = reduced_openings[log_h]
result.append((log_h, ro))
return result
# ---------------------------------------------------------------------------
# Query verification (verify_query)
# ---------------------------------------------------------------------------
# <doc-anchor id="verify-fri-query">
def _verify_query(
fri_params: FriParameters,
start_index: int,
betas: list[FF4],
commit_phase_commits: list[Digest],
commit_phase_openings: list[CommitPhaseProofStep],
reduced_openings: list[tuple[int, FF4]],
log_global_max_height: int,
log_final_height: int,
) -> tuple[FF4, int]:
"""Verify a single FRI query: fold chain with reduced openings rolled in.
Starting from the initial reduced opening at log_global_max_height,
performs FRI folds. At each step where a new reduced opening exists
(for a smaller log_height), it is rolled in using beta^2.
Returns (folded_eval, final_domain_index).
Reference:
p3-fri-0.4.1/src/verifier.rs (verify_query, lines 236-321)
"""
ro_peeked = list(reduced_openings)
ro_ptr = 0
# The first reduced opening must be at log_global_max_height
assert ro_ptr < len(ro_peeked), "no reduced openings"
first_lh, first_ro = ro_peeked[ro_ptr]
assert first_lh == log_global_max_height, (
f"first reduced opening at log_height {first_lh}, expected {log_global_max_height}"
)
folded_eval = first_ro
ro_ptr += 1
num_commit_rounds = len(commit_phase_commits)
assert len(betas) == num_commit_rounds
assert len(commit_phase_openings) == num_commit_rounds
domain_index = start_index
# Fold from log_global_max_height - 1 down to log_final_height
for round_idx in range(num_commit_rounds):
log_folded_height = log_global_max_height - 1 - round_idx
beta = betas[round_idx]
comm = commit_phase_commits[round_idx]
opening = commit_phase_openings[round_idx]
# Get the sibling value
index_sibling = domain_index ^ 1
sibling = FF4(opening.sibling_value)
# Arrange evals: evals[0] is at even position, evals[1] at odd
if index_sibling % 2 == 0:
e0, e1 = sibling, folded_eval
else:
e0, e1 = folded_eval, sibling
# Parent index
domain_index >>= 1
# Verify Merkle proof for the FRI commitment
# FRI commit phase uses pairs of extension field evaluations as leaves
# Leaf = hash([e0_c0..e0_c3, e1_c0..e1_c3])
leaf_digest = hash_fri_leaf(e0, e1)
assert verify_opening_prehashed(
comm,
leaf_digest,
domain_index,
opening.opening_proof,
), f"FRI commit phase Merkle proof failed at round {round_idx}"
# Fold
folded_eval = fold_row(domain_index, log_folded_height, beta, e0, e1)
# Roll in reduced openings at this folded height, if any
if ro_ptr < len(ro_peeked) and ro_peeked[ro_ptr][0] == log_folded_height:
_, ro_val = ro_peeked[ro_ptr]
ro_ptr += 1
# Use beta^2 as the random combination factor
folded_eval = folded_eval + beta * beta * ro_val
# All reduced openings must have been consumed
assert ro_ptr == len(ro_peeked), (
f"not all reduced openings consumed: {ro_ptr} of {len(ro_peeked)}"
)
return folded_eval, domain_index
# ---------------------------------------------------------------------------
# Main PCS verification
# ---------------------------------------------------------------------------
# <doc-anchor id="pcs-verify">
[docs]
def pcs_verify(
rounds: list[PcsRound],
fri_proof: FriProof,
challenger: Challenger,
fri_params: FriParameters,
) -> None:
"""Verify PCS openings using FRI.
This is the Python translation of TwoAdicFriPcs::verify followed by
verify_fri. The algorithm:
1. Sample FRI alpha from challenger (batch combination challenge)
2. Replay FRI commit phase: for each commitment, observe it, check PoW,
then sample beta
3. Observe final polynomial
4. Check query PoW
5. For each query:
a. Sample query index
b. Compute reduced openings from the input proof (open_input)
c. Verify FRI query (fold chain + Merkle proofs + final poly check)
Args:
rounds: List of PcsRound, one per batch commitment. Each contains
the commitment digest and the domains/openings for that batch.
fri_proof: The FRI proof containing commit phase commits, query proofs,
final polynomial, and PoW witnesses.
challenger: The Fiat-Shamir challenger (already has opened values observed).
fri_params: FRI protocol parameters.
Raises:
AssertionError: If any verification check fails.
Reference:
p3-fri-0.4.1/src/two_adic_pcs.rs (TwoAdicFriPcs::verify, lines 523-554)
p3-fri-0.4.1/src/verifier.rs (verify_fri, lines 54-204)
"""
# --- Step 0: Observe all opened values (claimed polynomial evaluations) ---
# The PCS verifier observes all evaluation claims into the transcript
# before sampling the batch combination challenge.
# Reference: two_adic_pcs.rs TwoAdicFriPcs::verify, lines 533-541
for round_data in rounds:
for _domain, points_and_values in round_data.domains_and_openings:
for _point, values in points_and_values:
for val in values:
challenger.observe_many(val)
# --- Step 1: Sample batch combination challenge (alpha) ---
alpha = challenger.sample_ext()
# --- Step 2: Compute log_global_max_height ---
# log_global_max_height = num_commit_rounds + log_blowup + log_final_poly_len
num_commit_rounds = len(fri_proof.commit_phase_commits)
log_global_max_height = (
num_commit_rounds + fri_params.log_blowup + fri_params.log_final_poly_len
)
# --- Step 3: Validate proof shape ---
assert len(fri_proof.commit_pow_witnesses) == num_commit_rounds, (
f"commit_pow_witnesses length {len(fri_proof.commit_pow_witnesses)} != "
f"commit_phase_commits length {num_commit_rounds}"
)
# --- Step 4: Replay commit phase — observe roots, check PoW, sample betas ---
betas = []
for round_idx in range(num_commit_rounds):
comm = fri_proof.commit_phase_commits[round_idx]
pow_witness = fri_proof.commit_pow_witnesses[round_idx]
# Observe the commitment
challenger.observe_many(comm)
# Check per-round PoW
assert check_witness(
challenger, fri_params.commit_proof_of_work_bits, pow_witness
), f"commit phase PoW failed at round {round_idx}"
# Sample folding challenge beta
beta = challenger.sample_ext()
betas.append(beta)
# --- Step 5: Validate final polynomial length ---
final_poly_len = 1 << fri_params.log_final_poly_len
assert len(fri_proof.final_poly) == final_poly_len, (
f"final poly length {len(fri_proof.final_poly)} != expected {final_poly_len}"
)
# --- Step 6: Observe final polynomial ---
for coeff in fri_proof.final_poly:
challenger.observe_many(coeff)
# --- Step 7: Validate number of query proofs ---
assert len(fri_proof.query_proofs) == fri_params.num_queries, (
f"expected {fri_params.num_queries} query proofs, "
f"got {len(fri_proof.query_proofs)}"
)
# --- Step 8: Check query-phase PoW ---
assert check_witness(
challenger, fri_params.query_proof_of_work_bits, fri_proof.query_pow_witness
), "query phase PoW failed"
# --- Step 9: Process each query ---
log_final_height = fri_params.log_blowup + fri_params.log_final_poly_len
for qi, query_proof in enumerate(fri_proof.query_proofs):
# Sample query index (extra_query_index_bits = 0 for TwoAdicFriFolding)
query_index = challenger.sample_bits(log_global_max_height)
# Compute reduced openings from the input proof
reduced_openings = _open_input(
fri_params,
log_global_max_height,
query_index,
query_proof.input_proof,
alpha,
rounds,
)
# Verify FRI query: fold chain + Merkle proofs
folded_eval, domain_index = _verify_query(
fri_params,
query_index, # extra_query_index_bits = 0, so no shifting needed
betas,
fri_proof.commit_phase_commits,
query_proof.commit_phase_openings,
reduced_openings,
log_global_max_height,
log_final_height,
)
# --- Step 10: Verify final polynomial evaluation ---
# The final polynomial is evaluated at x where
# x = two_adic_generator(log_global_max_height)^reverse_bits_len(domain_index, log_global_max_height)
#
# However, after the fold chain, domain_index has been shifted down by
# num_commit_rounds. The remaining index corresponds to the final domain.
# We evaluate at x = W[log_global_max_height]^reverse_bits_len(domain_index, log_global_max_height)
#
# But wait: domain_index after folding has log_final_height bits.
# The Rust code uses log_global_max_height for the bit reversal, but domain_index
# only has log_final_height significant bits at this point. Let's match the Rust exactly.
x_base = pow(
W[log_global_max_height],
reverse_bits_len(domain_index, log_global_max_height),
p,
)
# Evaluate the final polynomial at x using Horner's method
# final_poly is in coefficient form: f(x) = c0 + c1*x + c2*x^2 + ...
eval_result = FF4.zero()
for coeff in reversed(fri_proof.final_poly):
eval_result = eval_result * FF4(x_base) + FF4(coeff)
assert folded_eval.to_list() == eval_result.to_list(), (
f"Query {qi}: final polynomial mismatch: "
f"folded={folded_eval.to_list()}, expected={eval_result.to_list()}"
)
# ===========================================================================
# PCS Prover
# ===========================================================================
@dataclass
[docs]
class CommittedData:
"""Prover data for a single MMCS commitment (batch of matrices).
Reference:
p3-merkle-tree MerkleTreeMmcs::commit
"""
[docs]
tree: list[list[Digest]] # Merkle tree levels
# Per-matrix LDE rows (bit-reversed): [mat][row][col]
[docs]
lde_rows: list[list[list[Fe]]]
# Per-matrix coefficient form: [mat][col][coeff]
[docs]
coeffs: list[list[list[Fe]]]
# Per-matrix domain
[docs]
domains: list[TwoAdicMultiplicativeCoset]
# ---------------------------------------------------------------------------
# PCS commit
# ---------------------------------------------------------------------------
# <doc-anchor id="pcs-commit">
[docs]
def pcs_commit(
evaluations: list[tuple[TwoAdicMultiplicativeCoset, list[list[Fe]]]],
log_blowup: int,
) -> CommittedData:
"""Commit to a batch of matrices via LDE + Merkle tree.
For each matrix on its domain:
1. INTT each column to coefficients (treating as subgroup evaluations).
2. Zero-pad to LDE size (n * 2^log_blowup).
3. Coset-shift coefficients: c'[j] = c[j] * (GENERATOR / domain.shift)^j.
4. NTT on the extended subgroup.
5. Bit-reverse the rows.
Then build a single Merkle tree from the concatenated rows.
Args:
evaluations: List of (domain, matrix) where matrix is list of rows.
log_blowup: Log2 of the FRI blowup factor.
Returns:
CommittedData with Merkle tree, LDE, and coefficients.
Reference:
p3-fri-0.4.1/src/two_adic_pcs.rs TwoAdicFriPcs::commit
"""
all_lde_rows: list[list[list[Fe]]] = []
all_coeffs: list[list[list[Fe]]] = []
all_domains: list[TwoAdicMultiplicativeCoset] = []
for domain, matrix in evaluations:
n = domain.size()
assert len(matrix) == n
n_ext = n << log_blowup
num_cols = len(matrix[0]) if matrix else 0
# LDE shift = GENERATOR / domain.shift
lde_shift = (GENERATOR * inv_mod(domain.shift)) % p
# Extract columns
columns = [[matrix[r][c] for r in range(n)] for c in range(num_cols)]
# Compute coefficients via batch INTT (Rust FFI, parallelized)
coeffs_columns = _intt_batch_ffi(columns) if columns else []
coeffs_per_col = [list(c) for c in coeffs_columns]
# Coset LDE via Rust FFI (INTT + pad + shift + NTT, all parallelized)
# Note: coset_lde_batch does its own INTT internally, so we pass evals
lde_columns = list(_coset_lde_batch_ffi(columns, lde_shift, log_blowup)) if columns else []
# Transpose to rows and bit-reverse
lde_matrix = [
[lde_columns[c][r] for c in range(num_cols)]
for r in range(n_ext)
]
lde_matrix = bit_reverse_list(lde_matrix)
all_lde_rows.append(lde_matrix)
all_coeffs.append(coeffs_per_col)
all_domains.append(domain)
# Build Merkle tree using Plonky3's multi-height MMCS protocol.
# Tallest matrices go at the leaf level; shorter matrices are "injected"
# at higher tree levels via compress(node, hash(shorter_rows)).
# This matches _verify_batch_opening (p3-merkle-tree mmcs.rs verify_batch).
if not all_lde_rows:
return CommittedData(
root=[0] * 8, tree=[], lde_rows=[], coeffs=[], domains=[]
)
root, tree = _build_mmcs_tree(all_lde_rows)
return CommittedData(
root=root,
tree=tree,
lde_rows=all_lde_rows,
coeffs=all_coeffs,
domains=all_domains,
)
# <doc-anchor id="build-mmcs-tree">
def _build_mmcs_tree(
all_lde_rows: list[list[list[Fe]]],
) -> tuple[Digest, list[list[Digest]]]:
"""Build multi-height MMCS Merkle tree matching Plonky3.
Tallest matrices are hashed at the leaf level. Shorter matrices are
"injected" at higher levels by compressing their row hashes with tree
nodes. This mirrors the verify_batch logic in _verify_batch_opening.
Reference:
p3-merkle-tree-0.4.1/src/mmcs.rs MerkleTreeMmcs::commit
"""
# Group matrices by padded height, sorted tallest first.
heights = [(i, len(rows)) for i, rows in enumerate(all_lde_rows)]
sorted_entries = sorted(heights, key=lambda x: -x[1])
max_height = sorted_entries[0][1]
curr_height_padded = _next_power_of_two(max_height)
# Collect tallest group (all matrices whose padded height == curr_height_padded)
entry_ptr = 0
tallest_group: list[int] = [] # matrix indices
while (entry_ptr < len(sorted_entries)
and _next_power_of_two(sorted_entries[entry_ptr][1]) == curr_height_padded):
tallest_group.append(sorted_entries[entry_ptr][0])
entry_ptr += 1
# Hash leaf rows from tallest matrices only (batch parallel)
leaf_inputs: list[list[Fe]] = []
for row_idx in range(max_height):
leaf_data: list[Fe] = []
for mat_idx in tallest_group:
leaf_data.extend(all_lde_rows[mat_idx][row_idx])
leaf_inputs.append(leaf_data)
# Pad to power-of-two if needed
pad_count = curr_height_padded - max_height
for _ in range(pad_count):
leaf_inputs.append([])
leaf_digests: list[Digest] = hash_batch(leaf_inputs)
# Build tree bottom-up, injecting shorter matrices at appropriate levels
tree_levels: list[list[Digest]] = [leaf_digests]
current_level = leaf_digests
level_height = curr_height_padded
while level_height > 1:
# Pair siblings to form parent level (batch parallel)
lefts = current_level[0::2]
rights = current_level[1::2]
next_level: list[Digest] = compress_batch(lefts, rights)
level_height >>= 1
# Check if shorter matrices should be injected at this level
if (entry_ptr < len(sorted_entries)
and _next_power_of_two(sorted_entries[entry_ptr][1]) == level_height):
inject_height = sorted_entries[entry_ptr][1]
inject_group: list[int] = [] # matrix indices
while (entry_ptr < len(sorted_entries)
and sorted_entries[entry_ptr][1] == inject_height):
inject_group.append(sorted_entries[entry_ptr][0])
entry_ptr += 1
# Batch: collect all inject data, hash in parallel, then compress
inject_inputs: list[list[Fe]] = []
inject_positions: list[int] = []
for pos in range(len(next_level)):
row_idx = pos if pos < inject_height else pos % inject_height
inject_data: list[Fe] = []
for mat_idx in inject_group:
if row_idx < len(all_lde_rows[mat_idx]):
inject_data.extend(all_lde_rows[mat_idx][row_idx])
if inject_data:
inject_inputs.append(inject_data)
inject_positions.append(pos)
if inject_inputs:
inject_digests = hash_batch(inject_inputs)
# Batch compress: node with inject_digest
node_lefts = [next_level[pos] for pos in inject_positions]
inject_results = compress_batch(node_lefts, inject_digests)
for pos, result in zip(inject_positions, inject_results):
next_level[pos] = result
current_level = next_level
tree_levels.append(current_level)
root = list(current_level[0]) if current_level else hash_to_digest([])
return root, tree_levels
# ---------------------------------------------------------------------------
# Polynomial evaluation from coefficients
# ---------------------------------------------------------------------------
def _eval_poly_ef4(
coeffs: list[Fe],
point: list[int],
domain_shift: Fe,
) -> list[int]:
"""Evaluate polynomial at an extension field point.
The coefficients come from INTT on the domain subgroup. The original
polynomial p(x) satisfies p(domain_shift * omega^k) = evals[k].
So the INTT polynomial f(x) = p(domain_shift * x) and
p(point) = f(point / domain_shift).
Args:
coeffs: INTT coefficients of one column.
point: Extension field evaluation point.
domain_shift: Domain's coset shift.
Returns:
p(point) as list[int].
Reference:
Horner's method over BinomialExtensionField
"""
# eval_point = point / domain_shift
if domain_shift == 1:
eval_point = FF4(point)
else:
eval_point = FF4(point) * FF4(inv_mod(domain_shift))
result = FF4.zero()
for i in range(len(coeffs) - 1, -1, -1):
result = result * eval_point + FF4(coeffs[i])
return result.to_list()
def _compute_x_array(log_height: int, x_arrays: dict[int, FF]) -> FF:
"""Compute and cache bit-reversed domain points x_i for a given log_height.
Args:
log_height: Log2 of the LDE height.
x_arrays: Mutable cache mapping log_height to precomputed x arrays.
Returns:
Array of x_i = GENERATOR * omega^{bit_reverse(i)} for i in [0, 2^log_height).
"""
if log_height not in x_arrays:
height = 1 << log_height
omega = get_omega(log_height)
x_arr = FF([
(GENERATOR * pow(omega, reverse_bits_len(i, log_height), p)) % p
for i in range(height)
])
x_arrays[log_height] = x_arr
return x_arrays[log_height]
def _compute_inv_diff(
log_height: int,
point: list[int],
x_arrays: dict[int, FF],
inv_diff_cache: dict[tuple, FF4],
) -> FF4:
"""Compute and cache (z - x_i)^{-1} for all domain points x_i.
Args:
log_height: Log2 of the LDE height.
point: Extension field evaluation point z.
x_arrays: Mutable cache for precomputed x arrays (passed to _compute_x_array).
inv_diff_cache: Mutable cache mapping (log_height, point) to inverse differences.
Returns:
FF4 column of (z - x_i)^{-1} for each domain point x_i.
"""
key = (log_height, tuple(point))
if key not in inv_diff_cache:
height = 1 << log_height
x_arr = _compute_x_array(log_height, x_arrays)
z_v = FF4.broadcast(FF4(point), height)
diff = z_v - FF4.from_base(x_arr)
inv_diff_cache[key] = diff.inv()
return inv_diff_cache[key]
# ---------------------------------------------------------------------------
# PCS open (prover side)
# ---------------------------------------------------------------------------
@dataclass
[docs]
class PcsOpeningRound:
"""One round of PCS opening data for the prover.
Reference:
p3-fri-0.4.1/src/two_adic_pcs.rs TwoAdicFriPcs::open
"""
[docs]
committed: CommittedData
# Per-matrix list of opening points
[docs]
points_per_mat: list[list[list[int]]]
# <doc-anchor id="pcs-open">
[docs]
def pcs_open(
rounds: list[PcsOpeningRound],
challenger: Challenger,
fri_params: FriParameters,
) -> tuple[list[list[list[list[list[int]]]]], dict, list[int]]:
"""Open committed polynomials at specified points.
Prover-side PCS open: evaluates polynomials, computes reduced
polynomials, runs FRI, and generates Merkle openings.
Args:
rounds: Per-commitment opening data.
challenger: The Fiat-Shamir challenger (quotient already observed).
fri_params: FRI protocol parameters.
Returns:
(all_opened_values, fri_proof_data, query_indices) where:
- all_opened_values[round][mat][point] = list[list[int]]
- fri_proof_data contains the FRI proof components
- query_indices for Merkle opening generation
Reference:
p3-fri-0.4.1/src/two_adic_pcs.rs TwoAdicFriPcs::open lines 286-440
"""
# --- Step A: Evaluate polynomials at opening points ---
all_opened_values: list[list[list[list[list[int]]]]] = []
for rnd_idx, rnd in enumerate(rounds):
round_values: list[list[list[list[int]]]] = []
for mat_idx in range(len(rnd.committed.domains)):
domain = rnd.committed.domains[mat_idx]
coeffs_per_col = rnd.committed.coeffs[mat_idx]
points = rnd.points_per_mat[mat_idx]
degree = len(coeffs_per_col[0]) if coeffs_per_col else 0
mat_values: list[list[list[int]]] = []
for point in points:
# Compute eval_point = point / domain_shift
if domain.shift == 1:
eval_pt = list(point)
else:
eval_pt = (
FF4(point) * FF4(inv_mod(domain.shift))
).to_list()
if degree >= 256 and len(coeffs_per_col) > 0:
col_values = eval_poly_ef4_batch(
coeffs_per_col, eval_pt,
)
else:
col_values = [
_eval_poly_ef4(col_coeffs, point, domain.shift)
for col_coeffs in coeffs_per_col
]
mat_values.append(col_values)
round_values.append(mat_values)
all_opened_values.append(round_values)
# --- Step B: Observe all opened values ---
for round_values in all_opened_values:
for mat_values in round_values:
for point_values in mat_values:
for val in point_values:
challenger.observe_many(val)
# --- Step C: Sample FRI alpha ---
alpha = challenger.sample_ext()
# --- Step D: Compute reduced polynomials per height ---
# Group by LDE height, accumulate alpha-weighted quotients.
reduced_evals_np: dict[int, FF4] = {} # log_height → FF4 column
# Cache (z - x_i)^{-1} per (log_height, z_tuple) to avoid recomputation
inv_diff_cache: dict[tuple, FF4] = {}
# Pre-compute bit-reversed domain points x_i per log_height
x_arrays: dict[int, FF] = {}
# <doc-anchor id="per-height-alpha">
# Per-height alpha_pow accumulators (matching Rust's num_reduced[log_height]).
# Each height independently tracks alpha^k for its k-th column.
# Reference: p3-fri two_adic_pcs.rs lines 226,253,271
alpha_pow_per_height: dict[int, FF4] = {}
for rnd_idx, rnd in enumerate(rounds):
for mat_idx in range(len(rnd.committed.domains)):
domain = rnd.committed.domains[mat_idx]
lde_rows = rnd.committed.lde_rows[mat_idx]
points = rnd.points_per_mat[mat_idx]
log_height = domain.log_n + fri_params.log_blowup
height = 1 << log_height
if log_height not in reduced_evals_np:
reduced_evals_np[log_height] = FF4.zeros(height)
if log_height not in alpha_pow_per_height:
alpha_pow_per_height[log_height] = FF4(1)
num_cols = len(lde_rows[0]) if lde_rows else 0
mat_opened_values = all_opened_values[rnd_idx][mat_idx]
# Pre-extract LDE columns as field arrays for this matrix
lde_cols_np = [
FF([lde_rows[i][c] for i in range(height)])
for c in range(num_cols)
]
for pt_idx, point in enumerate(points):
inv_diff = _compute_inv_diff(log_height, point, x_arrays, inv_diff_cache)
point_values = mat_opened_values[pt_idx]
for col_idx in range(num_cols):
# p_at_z is scalar FF4, p_at_x is base field array
p_at_z_v = FF4.broadcast(FF4(point_values[col_idx]), height)
p_at_x_v = FF4.from_base(lde_cols_np[col_idx])
# (p_at_z - p_at_x) * inv_diff
quotient = (p_at_z_v - p_at_x_v) * inv_diff
# alpha_pow * quotient (per-height alpha_pow)
scaled = quotient * alpha_pow_per_height[log_height]
# Accumulate
reduced_evals_np[log_height] = reduced_evals_np[log_height] + scaled
# Advance this height's alpha_pow
alpha_pow_per_height[log_height] = (
alpha_pow_per_height[log_height] * alpha
)
# Convert FF4 columns back to list-of-list[int] for FRI
reduced_evals: dict[int, list[list[int]]] = {}
for log_h, ev in reduced_evals_np.items():
reduced_evals[log_h] = ev.to_rows()
# --- Step E: FRI prove ---
# Collect reduced evaluations in descending height order
sorted_heights = sorted(reduced_evals.keys(), reverse=True)
log_global_max_height = sorted_heights[0] if sorted_heights else 0
# Tallest height goes directly to FRI; shorter heights are rolled in
reduced_br = reduced_evals[sorted_heights[0]]
reduced_openings_by_height = None
if len(sorted_heights) > 1:
reduced_openings_by_height = {
log_h: reduced_evals[log_h]
for log_h in sorted_heights[1:]
}
# FRI commit phase
fri_result = fri_commit_phase(
reduced_br,
fri_params.log_blowup,
fri_params.log_final_poly_len,
challenger,
fri_params.commit_proof_of_work_bits,
reduced_openings_by_height=reduced_openings_by_height,
)
# Query PoW
query_pow_witness = grind(challenger, fri_params.query_proof_of_work_bits)
# --- Step F: Query phase ---
num_fri_rounds = len(fri_result.commits)
query_indices: list[int] = []
fri_query_proofs: list[tuple[int, list]] = []
for _ in range(fri_params.num_queries):
query_index = challenger.sample_bits(log_global_max_height)
query_indices.append(query_index)
# FRI commit phase openings
fri_openings = answer_query(
fri_result.trees,
fri_result.all_round_evals,
query_index,
num_fri_rounds,
)
fri_query_proofs.append((query_index, fri_openings))
# Package FRI proof data
fri_proof_data = {
"commit_phase_commits": fri_result.commits,
"final_poly": fri_result.final_poly,
"commit_pow_witnesses": fri_result.commit_pow_witnesses,
"query_pow_witness": query_pow_witness,
"fri_query_proofs": fri_query_proofs,
}
return all_opened_values, fri_proof_data, query_indices
# ---------------------------------------------------------------------------
# Batch opening generation for query proofs
# ---------------------------------------------------------------------------
def _generate_batch_opening(
committed: CommittedData,
query_index: int,
log_global_max_height: int,
) -> BatchOpening:
"""Generate a BatchOpening for a single query index.
Extracts the row values and Merkle proof from the committed data.
Args:
committed: The committed batch data.
query_index: The global query index.
log_global_max_height: Log2 of the global max LDE height.
Returns:
BatchOpening with row values and Merkle proof.
Reference:
p3-merkle-tree mmcs.rs MerkleTreeMmcs::open_batch
"""
# Determine the max height of matrices in this batch
max_mat_height = max(len(rows) for rows in committed.lde_rows)
log_max_mat = max_mat_height.bit_length() - 1
# Reduce query index to this batch's height
reduced_index = query_index >> (log_global_max_height - log_max_mat)
# Extract row values for each matrix
opened_values: list[list[Fe]] = []
for mat_idx in range(len(committed.lde_rows)):
mat_height = len(committed.lde_rows[mat_idx])
if mat_height == max_mat_height:
row = committed.lde_rows[mat_idx][reduced_index]
else:
# For shorter matrices, further reduce the index
log_mat = mat_height.bit_length() - 1
mat_index = reduced_index >> (log_max_mat - log_mat)
row = committed.lde_rows[mat_idx][mat_index]
opened_values.append(list(row))
# Merkle proof at the reduced index
proof = get_opening_proof(committed.tree, reduced_index)
return BatchOpening(
opened_values=opened_values,
opening_proof=proof,
)