diff --git a/ecc/bls12-377/fp/element_mul_amd64.s b/ecc/bls12-377/fp/element_mul_amd64.s index 3e7650e5aa..1e19c4d3fd 100644 --- a/ecc/bls12-377/fp/element_mul_amd64.s +++ b/ecc/bls12-377/fp/element_mul_amd64.s @@ -67,7 +67,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), R8 // x[0] -> R10 @@ -570,7 +570,7 @@ TEXT ·mul(SB), $24-24 MOVQ DI, 40(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -595,7 +595,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -850,7 +850,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ DI, 40(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index a4045c3bc8..582d8b4aff 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -711,6 +711,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls12-377/fp/vector.go b/ecc/bls12-377/fp/vector.go index 3fe1137102..0df05e3373 100644 --- a/ecc/bls12-377/fp/vector.go +++ b/ecc/bls12-377/fp/vector.go @@ -201,6 +201,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-377/fr/element_mul_amd64.s b/ecc/bls12-377/fr/element_mul_amd64.s index dc601e91e3..ab18162458 100644 --- a/ecc/bls12-377/fr/element_mul_amd64.s +++ b/ecc/bls12-377/fr/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index e40a9caed5..21568255de 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index afe75ff25e..ffa3b7bcae 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0x0a11800000000001, R11 + MOVQ $0x59aa76fed0000001, R12 + MOVQ $0x60b44d1e5c37b001, R13 + MOVQ $0x12ab655e9a2ca556, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go index fe434ed616..9c34ebecce 100644 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 819e4d57ea..9b41902857 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls12-377/fr/vector.go b/ecc/bls12-377/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/bls12-377/fr/vector.go +++ b/ecc/bls12-377/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-381/fp/element_mul_amd64.s b/ecc/bls12-381/fp/element_mul_amd64.s index 9e03b1c0ac..e95c984037 100644 --- a/ecc/bls12-381/fp/element_mul_amd64.s +++ b/ecc/bls12-381/fp/element_mul_amd64.s @@ -67,7 +67,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), R8 // x[0] -> R10 @@ -570,7 +570,7 @@ TEXT ·mul(SB), $24-24 MOVQ DI, 40(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -595,7 +595,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -850,7 +850,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ DI, 40(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index 2ca4c51017..d070a1814a 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -711,6 +711,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls12-381/fp/vector.go b/ecc/bls12-381/fp/vector.go index 3fe1137102..0df05e3373 100644 --- a/ecc/bls12-381/fp/vector.go +++ b/ecc/bls12-381/fp/vector.go @@ -201,6 +201,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-381/fr/element_mul_amd64.s b/ecc/bls12-381/fr/element_mul_amd64.s index ef89cc5dfd..396d990b75 100644 --- a/ecc/bls12-381/fr/element_mul_amd64.s +++ b/ecc/bls12-381/fr/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index e40a9caed5..21568255de 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index dde3813281..caffb72b1f 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0xffffffff00000001, R11 + MOVQ $0x53bda402fffe5bfe, R12 + MOVQ $0x3339d80809a1d805, R13 + MOVQ $0x73eda753299d7d48, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go index 258157ab79..50e839865c 100644 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index 40527fda5b..684ea15253 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls12-381/fr/vector.go b/ecc/bls12-381/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/bls12-381/fr/vector.go +++ b/ecc/bls12-381/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-315/fp/element_mul_amd64.s b/ecc/bls24-315/fp/element_mul_amd64.s index 51165684d5..92bba4f58a 100644 --- a/ecc/bls24-315/fp/element_mul_amd64.s +++ b/ecc/bls24-315/fp/element_mul_amd64.s @@ -63,7 +63,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), DI // x[0] -> R9 @@ -435,7 +435,7 @@ TEXT ·mul(SB), $24-24 MOVQ SI, 32(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -460,7 +460,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -649,7 +649,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ SI, 32(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 68abac8759..665ffce6a4 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -709,6 +709,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls24-315/fp/vector.go b/ecc/bls24-315/fp/vector.go index 4d428082f2..01b326d498 100644 --- a/ecc/bls24-315/fp/vector.go +++ b/ecc/bls24-315/fp/vector.go @@ -200,6 +200,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-315/fr/element_mul_amd64.s b/ecc/bls24-315/fr/element_mul_amd64.s index e32f783544..d028fed20a 100644 --- a/ecc/bls24-315/fr/element_mul_amd64.s +++ b/ecc/bls24-315/fr/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index e40a9caed5..21568255de 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index e09e5ee149..2e52c653bf 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0x19d0c5fd00c00001, R11 + MOVQ $0xc8c480ece644e364, R12 + MOVQ $0x25fc7ec9cf927a98, R13 + MOVQ $0x196deac24a9da12b, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go index 8dcb67f76e..7b6cfd87b4 100644 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index 2000441ed9..ac030b6d05 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls24-315/fr/vector.go b/ecc/bls24-315/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/bls24-315/fr/vector.go +++ b/ecc/bls24-315/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-317/fp/element_mul_amd64.s b/ecc/bls24-317/fp/element_mul_amd64.s index 56bfe818a3..bfc863eeba 100644 --- a/ecc/bls24-317/fp/element_mul_amd64.s +++ b/ecc/bls24-317/fp/element_mul_amd64.s @@ -63,7 +63,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), DI // x[0] -> R9 @@ -435,7 +435,7 @@ TEXT ·mul(SB), $24-24 MOVQ SI, 32(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -460,7 +460,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -649,7 +649,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ SI, 32(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 1324638f06..7bbabe2599 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -709,6 +709,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls24-317/fp/vector.go b/ecc/bls24-317/fp/vector.go index 4d428082f2..01b326d498 100644 --- a/ecc/bls24-317/fp/vector.go +++ b/ecc/bls24-317/fp/vector.go @@ -200,6 +200,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-317/fr/element_mul_amd64.s b/ecc/bls24-317/fr/element_mul_amd64.s index 150801a1a8..6e58b40d60 100644 --- a/ecc/bls24-317/fr/element_mul_amd64.s +++ b/ecc/bls24-317/fr/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index e40a9caed5..21568255de 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index a62d5598fd..fd237dad9f 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0xf000000000000001, R11 + MOVQ $0x1cd1e79196bf0e7a, R12 + MOVQ $0xd0b097f28d83cd49, R13 + MOVQ $0x443f917ea68dafc2, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go index 7b7f9352d9..14505483c7 100644 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 37cabe823e..c533cc1c92 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bls24-317/fr/vector.go b/ecc/bls24-317/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/bls24-317/fr/vector.go +++ b/ecc/bls24-317/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bn254/fp/element_mul_amd64.s b/ecc/bn254/fp/element_mul_amd64.s index e58b316819..9357a21d75 100644 --- a/ecc/bn254/fp/element_mul_amd64.s +++ b/ecc/bn254/fp/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index 83bba45aed..6f16baf686 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 48f34db8fe..cbfba4ee59 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0x3c208c16d87cfd47, R11 + MOVQ $0x97816a916871ca8d, R12 + MOVQ $0xb85045b68181585d, R13 + MOVQ $0x30644e72e131a029, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go index 93aca54ddb..250ac5bce0 100644 --- a/ecc/bn254/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 51db5e1e32..a923ef657d 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bn254/fp/vector.go b/ecc/bn254/fp/vector.go index acf1e44ea9..850b3603d8 100644 --- a/ecc/bn254/fp/vector.go +++ b/ecc/bn254/fp/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bn254/fr/element_mul_amd64.s b/ecc/bn254/fr/element_mul_amd64.s index b51bc69986..4a9321837d 100644 --- a/ecc/bn254/fr/element_mul_amd64.s +++ b/ecc/bn254/fr/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index e40a9caed5..21568255de 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index b9d8a5bfa2..d077b11246 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0x43e1f593f0000001, R11 + MOVQ $0x2833e84879b97091, R12 + MOVQ $0xb85045b68181585d, R13 + MOVQ $0x30644e72e131a029, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go index be0f085832..cd5c53d8fc 100644 --- a/ecc/bn254/fr/element_ops_purego.go +++ b/ecc/bn254/fr/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 47d3eac90f..3be23d96a4 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bn254/fr/vector.go b/ecc/bn254/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/bn254/fr/vector.go +++ b/ecc/bn254/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-633/fp/element_mul_amd64.s b/ecc/bw6-633/fp/element_mul_amd64.s index f0b6172282..62a7d4dda2 100644 --- a/ecc/bw6-633/fp/element_mul_amd64.s +++ b/ecc/bw6-633/fp/element_mul_amd64.s @@ -83,7 +83,7 @@ TEXT ·mul(SB), $64-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), R12 MOVQ y+16(FP), R13 @@ -1323,7 +1323,7 @@ TEXT ·mul(SB), $64-24 MOVQ R11, 72(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -1348,7 +1348,7 @@ TEXT ·fromMont(SB), $64-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -1967,7 +1967,7 @@ TEXT ·fromMont(SB), $64-8 MOVQ R11, 72(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bw6-633/fp/element_ops_amd64.s b/ecc/bw6-633/fp/element_ops_amd64.s index 119efe44f0..12a0789638 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.s +++ b/ecc/bw6-633/fp/element_ops_amd64.s @@ -374,7 +374,7 @@ TEXT ·Butterfly(SB), $56-16 SBBQ 56(AX), R10 SBBQ 64(AX), R11 SBBQ 72(AX), R12 - JCC l1 + JCC noReduce_1 MOVQ $0xd74916ea4570000d, AX ADDQ AX, DX MOVQ $0x3d369bd31147f73c, AX @@ -396,7 +396,7 @@ TEXT ·Butterfly(SB), $56-16 MOVQ $0x0126633cc0f35f63, AX ADCQ AX, R12 -l1: +noReduce_1: MOVQ b+8(FP), AX MOVQ DX, 0(AX) MOVQ CX, 8(AX) diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 08729f0bea..169cd6701c 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -719,6 +719,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bw6-633/fp/vector.go b/ecc/bw6-633/fp/vector.go index bd9388a4b5..1bd71a36e5 100644 --- a/ecc/bw6-633/fp/vector.go +++ b/ecc/bw6-633/fp/vector.go @@ -205,6 +205,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-633/fr/element_mul_amd64.s b/ecc/bw6-633/fr/element_mul_amd64.s index 51165684d5..92bba4f58a 100644 --- a/ecc/bw6-633/fr/element_mul_amd64.s +++ b/ecc/bw6-633/fr/element_mul_amd64.s @@ -63,7 +63,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), DI // x[0] -> R9 @@ -435,7 +435,7 @@ TEXT ·mul(SB), $24-24 MOVQ SI, 32(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -460,7 +460,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -649,7 +649,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ SI, 32(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index b69182ef40..e232de8c8b 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -709,6 +709,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bw6-633/fr/vector.go b/ecc/bw6-633/fr/vector.go index 0a99d9a2f3..1c9b6b9752 100644 --- a/ecc/bw6-633/fr/vector.go +++ b/ecc/bw6-633/fr/vector.go @@ -200,6 +200,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-761/fp/element_mul_amd64.s b/ecc/bw6-761/fp/element_mul_amd64.s index 478922653b..fd48d8606c 100644 --- a/ecc/bw6-761/fp/element_mul_amd64.s +++ b/ecc/bw6-761/fp/element_mul_amd64.s @@ -91,7 +91,7 @@ TEXT ·mul(SB), $96-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), AX // x[0] -> s0-8(SP) @@ -1865,7 +1865,7 @@ TEXT ·mul(SB), $96-24 MOVQ R13, 88(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -1890,7 +1890,7 @@ TEXT ·fromMont(SB), $96-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -2751,7 +2751,7 @@ TEXT ·fromMont(SB), $96-8 MOVQ R13, 88(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bw6-761/fp/element_ops_amd64.s b/ecc/bw6-761/fp/element_ops_amd64.s index c0f7ed2392..476e9e39ec 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.s +++ b/ecc/bw6-761/fp/element_ops_amd64.s @@ -430,7 +430,7 @@ TEXT ·Butterfly(SB), $88-16 SBBQ 72(AX), R12 SBBQ 80(AX), R13 SBBQ 88(AX), R14 - JCC l1 + JCC noReduce_1 MOVQ $0xf49d00000000008b, AX ADDQ AX, DX MOVQ $0xe6913e6870000082, AX @@ -456,7 +456,7 @@ TEXT ·Butterfly(SB), $88-16 MOVQ $0x0122e824fb83ce0a, AX ADCQ AX, R14 -l1: +noReduce_1: MOVQ b+8(FP), AX MOVQ DX, 0(AX) MOVQ CX, 8(AX) diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index 4bdf251c7b..fbba1f2864 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -723,6 +723,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bw6-761/fp/vector.go b/ecc/bw6-761/fp/vector.go index f136ce9b30..87105028b8 100644 --- a/ecc/bw6-761/fp/vector.go +++ b/ecc/bw6-761/fp/vector.go @@ -207,6 +207,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-761/fr/element_mul_amd64.s b/ecc/bw6-761/fr/element_mul_amd64.s index 3e7650e5aa..1e19c4d3fd 100644 --- a/ecc/bw6-761/fr/element_mul_amd64.s +++ b/ecc/bw6-761/fr/element_mul_amd64.s @@ -67,7 +67,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), R8 // x[0] -> R10 @@ -570,7 +570,7 @@ TEXT ·mul(SB), $24-24 MOVQ DI, 40(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -595,7 +595,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -850,7 +850,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ DI, 40(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index dafeea8ed1..0596297e8c 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -711,6 +711,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/bw6-761/fr/vector.go b/ecc/bw6-761/fr/vector.go index ee8dc421db..8dd4774c5a 100644 --- a/ecc/bw6-761/fr/vector.go +++ b/ecc/bw6-761/fr/vector.go @@ -201,6 +201,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_ops_purego.go index 31298ba1fd..a8624a511e 100644 --- a/ecc/secp256k1/fp/element_ops_purego.go +++ b/ecc/secp256k1/fp/element_ops_purego.go @@ -57,6 +57,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 6790f233b1..6f8165b18d 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -705,6 +705,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/secp256k1/fp/vector.go b/ecc/secp256k1/fp/vector.go index acf1e44ea9..850b3603d8 100644 --- a/ecc/secp256k1/fp/vector.go +++ b/ecc/secp256k1/fp/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_ops_purego.go index 7d3b6e3ddf..1a46f6d791 100644 --- a/ecc/secp256k1/fr/element_ops_purego.go +++ b/ecc/secp256k1/fr/element_ops_purego.go @@ -57,6 +57,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index 64906afcab..f554db8e3c 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -705,6 +705,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/secp256k1/fr/vector.go b/ecc/secp256k1/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/secp256k1/fr/vector.go +++ b/ecc/secp256k1/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/stark-curve/fp/element_mul_amd64.s b/ecc/stark-curve/fp/element_mul_amd64.s index eaa4d8216e..fab328c861 100644 --- a/ecc/stark-curve/fp/element_mul_amd64.s +++ b/ecc/stark-curve/fp/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/element_ops_amd64.go index 83bba45aed..6f16baf686 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index 78bbdc60dd..914653b705 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $1, R11 + MOVQ $0, R12 + MOVQ $0, R13 + MOVQ $0x0800000000000011, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go index b946a6f4d0..4906d13e0c 100644 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 897893cdf3..87e38f7c10 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/stark-curve/fp/vector.go b/ecc/stark-curve/fp/vector.go index acf1e44ea9..850b3603d8 100644 --- a/ecc/stark-curve/fp/vector.go +++ b/ecc/stark-curve/fp/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/stark-curve/fr/element_mul_amd64.s b/ecc/stark-curve/fr/element_mul_amd64.s index ec316ceb71..8eb931e77a 100644 --- a/ecc/stark-curve/fr/element_mul_amd64.s +++ b/ecc/stark-curve/fr/element_mul_amd64.s @@ -59,7 +59,7 @@ TEXT ·mul(SB), $24-24 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE l1 + JNE noAdx_1 MOVQ x+8(FP), SI // x[0] -> DI @@ -322,7 +322,7 @@ TEXT ·mul(SB), $24-24 MOVQ BX, 24(AX) RET -l1: +noAdx_1: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -347,7 +347,7 @@ TEXT ·fromMont(SB), $8-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE l2 + JNE noAdx_2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R13 @@ -480,7 +480,7 @@ TEXT ·fromMont(SB), $8-8 MOVQ BX, 24(AX) RET -l2: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go index e40a9caed5..21568255de 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ b/ecc/stark-curve/fr/element_ops_amd64.go @@ -45,6 +45,42 @@ func reduce(res *Element) //go:noescape func Butterfly(a, b *Element) +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 2cb4e1b3d2..245dcb8959 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -228,3 +228,400 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 MOVQ SI, 16(AX) MOVQ DI, 24(AX) RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $0x1e66a241adc64d2f, R11 + MOVQ $0xb781126dcae7b232, R12 + MOVQ $0xffffffffffffffff, R13 + MOVQ $0x0800000000000010, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go index 6458506dd4..b04f5202fd 100644 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_ops_purego.go @@ -60,6 +60,24 @@ func reduce(z *Element) { _reduceGeneric(z) } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index f983e8e780..b81aff116e 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -707,6 +707,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/ecc/stark-curve/fr/vector.go b/ecc/stark-curve/fr/vector.go index 00ad8a8986..f39828547f 100644 --- a/ecc/stark-curve/fr/vector.go +++ b/ecc/stark-curve/fr/vector.go @@ -199,6 +199,33 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/field/generator/asm/amd64/asm_macros.go b/field/generator/asm/amd64/asm_macros.go index 8a5eeee82d..45d324c948 100644 --- a/field/generator/asm/amd64/asm_macros.go +++ b/field/generator/asm/amd64/asm_macros.go @@ -86,7 +86,6 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 CMOVQCS rb{{$i}}, ra{{$i}}; \ {{- end}} - ` func (f *FFAmd64) GenerateDefines() { diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index ed022ee185..b760ad3e3f 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -158,6 +158,13 @@ func Generate(w io.Writer, F *config.FieldConfig) error { // fft butterflies f.generateButterfly() + // generate vector operations for "small" modulus + if f.NbWords == 4 { + f.generateAddVec() + f.generateSubVec() + f.generateScalarMulVec() + } + return nil } diff --git a/field/generator/asm/amd64/element_butterfly.go b/field/generator/asm/amd64/element_butterfly.go index 87a2c54683..2d996754ae 100644 --- a/field/generator/asm/amd64/element_butterfly.go +++ b/field/generator/asm/amd64/element_butterfly.go @@ -107,7 +107,7 @@ func (f *FFAmd64) generateButterfly() { f.Sub(r, t0) // t0 = a - b // reduce t0 - noReduce := f.NewLabel() + noReduce := f.NewLabel("noReduce") f.JCC(noReduce) q := r f.MOVQ(f.Q[0], q) diff --git a/field/generator/asm/amd64/element_frommont.go b/field/generator/asm/amd64/element_frommont.go index b19bea1687..79b717dccc 100644 --- a/field/generator/asm/amd64/element_frommont.go +++ b/field/generator/asm/amd64/element_frommont.go @@ -54,7 +54,7 @@ func (f *FFAmd64) generateFromMont(forceADX bool) { // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C`) - noAdx := f.NewLabel() + noAdx := f.NewLabel("noAdx") if !forceADX { // check ADX instruction support f.CMPB("·supportAdx(SB)", 1) diff --git a/field/generator/asm/amd64/element_mul.go b/field/generator/asm/amd64/element_mul.go index 5533e8ca38..df63f1ad4b 100644 --- a/field/generator/asm/amd64/element_mul.go +++ b/field/generator/asm/amd64/element_mul.go @@ -169,7 +169,7 @@ func (f *FFAmd64) generateMul(forceADX bool) { f.WriteLn("NO_LOCAL_POINTERS") } - noAdx := f.NewLabel() + noAdx := f.NewLabel("noAdx") if !forceADX { // check ADX instruction support diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go new file mode 100644 index 0000000000..05c2cf3f19 --- /dev/null +++ b/field/generator/asm/amd64/element_vec.go @@ -0,0 +1,242 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package amd64 + +import "github.com/consensys/bavard/amd64" + +// addVec res = a + b +// func addVec(res, a, b *{{.ElementName}}, n uint64) +func (f *FFAmd64) generateAddVec() { + f.Comment("addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n]") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + registers := f.FnHeader("addVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrRes := f.Pop(®isters) + len := f.Pop(®isters) + + a := f.PopN(®isters) + t := f.PopN(®isters) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("res+0(FP)", addrRes) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // a = a + b + f.LabelRegisters("a", a...) + f.Mov(addrA, a) + f.Add(addrB, a) + + // reduce a + f.ReduceElement(a, t) + + // save a into res + f.Mov(a, addrRes) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.ADDQ("$32", addrB) + f.ADDQ("$32", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + f.RET() + + f.Push(®isters, a...) + f.Push(®isters, t...) + f.Push(®isters, addrA, addrB, addrRes, len) + +} + +// subVec res = a - b +// func subVec(res, a, b *{{.ElementName}}, n uint64) +func (f *FFAmd64) generateSubVec() { + f.Comment("subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n]") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*2+5, 0, 0) + registers := f.FnHeader("subVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrRes := f.Pop(®isters) + len := f.Pop(®isters) + zero := f.Pop(®isters) + + a := f.PopN(®isters) + q := f.PopN(®isters) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("res+0(FP)", addrRes) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.XORQ(zero, zero) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // a = a - b + f.LabelRegisters("a", a...) + f.Mov(addrA, a) + f.Sub(addrB, a) + + // reduce a + f.Comment("reduce (a-b) mod q") + f.LabelRegisters("q", q...) + f.Mov(f.Q, q) + for i := 0; i < f.NbWords; i++ { + f.CMOVQCC(zero, q[i]) + } + // add registers (q or 0) to a, and set to result + f.Comment("add registers (q or 0) to a, and set to result") + f.Add(q, a) + + // save a into res + f.Mov(a, addrRes) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.ADDQ("$32", addrB) + f.ADDQ("$32", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + f.RET() + + f.Push(®isters, a...) + f.Push(®isters, q...) + f.Push(®isters, addrA, addrB, addrRes, len, zero) + +} + +// scalarMulVec res = a * b +// func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) +func (f *FFAmd64) generateScalarMulVec() { + f.Comment("scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b") + + const argSize = 4 * 8 + const minStackSize = 7 * 8 // 2 slices (3 words each) + pointer to the scalar + stackSize := f.StackSize(f.NbWords*2+3, 2, minStackSize) + reserved := []amd64.Register{amd64.DX, amd64.AX} + registers := f.FnHeader("scalarMulVec", stackSize, argSize, reserved...) + defer f.AssertCleanStack(stackSize, minStackSize) + + // labels & registers we need + noAdx := f.NewLabel("noAdx") + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + t := registers.PopN(f.NbWords) + scalar := registers.PopN(f.NbWords) + + addrB := registers.Pop() + addrA := registers.Pop() + addrRes := addrB + len := registers.Pop() + + // check ADX instruction support + f.CMPB("·supportAdx(SB)", 1) + f.JNE(noAdx) + + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + // we store b, the scalar, fully in registers + f.LabelRegisters("scalar", scalar...) + f.Mov(addrB, scalar) + + xat := func(i int) string { + return string(scalar[i]) + } + + f.MOVQ("res+0(FP)", addrRes) + + f.LABEL(loop) + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + yat := func(i int) string { + return addrA.At(i) + } + + f.Comment("TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function") + + f.MulADX(®isters, xat, yat, t) + + // registers.Push(addrA) + + // reduce; we need at least 4 extra registers + registers.Push(amd64.AX, amd64.DX) + f.Comment("reduce t mod q") + f.Reduce(®isters, t) + f.Mov(t, addrRes) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.ADDQ("$32", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + f.RET() + + // no ADX support + f.LABEL(noAdx) + + f.MOVQ("n+24(FP)", amd64.DX) + + f.MOVQ("res+0(FP)", amd64.AX) + f.MOVQ(amd64.AX, "(SP)") + f.MOVQ(amd64.DX, "8(SP)") // len + f.MOVQ(amd64.DX, "16(SP)") // cap + f.MOVQ("a+8(FP)", amd64.AX) + f.MOVQ(amd64.AX, "24(SP)") + f.MOVQ(amd64.DX, "32(SP)") // len + f.MOVQ(amd64.DX, "40(SP)") // cap + f.MOVQ("b+16(FP)", amd64.AX) + f.MOVQ(amd64.AX, "48(SP)") + f.WriteLn("CALL ·scalarMulVecGeneric(SB)") + f.RET() + +} diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 8f5c01510a..ffa7231e15 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -29,7 +29,43 @@ func reduce(res *{{.ElementName}}) //go:noescape func Butterfly(a, b *{{.ElementName}}) +{{- if eq .NbWords 4}} +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func addVec(res, a, b *{{.ElementName}}, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} +//go:noescape +func subVec(res, a, b *{{.ElementName}}, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) +} + +//go:noescape +func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) +{{- end}} // Mul z = x * y (mod q) // diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index 9447465c1e..a4fde0d053 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -50,7 +50,25 @@ func reduce(z *{{.ElementName}}) { _reduceGeneric(z) } +{{- if eq .NbWords 4}} +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + scalarMulVecGeneric(*vector, a, b) +} +{{- end}} // Mul z = x * y (mod q) {{- if $.NoCarry}} diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 5a58b065b4..416b3f30e6 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -728,6 +728,78 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { } +func Test{{toTitle .ElementName}}VecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected {{.ElementName}} + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected {{.ElementName}} + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected {{.ElementName}} + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + {{template "testBinaryOp" dict "all" . "Op" "Add"}} {{template "testBinaryOp" dict "all" . "Op" "Sub"}} diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 35003afddb..8f06b54c9f 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -192,6 +192,56 @@ func (vector Vector) Swap(i, j int) { } +{{/* For 4 elements, we have a special assembly path and copy this in ops_pure.go */}} +{{- if ne .NbWords 4}} +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + scalarMulVecGeneric(*vector, a, b) +} +{{- end}} + + + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *{{.ElementName}}) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 195919bfe0..339fb4ea61 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -652,6 +652,77 @@ func TestElementLexicographicallyLargest(t *testing.T) { } +func TestElementVecOps(t *testing.T) { + assert := require.New(t) + + const N = 7 + a := make(Vector, N) + b := make(Vector, N) + c := make(Vector, N) + for i := 0; i < N; i++ { + a[i].SetRandom() + b[i].SetRandom() + } + + // Vector addition + c.Add(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } + + // Vector subtraction + c.Sub(a, b) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &b[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } + + // Vector scaling + c.ScalarMul(a, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&a[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } +} + +func BenchmarkElementVecOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1024 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + for i := 0; i < N; i++ { + a1[i].SetRandom() + b1[i].SetRandom() + } + + b.Run("Add", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Add(a1, b1) + } + }) + + b.Run("Sub", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.Sub(a1, b1) + } + }) + + b.Run("ScalarMul", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + c1.ScalarMul(a1, &b1[0]) + } + }) +} + func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() diff --git a/field/goldilocks/vector.go b/field/goldilocks/vector.go index 5c0d964b12..3de71afb88 100644 --- a/field/goldilocks/vector.go +++ b/field/goldilocks/vector.go @@ -196,6 +196,51 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/go.mod b/go.mod index f1e439cbc4..1cc1f399bb 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.13 + github.com/consensys/bavard v0.1.15 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index e6a6f98bbc..2c324c00b1 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= -github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/bavard v0.1.15 h1:fxv2mg1afRMJvZgpwEgLmyr2MsQwaAYcyKf31UBHzw4= +github.com/consensys/bavard v0.1.15/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=