Skip to content

Commit

Permalink
fix(bandersnatch): GLV bounds + test (#516)
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni authored Jul 18, 2024
1 parent 7e5f929 commit 6cf8884
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 92 deletions.
22 changes: 18 additions & 4 deletions ecc/bls12-377/twistededwards/point.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 18 additions & 4 deletions ecc/bls12-378/twistededwards/point.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 53 additions & 41 deletions ecc/bls12-381/bandersnatch/endomorpism.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package bandersnatch

import (
"math"
"math/big"

"github.com/consensys/gnark-crypto/ecc"
Expand Down Expand Up @@ -30,14 +29,13 @@ func (p *PointProj) phi(p1 *PointProj) *PointProj {
return p
}

// ScalarMultiplication scalar multiplication (GLV) of a point
// scalarMulGLV is the GLV scalar multiplication of a point
// p1 in projective coordinates with a scalar in big.Int
func (p *PointProj) scalarMulGLV(p1 *PointProj, scalar *big.Int) *PointProj {

initOnce.Do(initCurveParams)

var table [15]PointProj
var zero big.Int
var res PointProj
var k1, k2 fr.Element

Expand All @@ -50,38 +48,45 @@ func (p *PointProj) scalarMulGLV(p1 *PointProj, scalar *big.Int) *PointProj {
// split the scalar, modifies +-p1, phi(p1) accordingly
k := ecc.SplitScalar(scalar, &curveParams.glvBasis)

if k[0].Cmp(&zero) == -1 {
if k[0].Sign() == -1 {
k[0].Neg(&k[0])
table[0].Neg(&table[0])
}
if k[1].Cmp(&zero) == -1 {
if k[1].Sign() == -1 {
k[1].Neg(&k[1])
table[3].Neg(&table[3])
}

// precompute table (2 bits sliding window)
// table[b3b2b1b0-1] = b3b2*phi(p1) + b1b0*p1 if b3b2b1b0 != 0
table[1].Double(&table[0])
table[2].Set(&table[1]).Add(&table[2], &table[0])
table[4].Set(&table[3]).Add(&table[4], &table[0])
table[5].Set(&table[3]).Add(&table[5], &table[1])
table[6].Set(&table[3]).Add(&table[6], &table[2])
table[2].Add(&table[1], &table[0])
table[4].Add(&table[3], &table[0])
table[5].Add(&table[3], &table[1])
table[6].Add(&table[3], &table[2])
table[7].Double(&table[3])
table[8].Set(&table[7]).Add(&table[8], &table[0])
table[9].Set(&table[7]).Add(&table[9], &table[1])
table[10].Set(&table[7]).Add(&table[10], &table[2])
table[11].Set(&table[7]).Add(&table[11], &table[3])
table[12].Set(&table[11]).Add(&table[12], &table[0])
table[13].Set(&table[11]).Add(&table[13], &table[1])
table[14].Set(&table[11]).Add(&table[14], &table[2])

// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max
table[8].Add(&table[7], &table[0])
table[9].Add(&table[7], &table[1])
table[10].Add(&table[7], &table[2])
table[11].Add(&table[7], &table[3])
table[12].Add(&table[11], &table[0])
table[13].Add(&table[11], &table[1])
table[14].Add(&table[11], &table[2])

// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max
// this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift
k1 = k1.SetBigInt(&k[0]).Bits()
k2 = k2.SetBigInt(&k[1]).Bits()

// loop starts from len(k1)/2 due to the bounds
// fr.Limbs == Order.limbs
for i := int(math.Ceil(fr.Limbs/2. - 1)); i >= 0; i-- {
// we don't target constant-timeness so we check first if we increase the bounds or not
maxBit := k1.BitLen()
if k2.BitLen() > maxBit {
maxBit = k2.BitLen()
}
hiWordIndex := (maxBit - 1) / 64

// loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds
for i := hiWordIndex; i >= 0; i-- {
mask := uint64(3) << 62
for j := 0; j < 32; j++ {
res.Double(&res).Double(&res)
Expand Down Expand Up @@ -121,13 +126,13 @@ func (p *PointExtended) phi(p1 *PointExtended) *PointExtended {
return p
}

// ScalarMultiplication scalar multiplication (GLV) of a point
// scalarMulGLV is the GLV scalar multiplication of a point
// p1 in projective coordinates with a scalar in big.Int
func (p *PointExtended) scalarMulGLV(p1 *PointExtended, scalar *big.Int) *PointExtended {

initOnce.Do(initCurveParams)

var table [15]PointExtended
var zero big.Int
var res PointExtended
var k1, k2 fr.Element

Expand All @@ -140,38 +145,45 @@ func (p *PointExtended) scalarMulGLV(p1 *PointExtended, scalar *big.Int) *PointE
// split the scalar, modifies +-p1, phi(p1) accordingly
k := ecc.SplitScalar(scalar, &curveParams.glvBasis)

if k[0].Cmp(&zero) == -1 {
if k[0].Sign() == -1 {
k[0].Neg(&k[0])
table[0].Neg(&table[0])
}
if k[1].Cmp(&zero) == -1 {
if k[1].Sign() == -1 {
k[1].Neg(&k[1])
table[3].Neg(&table[3])
}

// precompute table (2 bits sliding window)
// table[b3b2b1b0-1] = b3b2*phi(p1) + b1b0*p1 if b3b2b1b0 != 0
table[1].Double(&table[0])
table[2].Set(&table[1]).Add(&table[2], &table[0])
table[4].Set(&table[3]).Add(&table[4], &table[0])
table[5].Set(&table[3]).Add(&table[5], &table[1])
table[6].Set(&table[3]).Add(&table[6], &table[2])
table[2].Add(&table[1], &table[0])
table[4].Add(&table[3], &table[0])
table[5].Add(&table[3], &table[1])
table[6].Add(&table[3], &table[2])
table[7].Double(&table[3])
table[8].Set(&table[7]).Add(&table[8], &table[0])
table[9].Set(&table[7]).Add(&table[9], &table[1])
table[10].Set(&table[7]).Add(&table[10], &table[2])
table[11].Set(&table[7]).Add(&table[11], &table[3])
table[12].Set(&table[11]).Add(&table[12], &table[0])
table[13].Set(&table[11]).Add(&table[13], &table[1])
table[14].Set(&table[11]).Add(&table[14], &table[2])

// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max
table[8].Add(&table[7], &table[0])
table[9].Add(&table[7], &table[1])
table[10].Add(&table[7], &table[2])
table[11].Add(&table[7], &table[3])
table[12].Add(&table[11], &table[0])
table[13].Add(&table[11], &table[1])
table[14].Add(&table[11], &table[2])

// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max
// this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift
k1 = k1.SetBigInt(&k[0]).Bits()
k2 = k2.SetBigInt(&k[1]).Bits()

// loop starts from len(k1)/2 due to the bounds
// fr.Limbs == Order.limbs
for i := int(math.Ceil(fr.Limbs/2. - 1)); i >= 0; i-- {
// we don't target constant-timeness so we check first if we increase the bounds or not
maxBit := k1.BitLen()
if k2.BitLen() > maxBit {
maxBit = k2.BitLen()
}
hiWordIndex := (maxBit - 1) / 64

// loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds
for i := hiWordIndex; i >= 0; i-- {
mask := uint64(3) << 62
for j := 0; j < 32; j++ {
res.Double(&res).Double(&res)
Expand Down
63 changes: 63 additions & 0 deletions ecc/bls12-381/bandersnatch/point.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions ecc/bls12-381/bandersnatch/point_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6cf8884

Please sign in to comment.