"""BabyBear field GF(p) and quartic extension GF(p^4).
Uses galois library for base field GF(p) arithmetic.
Uses custom FF4 class for fast quartic extension field arithmetic.
Type Discipline
---------------
FF (galois GF(p)):
Base field columns and scalars. Fast jit-compiled numpy operations.
FF4:
Extension field GF(p^4) scalars and columns. Stores coefficients as int64 arrays.
Supports operator overloads: +, -, *, /, **, negation.
Scalars: shape (4,). Columns: shape (N, 4).
Coefficients in ascending degree: c0 + c1*x + c2*x^2 + c3*x^3 as [c0, c1, c2, c3].
"""
import pickle
from pathlib import Path
import galois
import numpy as np
# --- Field Construction ---
[docs]
BABYBEAR_PRIME = 2013265921 # 2^31 - 2^27 + 1
[docs]
FIELD_EXTENSION_DEGREE = 4
[docs]
DIGEST_WIDTH = 8 # Poseidon2 digest width in BabyBear elements
[docs]
TWO_ADICITY = 27 # p - 1 = 2^27 * 15
[docs]
GENERATOR = 31 # Multiplicative generator (Val::GENERATOR in Plonky3)
[docs]
TWO_INV = pow(2, BABYBEAR_PRIME - 2, BABYBEAR_PRIME) # Multiplicative inverse of 2
[docs]
MONTY_R = pow(2, 32, BABYBEAR_PRIME) # Montgomery form multiplier: 2^32 mod p = 268435454
[docs]
MONTY_RINV = pow(MONTY_R, BABYBEAR_PRIME - 2, BABYBEAR_PRIME) # Montgomery inverse
[docs]
FF = galois.GF(BABYBEAR_PRIME)
"""Base field GF(p) - BabyBear prime field."""
# Load galois FF4 from cache (used only for test backward compat and JSON roundtrip).
_FF4_CACHE_PATH = Path(__file__).parent / "ff4_cache.pkl"
def _build_ff4():
"""Construct the quartic extension field GF(p^4) with irreducible polynomial x^4 - 11."""
irr = galois.Poly([1, 0, 0, 0, BABYBEAR_PRIME - 11], field=FF)
return galois.GF(BABYBEAR_PRIME, 4, irreducible_poly=irr)
def _regenerate_ff4_cache() -> None:
"""Regenerate the FF4 cache file. Only needed if galois version changes."""
ff4_field = _build_ff4()
with open(_FF4_CACHE_PATH, "wb") as f:
pickle.dump(ff4_field, f)
if _FF4_CACHE_PATH.exists():
with open(_FF4_CACHE_PATH, "rb") as _f:
_GaloisFF4 = pickle.load(_f)
else:
_GaloisFF4 = _build_ff4()
_regenerate_ff4_cache()
# --- Internal constants ---
_P = BABYBEAR_PRIME
_P2 = _P * _P
_P3 = _P2 * _P
_W_EXT = 11 # Extension polynomial constant: x^4 = 11
# --- FF4: Fast Extension Field GF(p^4) ---
def _modpow_arr(base: np.ndarray, exp: int) -> np.ndarray:
"""Vectorized modular exponentiation for int64 arrays."""
result = np.ones_like(base)
base = base % _P
while exp > 0:
if exp & 1:
result = (result * base) % _P
base = (base * base) % _P
exp >>= 1
return result
[docs]
class FF4:
"""Extension field GF(p^4) = GF(p)[x]/(x^4 - 11) element or column.
Internal storage: int64 ndarray of shape (4,) for scalar or (N, 4) for column.
Coefficients in ascending degree: c0 + c1*x + c2*x^2 + c3*x^3 stored as [c0, c1, c2, c3].
Supports operator overloads for clean mathematical notation:
c = a + b # addition
c = a - b # subtraction
c = a * b # polynomial multiplication
c = a / b # division (multiply by inverse)
c = a ** n # exponentiation (n can be -1)
c = -a # negation
"""
__slots__ = ('_d',)
def __init__(self, data=None):
if isinstance(data, FF4):
self._d = data._d.copy()
elif isinstance(data, int):
# Embed base field element as (val, 0, 0, 0)
self._d = np.array([data % _P, 0, 0, 0], dtype=np.int64)
elif isinstance(data, np.integer):
self._d = np.array([int(data) % _P, 0, 0, 0], dtype=np.int64)
elif isinstance(data, (list, tuple)):
arr = np.asarray(data, dtype=np.int64)
if arr.ndim == 1 and arr.shape[0] == 4:
self._d = arr % _P
elif arr.ndim == 2 and arr.shape[1] == 4:
self._d = arr % _P
elif arr.ndim == 1 and arr.shape[0] != 4:
raise ValueError(f"FF4 1D array must have 4 elements, got {arr.shape[0]}")
else:
raise ValueError(f"FF4 array must be shape (4,) or (N,4), got {arr.shape}")
elif hasattr(data, 'vector'):
# galois FieldArray — extract polynomial coefficients.
# Must check before np.ndarray since FieldArray is a subclass.
# galois uses highest-degree-first; we use lowest-degree-first.
v = np.array(data.vector().view(np.ndarray), dtype=np.int64)
if v.ndim == 1:
self._d = v[::-1].copy() % _P
elif v.ndim == 2:
self._d = v[:, ::-1].copy() % _P
else:
raise ValueError(f"Unexpected galois vector shape: {v.shape}")
elif isinstance(data, np.ndarray):
if data.ndim == 1 and data.shape[0] == 4:
self._d = data.astype(np.int64) % _P
elif data.ndim == 2 and data.shape[1] == 4:
self._d = data.astype(np.int64) % _P
else:
raise ValueError(f"FF4 array must be shape (4,) or (N,4), got {data.shape}")
elif data is None:
self._d = np.array([0, 0, 0, 0], dtype=np.int64)
else:
raise TypeError(f"Cannot create FF4 from {type(data)}")
@classmethod
def _wrap(cls, data: np.ndarray) -> 'FF4':
"""Wrap raw int64 array without copying or mod. Internal use only."""
obj = object.__new__(cls)
obj._d = data
return obj
# --- Constructors ---
@classmethod
[docs]
def zeros(cls, n: int) -> 'FF4':
"""Zero column of length n."""
return cls._wrap(np.zeros((n, 4), dtype=np.int64))
@classmethod
[docs]
def one(cls) -> 'FF4':
"""Multiplicative identity scalar."""
return cls._wrap(np.array([1, 0, 0, 0], dtype=np.int64))
@classmethod
[docs]
def zero(cls) -> 'FF4':
"""Additive identity scalar."""
return cls._wrap(np.array([0, 0, 0, 0], dtype=np.int64))
@classmethod
[docs]
def from_base(cls, base) -> 'FF4':
"""Lift base field values to FF4 as (val, 0, 0, 0).
Args:
base: int, FF array, list[int], or ndarray of base field elements.
"""
if isinstance(base, (int, np.integer)):
return cls._wrap(np.array([int(base) % _P, 0, 0, 0], dtype=np.int64))
if isinstance(base, np.ndarray):
vals = base.view(np.ndarray).ravel().astype(np.int64) % _P
d = np.zeros((len(vals), 4), dtype=np.int64)
d[:, 0] = vals
return cls._wrap(d)
if isinstance(base, list):
vals = np.array(base, dtype=np.int64) % _P
d = np.zeros((len(vals), 4), dtype=np.int64)
d[:, 0] = vals
return cls._wrap(d)
raise TypeError(f"Cannot create FF4.from_base from {type(base)}")
@classmethod
[docs]
def from_rows(cls, rows: list) -> 'FF4':
"""Create from list of [c0, c1, c2, c3] coefficient lists."""
return cls(np.array(rows, dtype=np.int64))
@classmethod
[docs]
def broadcast(cls, scalar: 'FF4', n: int) -> 'FF4':
"""Broadcast a scalar FF4 to a column of length n."""
if isinstance(scalar, cls):
coeffs = scalar._d if scalar._d.ndim == 1 else scalar._d[0]
elif isinstance(scalar, (list, tuple)):
coeffs = np.array(scalar, dtype=np.int64) % _P
else:
raise TypeError(f"Cannot broadcast {type(scalar)}")
return cls._wrap(np.tile(coeffs, (n, 1)))
# --- Properties ---
@property
[docs]
def is_scalar(self) -> bool:
return self._d.ndim == 1
@property
[docs]
def coeffs(self) -> tuple:
"""Return coefficients as tuple (c0, c1, c2, c3) for a scalar."""
if self._d.ndim == 1:
return tuple(int(x) for x in self._d)
raise ValueError("coeffs only available for scalars, use to_rows() for columns")
@property
[docs]
def c0(self) -> int:
"""Constant coefficient (base field component)."""
if self._d.ndim == 1:
return int(self._d[0])
raise ValueError("c0 only available for scalars")
# --- Access ---
def __len__(self) -> int:
if self._d.ndim == 1:
return 1
return self._d.shape[0]
def __getitem__(self, key):
if self._d.ndim == 1:
if isinstance(key, (int, np.integer)):
# Return individual coefficient
return int(self._d[key])
raise IndexError(f"Invalid index for scalar FF4: {key}")
result = self._d[key]
if result.ndim == 1:
return FF4._wrap(result.copy())
return FF4._wrap(result)
def __setitem__(self, key, value):
if self._d.ndim == 1:
raise IndexError("Cannot set items on scalar FF4")
if isinstance(value, FF4):
self._d[key] = value._d
else:
raise TypeError(f"Can only assign FF4, got {type(value)}")
[docs]
def to_rows(self) -> list:
"""Convert to list of [c0, c1, c2, c3] coefficient lists."""
if self._d.ndim == 1:
return [self._d.tolist()]
return self._d.tolist()
[docs]
def to_list(self) -> list:
"""Convert scalar to [c0, c1, c2, c3] list (backward compat with FF4Coeffs)."""
if self._d.ndim == 1:
return self._d.tolist()
raise ValueError("to_list() only for scalars")
# --- Arithmetic ---
def __add__(self, other):
if isinstance(other, FF4):
return FF4._wrap((self._d + other._d) % _P)
return NotImplemented
def __radd__(self, other):
if other == 0: # Support sum()
return self
if isinstance(other, FF4):
return other.__add__(self)
return NotImplemented
def __sub__(self, other):
if isinstance(other, FF4):
return FF4._wrap((self._d - other._d) % _P)
return NotImplemented
def __rsub__(self, other):
if isinstance(other, FF4):
return other.__sub__(self)
return NotImplemented
def __neg__(self):
d = self._d
result = np.where(d == 0, d, _P - d)
return FF4._wrap(result)
def __mul__(self, other):
if isinstance(other, FF4):
return _ef4_mul(self._d, other._d)
if isinstance(other, (int, np.integer)):
return FF4._wrap((self._d * (int(other) % _P)) % _P)
return NotImplemented
def __rmul__(self, other):
if isinstance(other, (int, np.integer)):
return FF4._wrap((self._d * (int(other) % _P)) % _P)
return NotImplemented
def __truediv__(self, other):
if isinstance(other, FF4):
return self * other.inv()
return NotImplemented
def __pow__(self, exp):
if exp == -1:
return self.inv()
if exp == 0:
if self._d.ndim == 1:
return FF4.one()
return FF4._wrap(np.tile(np.array([1, 0, 0, 0], dtype=np.int64),
(self._d.shape[0], 1)))
if exp == 1:
return FF4._wrap(self._d.copy())
if exp < 0:
return self.inv() ** (-exp)
# Square-and-multiply
result = self ** 0 # identity
base = FF4._wrap(self._d.copy())
n = exp
while n > 0:
if n & 1:
result = result * base
base = base * base
n >>= 1
return result
[docs]
def inv(self) -> 'FF4':
"""Multiplicative inverse via tower decomposition."""
return _ef4_inv(self._d)
[docs]
def mul_base(self, base_vals) -> 'FF4':
"""Multiply each element by base field values (coefficient-wise scaling).
More efficient than a * FF4.from_base(b) because it avoids
the full polynomial multiply (only 4 multiplications vs 16).
"""
if isinstance(base_vals, np.ndarray):
b = base_vals.view(np.ndarray).ravel().astype(np.int64) % _P
if self._d.ndim == 2:
return FF4._wrap((self._d * b[:, np.newaxis]) % _P)
return FF4._wrap((self._d * int(b[0])) % _P)
if isinstance(base_vals, (int, np.integer)):
return FF4._wrap((self._d * (int(base_vals) % _P)) % _P)
if isinstance(base_vals, list):
b = np.array(base_vals, dtype=np.int64) % _P
if self._d.ndim == 2:
return FF4._wrap((self._d * b[:, np.newaxis]) % _P)
return FF4._wrap((self._d * int(b[0])) % _P)
raise TypeError(f"Cannot mul_base with {type(base_vals)}")
# --- Column operations ---
[docs]
def roll(self, shift: int) -> 'FF4':
"""Circular shift (for columns)."""
return FF4._wrap(np.roll(self._d, shift, axis=0))
[docs]
def cumsum(self) -> 'FF4':
"""Prefix sum (for columns). Safe for height <= 2^27."""
d = np.cumsum(self._d, axis=0) % _P
return FF4._wrap(d)
# --- Comparison ---
def __eq__(self, other):
if isinstance(other, FF4):
return np.array_equal(self._d, other._d)
if isinstance(other, (list, tuple)):
return np.array_equal(self._d, np.array(other, dtype=np.int64) % _P)
return NotImplemented
def __ne__(self, other):
eq = self.__eq__(other)
if eq is NotImplemented:
return eq
return not eq
def __hash__(self):
if self._d.ndim == 1:
return hash(tuple(self._d.tolist()))
raise TypeError("Unhashable: column FF4")
def __repr__(self):
if self._d.ndim == 1:
return f"FF4({self._d.tolist()})"
return f"FF4(shape=({self._d.shape[0]}, 4))"
def __bool__(self):
raise ValueError("Truth value of FF4 is ambiguous. Use == FF4.zero() for comparison.")
def _ef4_mul(a: np.ndarray, b: np.ndarray) -> FF4:
"""Polynomial multiply for FF4 internal arrays.
Handles all broadcasting: scalar*scalar, scalar*column, column*column.
"""
W = _W_EXT
if a.ndim == 1 and b.ndim == 1:
# Scalar * scalar (Python ints for precision)
a0, a1, a2, a3 = int(a[0]), int(a[1]), int(a[2]), int(a[3])
b0, b1, b2, b3 = int(b[0]), int(b[1]), int(b[2]), int(b[3])
c0 = (a0 * b0 + W * (a1 * b3 + a2 * b2 + a3 * b1)) % _P
c1 = (a0 * b1 + a1 * b0 + W * (a2 * b3 + a3 * b2)) % _P
c2 = (a0 * b2 + a1 * b1 + a2 * b0 + W * a3 * b3) % _P
c3 = (a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0) % _P
return FF4._wrap(np.array([c0, c1, c2, c3], dtype=np.int64))
# At least one operand is a column — use vectorized numpy.
# Use intermediate mod to prevent int64 overflow: (p-1)^2 ~ 2^62, sum of 3 can overflow.
if a.ndim == 1:
a0 = np.int64(a[0]); a1 = np.int64(a[1])
a2 = np.int64(a[2]); a3 = np.int64(a[3])
b0 = b[:, 0]; b1 = b[:, 1]; b2 = b[:, 2]; b3 = b[:, 3]
elif b.ndim == 1:
a0 = a[:, 0]; a1 = a[:, 1]; a2 = a[:, 2]; a3 = a[:, 3]
b0 = np.int64(b[0]); b1 = np.int64(b[1])
b2 = np.int64(b[2]); b3 = np.int64(b[3])
else:
a0 = a[:, 0]; a1 = a[:, 1]; a2 = a[:, 2]; a3 = a[:, 3]
b0 = b[:, 0]; b1 = b[:, 1]; b2 = b[:, 2]; b3 = b[:, 3]
# Reduce each product mod p before summing to avoid int64 overflow
_m = lambda x, y: (x * y) % _P # noqa: E731
c0 = (_m(a0, b0) + W * ((_m(a1, b3) + _m(a2, b2)) % _P + _m(a3, b1)) % _P) % _P
c1 = ((_m(a0, b1) + _m(a1, b0)) % _P + W * ((_m(a2, b3) + _m(a3, b2)) % _P)) % _P
c2 = ((_m(a0, b2) + _m(a1, b1)) % _P + (_m(a2, b0) + W * _m(a3, b3)) % _P) % _P
c3 = ((_m(a0, b3) + _m(a1, b2)) % _P + (_m(a2, b1) + _m(a3, b0)) % _P) % _P
return FF4._wrap(np.column_stack([c0, c1, c2, c3]))
def _ef4_inv(d: np.ndarray) -> FF4:
"""Inverse via tower decomposition.
GF(p^4) = GF(p^2)[y]/(y^2 - W) where GF(p^2) = GF(p)[u]/(u^2 - W).
Element a = a0 + a1*y + a2*u + a3*u*y = (a0 + a2*u) + (a1 + a3*u)*y.
Norm = a_lo^2 - u*a_hi^2 in GF(p^2), then GF(p) norm, invert, multiply back.
"""
W = np.int64(_W_EXT)
if d.ndim == 1:
# Scalar inverse using Python ints
a0, a1, a2, a3 = int(d[0]), int(d[1]), int(d[2]), int(d[3])
# a_lo = a0 + a2*u, a_hi = a1 + a3*u
# a_lo^2 in GF(p^2)
a2_0 = (a0 * a0 + _W_EXT * a2 * a2) % _P
a2_1 = (2 * a0 * a2) % _P
# a_hi^2 in GF(p^2)
b2_0 = (a1 * a1 + _W_EXT * a3 * a3) % _P
b2_1 = (2 * a1 * a3) % _P
# u * a_hi^2: multiply by u where u^2=W
b2u_0 = (_W_EXT * b2_1) % _P
b2u_1 = b2_0
# norm in GF(p^2)
d0 = (a2_0 - b2u_0) % _P
d1 = (a2_1 - b2u_1) % _P
# GF(p) norm
norm = (d0 * d0 - _W_EXT * d1 * d1) % _P
ni = pow(int(norm), _P - 2, _P)
# d^(-1) in GF(p^2)
e0 = (d0 * ni) % _P
e1 = (_P - (d1 * ni) % _P) % _P
# Multiply back
r0 = (a0 * e0 + _W_EXT * a2 * e1) % _P
r1 = (_P - (a1 * e0 + _W_EXT * a3 * e1) % _P) % _P
r2 = (a0 * e1 + a2 * e0) % _P
r3 = (_P - (a1 * e1 + a3 * e0) % _P) % _P
return FF4._wrap(np.array([r0, r1, r2, r3], dtype=np.int64))
# Column inverse — vectorized with intermediate mod to prevent int64 overflow.
# W=11, so W*x*y can reach 11*(P-1)^2 ~ 5e19 > 2^63.
a0 = d[:, 0]; a1 = d[:, 1]; a2 = d[:, 2]; a3 = d[:, 3]
_m = lambda x, y: (x * y) % _P # noqa: E731
# a_lo^2 in GF(p^2): (a0 + a2*u)^2 = (a0^2 + W*a2^2) + 2*a0*a2*u
a2_0 = (_m(a0, a0) + W * _m(a2, a2)) % _P
a2_1 = (2 * _m(a0, a2)) % _P
# a_hi^2 in GF(p^2): (a1 + a3*u)^2 = (a1^2 + W*a3^2) + 2*a1*a3*u
b2_0 = (_m(a1, a1) + W * _m(a3, a3)) % _P
b2_1 = (2 * _m(a1, a3)) % _P
# u * a_hi^2: multiply (b2_0 + b2_1*u) by u where u^2=W
b2u_0 = (W * b2_1) % _P
b2u_1 = b2_0
# norm in GF(p^2)
d0 = (a2_0 - b2u_0) % _P
d1 = (a2_1 - b2u_1) % _P
# GF(p) norm: d0^2 - W*d1^2
norm = (_m(d0, d0) - W * _m(d1, d1)) % _P
ni = _modpow_arr(norm, _P - 2)
# d^(-1) in GF(p^2): (d0*ni, -d1*ni)
e0 = _m(d0, ni)
e1 = (_P - _m(d1, ni)) % _P
# Multiply back: result = conj(a) * d^(-1)
r0 = (_m(a0, e0) + W * _m(a2, e1)) % _P
r1 = (_P - (_m(a1, e0) + W * _m(a3, e1)) % _P) % _P
r2 = (_m(a0, e1) + _m(a2, e0)) % _P
r3 = (_P - (_m(a1, e1) + _m(a3, e0)) % _P) % _P
return FF4._wrap(np.column_stack([r0, r1, r2, r3]))
# --- Type Aliases ---
# Keep galois FF4 accessible for tests and backward compat
[docs]
GaloisFF4Poly = _GaloisFF4
# Semantic type aliases for protocol code
[docs]
Fe = int # Base field element (BabyBear, in [0, p))
[docs]
FF4Coeffs = list[int] # Extension field element as [c0,c1,c2,c3] — DEPRECATED, use FF4
[docs]
Digest = list[int] # 8-element Poseidon2 digest
[docs]
MerklePath = list[Digest] # Merkle opening proof
# --- Galois FF4 / FF4 Conversion (boundary helpers) ---
[docs]
def ff4_coeffs(elem) -> list[int]:
"""Extract ascending-order coefficients [a0, a1, a2, a3] from FF4 element."""
val = int(elem)
return [val % _P, (val // _P) % _P, (val // _P2) % _P, (val // _P3) % _P]
[docs]
def ff4(coeffs) -> 'GaloisFF4':
"""Construct galois FF4 scalar from ascending-order coefficients [a0, a1, a2, a3].
Also accepts FF4 objects for backward compat.
"""
if isinstance(coeffs, FF4):
c = coeffs._d if coeffs._d.ndim == 1 else coeffs._d[0]
a0, a1, a2, a3 = int(c[0]) % _P, int(c[1]) % _P, int(c[2]) % _P, int(c[3]) % _P
else:
a0 = int(coeffs[0]) % _P
a1 = int(coeffs[1]) % _P
a2 = int(coeffs[2]) % _P
a3 = int(coeffs[3]) % _P
return _GaloisFF4(a0 + a1 * _P + a2 * _P2 + a3 * _P3)
[docs]
def ff4_from_base(val: int) -> 'GaloisFF4':
"""Embed base field element into galois FF4 as (val, 0, 0, 0)."""
return _GaloisFF4(int(val) % _P)
[docs]
def ff4_array(c0: list[int], c1: list[int], c2: list[int], c3: list[int]) -> 'GaloisFF4':
"""Construct galois FF4 array from parallel coefficient lists."""
n = len(c0)
vals = [
int(c0[k]) % _P
+ (int(c1[k]) % _P) * _P
+ (int(c2[k]) % _P) * _P2
+ (int(c3[k]) % _P) * _P3
for k in range(n)
]
return _GaloisFF4(vals)
[docs]
def ff4_from_json(json_arr: list[list[int]]) -> 'GaloisFF4':
"""Parse JSON [[c0,c1,c2,c3],...] to galois FF4 array."""
n = len(json_arr)
c0 = [json_arr[i][0] for i in range(n)]
c1 = [json_arr[i][1] for i in range(n)]
c2 = [json_arr[i][2] for i in range(n)]
c3 = [json_arr[i][3] for i in range(n)]
return ff4_array(c0, c1, c2, c3)
[docs]
def ff4_to_json(arr) -> list[list[int]]:
"""Convert galois FF4 array to JSON [[c0,c1,c2,c3],...] format."""
return [ff4_coeffs(elem) for elem in arr]
# --- FF4 JSON converters ---
[docs]
def ef4_from_json(json_arr: list[list[int]]) -> FF4:
"""Parse JSON [[c0,c1,c2,c3],...] to FF4 column."""
return FF4(np.array(json_arr, dtype=np.int64))
[docs]
def ef4_to_json(ef4_col: FF4) -> list[list[int]]:
"""Convert FF4 column/scalar to JSON [[c0,c1,c2,c3],...] format."""
return ef4_col.to_rows()
# --- Deprecated Extension Field Helpers (kept for backward compat) ---
[docs]
def ef4_from_base(x: int) -> FF4:
"""Embed base field element into extension field as (x, 0, 0, 0).
DEPRECATED: Use FF4(x) or FF4.from_base(x) instead.
"""
return FF4(int(x))
[docs]
def ef4_mul(a, b) -> FF4:
"""Multiply two extension field elements.
DEPRECATED: Use a * b instead.
"""
if not isinstance(a, FF4):
a = FF4(a)
if not isinstance(b, FF4):
b = FF4(b)
return a * b
[docs]
def ef4_mul_base(a, b: int) -> FF4:
"""Multiply extension field element by a base field element.
DEPRECATED: Use a.mul_base(b) instead.
"""
if not isinstance(a, FF4):
a = FF4(a)
return a.mul_base(b)
[docs]
def ef4_add(a, b) -> FF4:
"""Add two extension field elements.
DEPRECATED: Use a + b instead.
"""
if not isinstance(a, FF4):
a = FF4(a)
if not isinstance(b, FF4):
b = FF4(b)
return a + b
[docs]
def ef4_sub(a, b) -> FF4:
"""Subtract two extension field elements.
DEPRECATED: Use a - b instead.
"""
if not isinstance(a, FF4):
a = FF4(a)
if not isinstance(b, FF4):
b = FF4(b)
return a - b
[docs]
def ef4_neg(a) -> FF4:
"""Negate an extension field element.
DEPRECATED: Use -a instead.
"""
if not isinstance(a, FF4):
a = FF4(a)
return -a
[docs]
def ef4_inv(x) -> FF4:
"""Multiplicative inverse in extension field.
DEPRECATED: Use x ** -1 instead.
"""
if not isinstance(x, FF4):
x = FF4(x)
return x ** -1
[docs]
def ef4_div(a, b) -> FF4:
"""Division in extension field: a / b.
DEPRECATED: Use a / b instead.
"""
if not isinstance(a, FF4):
a = FF4(a)
if not isinstance(b, FF4):
b = FF4(b)
return a / b
[docs]
def ef4_pow(x, n: int) -> FF4:
"""Exponentiation in extension field.
DEPRECATED: Use x ** n instead.
"""
if not isinstance(x, FF4):
x = FF4(x)
return x ** n
[docs]
def ef4_exp_power_of_2(x, log_power: int) -> FF4:
"""Compute x^(2^log_power) by repeated squaring.
DEPRECATED: Use x ** (2 ** log_power) instead.
"""
if not isinstance(x, FF4):
x = FF4(x)
return x ** (2 ** log_power)
# --- Deprecated FF4Vec Wrappers ---
[docs]
FF4Vec = tuple # Backward compat type alias
[docs]
def ef4v_add(a: FF4, b: FF4) -> FF4:
"""DEPRECATED: Use a + b."""
return a + b
[docs]
def ef4v_sub(a: FF4, b: FF4) -> FF4:
"""DEPRECATED: Use a - b."""
return a - b
[docs]
def ef4v_neg(a: FF4) -> FF4:
"""DEPRECATED: Use -a."""
return -a
[docs]
def ef4v_mul(a: FF4, b: FF4) -> FF4:
"""DEPRECATED: Use a * b."""
return a * b
[docs]
def ef4v_mul_base(a: FF4, b) -> FF4:
"""DEPRECATED: Use a.mul_base(b)."""
return a.mul_base(b)
[docs]
def ef4v_from_base(b) -> FF4:
"""DEPRECATED: Use FF4.from_base(b)."""
return FF4.from_base(b)
[docs]
def ef4v_from_scalar(coeffs, n: int) -> FF4:
"""DEPRECATED: Use FF4.broadcast(scalar, n)."""
if isinstance(coeffs, FF4):
return FF4.broadcast(coeffs, n)
return FF4.broadcast(FF4(coeffs), n)
[docs]
def ef4v_inv(a: FF4) -> FF4:
"""DEPRECATED: Use a.inv() or a ** -1."""
return a.inv()
[docs]
def ef4v_mul_scalar(a: FF4, s) -> FF4:
"""DEPRECATED: Use a * s."""
if not isinstance(s, FF4):
s = FF4(s)
return a * s
[docs]
def ef4v_zeros(n: int) -> FF4:
"""DEPRECATED: Use FF4.zeros(n)."""
return FF4.zeros(n)
[docs]
def ef4v_from_rows(rows: list) -> FF4:
"""DEPRECATED: Use FF4.from_rows(rows)."""
return FF4.from_rows(rows)
[docs]
def ef4v_to_rows(v: FF4) -> list:
"""DEPRECATED: Use v.to_rows()."""
return v.to_rows()
[docs]
def ef4v_roll(v: FF4, shift: int) -> FF4:
"""DEPRECATED: Use v.roll(shift)."""
return v.roll(shift)
[docs]
def ef4v_cumsum(v: FF4) -> FF4:
"""DEPRECATED: Use v.cumsum()."""
return v.cumsum()
# --- Deprecated Base Field Column Helpers ---
[docs]
def ff_column(data) -> FF:
"""DEPRECATED: Use FF(data) directly."""
return FF(data)
[docs]
def ff_zeros(n: int) -> FF:
"""DEPRECATED: Use FF.Zeros(n) directly."""
return FF.Zeros(n)
[docs]
def ff_constant(val: int, n: int) -> FF:
"""DEPRECATED: Use FF(np.full(n, val % BABYBEAR_PRIME)) directly."""
return FF(np.full(n, val % BABYBEAR_PRIME, dtype=np.int64))
[docs]
def ff_roll(arr, shift: int):
"""DEPRECATED: Use np.roll(arr, shift) directly."""
return np.roll(arr, shift)
# --- NTT Support ---
# Precomputed roots of unity: W[n] is a primitive 2^n-th root of unity in BabyBear.
_root_27 = pow(GENERATOR, (BABYBEAR_PRIME - 1) >> TWO_ADICITY, BABYBEAR_PRIME)
[docs]
W: list[int] = [0] * (TWO_ADICITY + 1)
W[TWO_ADICITY] = _root_27
for _k in range(TWO_ADICITY - 1, -1, -1):
W[_k] = pow(W[_k + 1], 2, BABYBEAR_PRIME)
# Precomputed inverses: W_INV[n] = W[n]^(-1) mod p
[docs]
W_INV: list[int] = [pow(w, BABYBEAR_PRIME - 2, BABYBEAR_PRIME) if w != 0 else 0 for w in W]
[docs]
def get_omega(n_bits: int) -> int:
"""Return primitive 2^n_bits-th root of unity."""
return W[n_bits]
[docs]
def get_omega_inv(n_bits: int) -> int:
"""Return inverse of primitive 2^n_bits-th root of unity."""
return W_INV[n_bits]
[docs]
def inv_mod(x: int) -> int:
"""Multiplicative inverse of x modulo BABYBEAR_PRIME."""
return pow(x, BABYBEAR_PRIME - 2, BABYBEAR_PRIME)
# --- Bit Reversal Utilities ---
[docs]
def to_monty(x: int) -> int:
"""Convert a canonical integer to BabyBear Montgomery form."""
return (x * MONTY_R) % BABYBEAR_PRIME
[docs]
def from_monty(x: int) -> int:
"""Convert a BabyBear Montgomery-form value to canonical form."""
return (x * MONTY_RINV) % BABYBEAR_PRIME
[docs]
def reverse_bits_len(x: int, bit_len: int) -> int:
"""Reverse the lowest bit_len bits of x."""
result = 0
for _ in range(bit_len):
result = (result << 1) | (x & 1)
x >>= 1
return result
[docs]
def bit_reverse_list(lst: list) -> list:
"""Reorder list elements by bit-reversing their indices."""
n = len(lst)
if n <= 1:
return list(lst)
log_n = n.bit_length() - 1
return [lst[reverse_bits_len(i, log_n)] for i in range(n)]
# --- Montgomery Batch Inversion ---
[docs]
def batch_inverse(values):
"""Montgomery batch inversion for any galois array.
Converts N field inversions into 3N-3 multiplications + 1 inversion.
"""
n = len(values)
if n == 0:
return values
if n == 1:
return values ** -1
field_type = type(values)
cumprods = field_type.Zeros(n)
cumprods[0] = values[0]
for i in range(1, n):
cumprods[i] = cumprods[i - 1] * values[i]
inv_total = cumprods[n - 1] ** -1
results = field_type.Zeros(n)
z = inv_total
for i in range(n - 1, 0, -1):
results[i] = z * cumprods[i - 1]
z = z * values[i]
results[0] = z
return results
[docs]
def ef4_batch_inverse(values: list) -> list:
"""Montgomery batch inversion for FF4 elements.
Converts N extension field inversions into 3N-3 multiplications + 1 inversion.
Accepts list of FF4 scalars or list of FF4Coeffs (list[int]).
"""
n = len(values)
if n == 0:
return []
# Convert to FF4 if needed
vals = [v if isinstance(v, FF4) else FF4(v) for v in values]
if n == 1:
return [vals[0] ** -1]
cumprods = [None] * n
cumprods[0] = vals[0]
for i in range(1, n):
cumprods[i] = cumprods[i - 1] * vals[i]
inv_total = cumprods[n - 1] ** -1
results = [None] * n
z = inv_total
for i in range(n - 1, 0, -1):
results[i] = z * cumprods[i - 1]
z = z * vals[i]
results[0] = z
return results
[docs]
def batch_inverse_base(values: list) -> list:
"""Montgomery batch inversion in the base field."""
p = BABYBEAR_PRIME
n = len(values)
if n == 0:
return []
if n == 1:
return [inv_mod(values[0])]
cumprods = [0] * n
cumprods[0] = values[0] % p
for i in range(1, n):
cumprods[i] = (cumprods[i - 1] * values[i]) % p
inv_total = inv_mod(cumprods[n - 1])
results = [0] * n
z = inv_total
for i in range(n - 1, 0, -1):
results[i] = (z * cumprods[i - 1]) % p
z = (z * values[i]) % p
results[0] = z
return results
# --- Batch polynomial evaluation at a single FF4 point ---
def _ef4_mul_raw(a, b):
"""Raw FF4 multiply on tuples/lists of ints. Internal use for BSGS."""
a0, a1, a2, a3 = a
b0, b1, b2, b3 = b
c0 = (a0 * b0 + _W_EXT * (a1 * b3 + a2 * b2 + a3 * b1)) % _P
c1 = (a0 * b1 + a1 * b0 + _W_EXT * (a2 * b3 + a3 * b2)) % _P
c2 = (a0 * b2 + a1 * b1 + a2 * b0 + _W_EXT * a3 * b3) % _P
c3 = (a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0) % _P
return [c0, c1, c2, c3]
def _outer_flat(a: np.ndarray, b: np.ndarray, length: int) -> np.ndarray:
"""Compute outer product, flatten, and truncate to length."""
return np.outer(a, b).ravel()[:length]
def _outer_mod(a: np.ndarray, b: np.ndarray, length: int) -> np.ndarray:
"""Compute outer product, flatten, truncate, and reduce mod p."""
return _outer_flat(a, b, length) % _P
def _precompute_z_powers_bsgs(z, degree: int):
"""Precompute z^0..z^{degree-1} as 4 coefficient arrays using BSGS."""
if degree <= 0:
return tuple(np.empty(0, dtype=np.int64) for _ in range(4))
B = max(1, int(degree ** 0.5))
num_blocks = (degree + B - 1) // B
small = [(1, 0, 0, 0)]
cur = (1, 0, 0, 0)
for _ in range(B - 1):
cur = _ef4_mul_raw(cur, z)
small.append(cur)
z_B = _ef4_mul_raw(cur, z)
big = [(1, 0, 0, 0)]
cur = (1, 0, 0, 0)
for _ in range(num_blocks - 1):
cur = _ef4_mul_raw(cur, z_B)
big.append(cur)
small_np = [np.array([s[c] for s in small], dtype=np.int64) for c in range(4)]
big_np = [np.array([b[c] for b in big], dtype=np.int64) for c in range(4)]
c0 = _outer_mod(big_np[0], small_np[0], degree)
w_sum = (_outer_flat(big_np[1], small_np[3], degree)
+ _outer_flat(big_np[2], small_np[2], degree)) % _P
w_sum = (w_sum + _outer_mod(big_np[3], small_np[1], degree)) % _P
c0 = (c0 + _W_EXT * w_sum) % _P
t1 = (_outer_flat(big_np[0], small_np[1], degree)
+ _outer_flat(big_np[1], small_np[0], degree)) % _P
w_sum1 = (_outer_flat(big_np[2], small_np[3], degree)
+ _outer_flat(big_np[3], small_np[2], degree)) % _P
c1 = (t1 + _W_EXT * w_sum1) % _P
t2 = (_outer_flat(big_np[0], small_np[2], degree)
+ _outer_flat(big_np[1], small_np[1], degree)) % _P
t2 = (t2 + _outer_mod(big_np[2], small_np[0], degree)) % _P
c2 = (t2 + _W_EXT * _outer_mod(big_np[3], small_np[3], degree)) % _P
t3 = (_outer_flat(big_np[0], small_np[3], degree)
+ _outer_flat(big_np[1], small_np[2], degree)) % _P
t3_2 = (_outer_flat(big_np[2], small_np[1], degree)
+ _outer_flat(big_np[3], small_np[0], degree)) % _P
c3 = (t3 + t3_2) % _P
return (c0, c1, c2, c3)
[docs]
def eval_poly_ef4_batch(
coeffs_per_col: list[list[int]],
eval_point,
) -> list:
"""Evaluate multiple polynomials at a single FF4 point using BSGS.
Args:
coeffs_per_col: Polynomial coefficient vectors, all same degree.
eval_point: FF4 scalar or [c0,c1,c2,c3] list.
Returns:
List of FF4 scalars, one per polynomial.
"""
num_cols = len(coeffs_per_col)
if num_cols == 0:
return []
degree = len(coeffs_per_col[0])
if degree == 0:
return [FF4.zero()] * num_cols
if isinstance(eval_point, FF4):
z = tuple(int(c) for c in eval_point._d[:4])
else:
z = tuple(int(c) for c in eval_point[:4])
z_powers = _precompute_z_powers_bsgs(z, degree)
coeff_mat = np.array(coeffs_per_col, dtype=np.int64)
results = []
for zp in z_powers:
products = (coeff_mat * zp) % _P
result_j = products.sum(axis=1) % _P
results.append(result_j)
out = np.stack(results, axis=1)
return [FF4._wrap(np.array(row, dtype=np.int64)) for row in out]