diff --git a/soroban-sdk/src/crypto/bls12_381.rs b/soroban-sdk/src/crypto/bls12_381.rs index cbfe9d41..ba66723a 100644 --- a/soroban-sdk/src/crypto/bls12_381.rs +++ b/soroban-sdk/src/crypto/bls12_381.rs @@ -1,10 +1,14 @@ use crate::{ - env::internal::{self, BytesObject, U64Val}, + env::internal::{self, BytesObject, U256Val, U64Val}, impl_bytesn_repr, unwrap::{UnwrapInfallible, UnwrapOptimized}, Bytes, BytesN, ConversionError, Env, IntoVal, TryFromVal, Val, Vec, U256, }; -use core::{cmp::Ordering, fmt::Debug}; +use core::{ + cmp::Ordering, + fmt::Debug, + ops::{Add, Mul, Sub}, +}; /// Bls12_381 provides access to curve and field arithmetics on the BLS12-381 /// curve. @@ -84,6 +88,14 @@ pub struct Fp(BytesN<48>); #[repr(transparent)] pub struct Fp2(BytesN<96>); +/// `Fr` represents an element in the BLS12-381 scalar field, which is a prime +/// field of order `r` (the order of the G1 and G2 groups). The struct is +/// internally represented with an `U256`, all arithmetic operations follow +/// modulo `r`. +#[derive(Clone)] +#[repr(transparent)] +pub struct Fr(U256); + impl_bytesn_repr!(G1Affine, 96); impl_bytesn_repr!(G2Affine, 192); impl_bytesn_repr!(Fp, 48); @@ -103,7 +115,7 @@ impl G1Affine { } } -impl core::ops::Add for G1Affine { +impl Add for G1Affine { type Output = G1Affine; fn add(self, rhs: Self) -> Self::Output { @@ -111,10 +123,10 @@ impl core::ops::Add for G1Affine { } } -impl core::ops::Mul for G1Affine { +impl Mul for G1Affine { type Output = G1Affine; - fn mul(self, rhs: U256) -> Self::Output { + fn mul(self, rhs: Fr) -> Self::Output { self.env().crypto().bls12_381().g1_mul(&self, &rhs) } } @@ -133,7 +145,7 @@ impl G2Affine { } } -impl core::ops::Add for G2Affine { +impl Add for G2Affine { type Output = G2Affine; fn add(self, rhs: Self) -> Self::Output { @@ -141,10 +153,10 @@ impl core::ops::Add for G2Affine { } } -impl core::ops::Mul for G2Affine { +impl Mul for G2Affine { type Output = G2Affine; - fn mul(self, rhs: U256) -> Self::Output { + fn mul(self, rhs: Fr) -> Self::Output { self.env().crypto().bls12_381().g2_mul(&self, &rhs) } } @@ -169,6 +181,113 @@ impl Fp2 { } } +impl Fr { + pub fn env(&self) -> &Env { + self.0.env() + } + + pub fn from_u256(value: U256) -> Self { + value.into() + } + + pub fn to_u256(&self) -> U256 { + self.0.clone() + } + + pub fn as_u256(&self) -> &U256 { + &self.0 + } + + pub fn from_bytes(bytes: BytesN<32>) -> Self { + U256::from_be_bytes(bytes.env(), bytes.as_ref()).into() + } + + pub fn to_bytes(&self) -> BytesN<32> { + self.as_u256().to_be_bytes().try_into().unwrap_optimized() + } + + pub fn as_val(&self) -> &Val { + self.0.as_val() + } + + pub fn to_val(&self) -> Val { + self.0.to_val() + } + + pub fn pow(&self, rhs: u64) -> Self { + self.env().crypto().bls12_381().fr_pow(self, rhs) + } + + pub fn inv(&self) -> Self { + self.env().crypto().bls12_381().fr_inv(self) + } +} + +impl From for Fr { + fn from(value: U256) -> Self { + Self(value) + } +} + +impl From<&Fr> for U256Val { + fn from(value: &Fr) -> Self { + value.as_u256().into() + } +} + +impl IntoVal for Fr { + fn into_val(&self, e: &Env) -> Val { + self.0.into_val(e) + } +} + +impl TryFromVal for Fr { + type Error = ConversionError; + + fn try_from_val(env: &Env, val: &Val) -> Result { + let u = U256::try_from_val(env, val)?; + Ok(Fr(u)) + } +} + +impl Eq for Fr {} + +impl PartialEq for Fr { + fn eq(&self, other: &Self) -> bool { + self.as_u256().partial_cmp(other.as_u256()) == Some(Ordering::Equal) + } +} + +impl Debug for Fr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Fr({:?})", self.as_u256()) + } +} + +impl Add for Fr { + type Output = Fr; + + fn add(self, rhs: Self) -> Self::Output { + self.env().crypto().bls12_381().fr_add(&self, &rhs) + } +} + +impl Sub for Fr { + type Output = Fr; + + fn sub(self, rhs: Self) -> Self::Output { + self.env().crypto().bls12_381().fr_sub(&self, &rhs) + } +} + +impl Mul for Fr { + type Output = Fr; + + fn mul(self, rhs: Self) -> Self::Output { + self.env().crypto().bls12_381().fr_mul(&self, &rhs) + } +} + impl Bls12_381 { pub(crate) fn new(env: &Env) -> Bls12_381 { Bls12_381 { env: env.clone() } @@ -217,7 +336,7 @@ impl Bls12_381 { } /// Multiplies a point `p0` in G1 by a scalar. - pub fn g1_mul(&self, p0: &G1Affine, scalar: &U256) -> G1Affine { + pub fn g1_mul(&self, p0: &G1Affine, scalar: &Fr) -> G1Affine { let env = self.env(); let bin = internal::Env::bls12_381_g1_mul(env, p0.to_object(), scalar.into()).unwrap_infallible(); @@ -225,7 +344,7 @@ impl Bls12_381 { } /// Performs a multi-scalar multiplication (MSM) operation in G1. - pub fn g1_msm(&self, vp: Vec, vs: Vec) -> G1Affine { + pub fn g1_msm(&self, vp: Vec, vs: Vec) -> G1Affine { let env = self.env(); let bin = internal::Env::bls12_381_g1_msm(env, vp.into(), vs.into()).unwrap_infallible(); unsafe { G1Affine::from_bytes(BytesN::unchecked_new(env.clone(), bin)) } @@ -285,7 +404,7 @@ impl Bls12_381 { } /// Multiplies a point `p0` in G2 by a scalar. - pub fn g2_mul(&self, p0: &G2Affine, scalar: &U256) -> G2Affine { + pub fn g2_mul(&self, p0: &G2Affine, scalar: &Fr) -> G2Affine { let env = self.env(); let bin = internal::Env::bls12_381_g2_mul(env, p0.to_object(), scalar.into()).unwrap_infallible(); @@ -293,7 +412,7 @@ impl Bls12_381 { } /// Performs a multi-scalar multiplication (MSM) operation in G2. - pub fn g2_msm(&self, vp: Vec, vs: Vec) -> G2Affine { + pub fn g2_msm(&self, vp: Vec, vs: Vec) -> G2Affine { let env = self.env(); let bin = internal::Env::bls12_381_g2_msm(env, vp.into(), vs.into()).unwrap_infallible(); unsafe { G2Affine::from_bytes(BytesN::unchecked_new(env.clone(), bin)) } @@ -338,38 +457,38 @@ impl Bls12_381 { // scalar arithmetic /// Adds two scalars in the BLS12-381 scalar field `Fr`. - pub fn fr_add(&self, lhs: &U256, rhs: &U256) -> U256 { + pub fn fr_add(&self, lhs: &Fr, rhs: &Fr) -> Fr { let env = self.env(); let v = internal::Env::bls12_381_fr_add(env, lhs.into(), rhs.into()).unwrap_infallible(); - U256::try_from_val(env, &v).unwrap_infallible() + U256::try_from_val(env, &v).unwrap_infallible().into() } /// Subtracts one scalar from another in the BLS12-381 scalar field `Fr`. - pub fn fr_sub(&self, lhs: &U256, rhs: &U256) -> U256 { + pub fn fr_sub(&self, lhs: &Fr, rhs: &Fr) -> Fr { let env = self.env(); let v = internal::Env::bls12_381_fr_sub(env, lhs.into(), rhs.into()).unwrap_infallible(); - U256::try_from_val(env, &v).unwrap_infallible() + U256::try_from_val(env, &v).unwrap_infallible().into() } /// Multiplies two scalars in the BLS12-381 scalar field `Fr`. - pub fn fr_mul(&self, lhs: &U256, rhs: &U256) -> U256 { + pub fn fr_mul(&self, lhs: &Fr, rhs: &Fr) -> Fr { let env = self.env(); let v = internal::Env::bls12_381_fr_mul(env, lhs.into(), rhs.into()).unwrap_infallible(); - U256::try_from_val(env, &v).unwrap_infallible() + U256::try_from_val(env, &v).unwrap_infallible().into() } /// Raises a scalar to the power of a given exponent in the BLS12-381 scalar field `Fr`. - pub fn fr_pow(&self, lhs: &U256, rhs: u64) -> U256 { + pub fn fr_pow(&self, lhs: &Fr, rhs: u64) -> Fr { let env = self.env(); let rhs = U64Val::try_from_val(env, &rhs).unwrap_optimized(); let v = internal::Env::bls12_381_fr_pow(env, lhs.into(), rhs).unwrap_infallible(); - U256::try_from_val(env, &v).unwrap_infallible() + U256::try_from_val(env, &v).unwrap_infallible().into() } /// Computes the multiplicative inverse of a scalar in the BLS12-381 scalar field `Fr`. - pub fn fr_inv(&self, lhs: &U256) -> U256 { + pub fn fr_inv(&self, lhs: &Fr) -> Fr { let env = self.env(); let v = internal::Env::bls12_381_fr_inv(env, lhs.into()).unwrap_infallible(); - U256::try_from_val(env, &v).unwrap_infallible() + U256::try_from_val(env, &v).unwrap_infallible().into() } } diff --git a/soroban-sdk/src/tests/crypto_bls12_381.rs b/soroban-sdk/src/tests/crypto_bls12_381.rs index 04d03ac4..a8efef6f 100644 --- a/soroban-sdk/src/tests/crypto_bls12_381.rs +++ b/soroban-sdk/src/tests/crypto_bls12_381.rs @@ -1,6 +1,6 @@ use crate::{ bytes, bytesn, - crypto::bls12_381::{Bls12_381, Fp, Fp2, G1Affine, G2Affine}, + crypto::bls12_381::{Bls12_381, Fp, Fp2, Fr, G1Affine, G2Affine}, vec, Bytes, Env, Vec, U256, }; @@ -25,12 +25,16 @@ fn test_bls_g1() { assert!(res.is_some_and(|v| v == one)); // mul - let res = bls12_381.g1_mul(&one, &U256::from_u32(&env, 0)); + let res = bls12_381.g1_mul(&one, &U256::from_u32(&env, 0).into()); assert_eq!(res, zero); // msm let vp: Vec = vec![&env, one.clone(), one.clone()]; - let vs: Vec = vec![&env, U256::from_u32(&env, 1), U256::from_u32(&env, 0)]; + let vs: Vec = vec![ + &env, + U256::from_u32(&env, 1).into(), + U256::from_u32(&env, 0).into(), + ]; let res = bls12_381.g1_msm(vp, vs); assert_eq!(res, one); @@ -69,13 +73,23 @@ fn test_bls_g2() { assert!(res.is_some_and(|v| v == one)); // mul - let res = bls12_381.g2_mul(&one, &U256::from_u32(&env, 0)); + let res = bls12_381.g2_mul(&one, &U256::from_u32(&env, 0).into()); assert_eq!(res, zero); // msm let vp: Vec = vec![&env, one.clone(), one.clone()]; - let vs: Vec = vec![&env, U256::from_u32(&env, 1), U256::from_u32(&env, 0)]; - let res = bls12_381.g2_msm(vp, vs); + let vs: Vec = vec![ + &env, + Fr::from_bytes(bytesn!( + &env, + 0x0000000000000000000000000000000000000000000000000000000000000001 + )), + Fr::from_bytes(bytesn!( + &env, + 0x0000000000000000000000000000000000000000000000000000000000000000 + )), + ]; + let res = bls12_381.g2_msm(vp.clone(), vs); assert_eq!(res, one); // map to curve (test case from https://datatracker.ietf.org/doc/html/rfc9380) @@ -125,24 +139,33 @@ fn test_fr_arithmetic() { ), ); assert_eq!( - bls12_381.fr_add(&U256::from_u32(&env, 2), &U256::from_u32(&env, 3)), - U256::from_u32(&env, 5) + bls12_381.fr_add( + &U256::from_u32(&env, 2).into(), + &U256::from_u32(&env, 3).into() + ), + U256::from_u32(&env, 5).into() ); assert_eq!( - bls12_381.fr_sub(&U256::from_u32(&env, 2), &U256::from_u32(&env, 3)), - modulus.sub(&U256::from_u32(&env, 1)) + bls12_381.fr_sub( + &U256::from_u32(&env, 2).into(), + &U256::from_u32(&env, 3).into() + ), + modulus.sub(&U256::from_u32(&env, 1)).into() ); assert_eq!( - bls12_381.fr_mul(&U256::from_u32(&env, 2), &U256::from_u32(&env, 3)), - U256::from_u32(&env, 6) + bls12_381.fr_mul( + &U256::from_u32(&env, 2).into(), + &U256::from_u32(&env, 3).into() + ), + U256::from_u32(&env, 6).into() ); assert_eq!( - bls12_381.fr_pow(&U256::from_u32(&env, 5), 2), - U256::from_u32(&env, 25) + bls12_381.fr_pow(&U256::from_u32(&env, 5).into(), 2), + U256::from_u32(&env, 25).into() ); - let inverse_13 = bls12_381.fr_inv(&U256::from_u32(&env, 13)); + let inverse_13 = bls12_381.fr_inv(&U256::from_u32(&env, 13).into()); assert_eq!( - bls12_381.fr_mul(&inverse_13, &U256::from_u32(&env, 13)), - U256::from_u32(&env, 1) + bls12_381.fr_mul(&inverse_13, &U256::from_u32(&env, 13).into()), + U256::from_u32(&env, 1).into() ); }