Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete fix of 5.4 Reshare Denial-of-Service via Predicable Instance IDs #162

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions pkgs/initiator/initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/herumi/bls-eth-go-binary/bls"
"github.com/imroc/req/v3"
"go.uber.org/zap"

spec "github.com/ssvlabs/dkg-spec"
spec_crypto "github.com/ssvlabs/dkg-spec/crypto"
"github.com/ssvlabs/ssv-dkg/pkgs/consts"
"github.com/ssvlabs/ssv-dkg/pkgs/crypto"
"github.com/ssvlabs/ssv-dkg/pkgs/utils"
"github.com/ssvlabs/ssv-dkg/pkgs/wire"
"go.uber.org/zap"
)

type VerifyMessageSignatureFunc func(pub *rsa.PublicKey, msg, sig []byte) error
Expand Down Expand Up @@ -165,7 +165,7 @@ func ValidatedOperatorData(ids []uint64, operators wire.OperatorsCLI) ([]*spec.O
}

// messageFlowHandling main steps of DKG at initiator
func (c *Initiator) initMessageFlowHandling(init *spec.Init, id [24]byte, operators []*spec.Operator) ([][]byte, error) {
func (c *Initiator) initMessageFlowHandling(init *spec.Init, id, instanceID [24]byte, operators []*spec.Operator) ([][]byte, error) {
c.Logger.Info("phase 1: sending init message to operators")
exchangeMsgs, errs, err := c.SendInitMsg(id, init, operators)
if err != nil {
Expand All @@ -175,33 +175,33 @@ func (c *Initiator) initMessageFlowHandling(init *spec.Init, id [24]byte, operat
if err := checkThreshold(exchangeMsgs, errs, operators, operators, len(operators)); err != nil {
return nil, err
}
err = verifyMessageSignatures(id, exchangeMsgs, c.VerifyMessageSignature)
err = verifyMessageSignatures(instanceID, exchangeMsgs, c.VerifyMessageSignature)
if err != nil {
return nil, err
}
c.Logger.Info("phase 1: ✅ verified operator init responses signatures")
c.Logger.Info("phase 2: ➡️ sending operator data (exchange messages) required for dkg")
kyberMsgs, errs, err := c.SendExchangeMsgs(id, exchangeMsgs, operators)
kyberMsgs, errs, err := c.SendExchangeMsgs(instanceID, exchangeMsgs, operators)
if err != nil {
return nil, err
}
if err := checkThreshold(kyberMsgs, errs, operators, operators, len(operators)); err != nil {
return nil, err
}
err = verifyMessageSignatures(id, kyberMsgs, c.VerifyMessageSignature)
err = verifyMessageSignatures(instanceID, kyberMsgs, c.VerifyMessageSignature)
if err != nil {
return nil, err
}
c.Logger.Info("phase 2: ✅ verified operator responses (deal messages) signatures")
c.Logger.Info("phase 3: ➡️ sending deal dkg data to all operators")
dkgResult, errs, err := c.SendKyberMsgs(id, kyberMsgs, operators)
dkgResult, errs, err := c.SendKyberMsgs(instanceID, kyberMsgs, operators)
if err != nil {
return nil, err
}
if err := checkThreshold(dkgResult, errs, operators, operators, len(operators)); err != nil {
return nil, err
}
err = verifyMessageSignatures(id, dkgResult, c.VerifyMessageSignature)
err = verifyMessageSignatures(instanceID, dkgResult, c.VerifyMessageSignature)
if err != nil {
return nil, err
}
Expand All @@ -215,9 +215,13 @@ func (c *Initiator) initMessageFlowHandling(init *spec.Init, id [24]byte, operat

func (c *Initiator) ResignMessageFlowHandling(signedResign *wire.SignedResign, id [24]byte, operators []*spec.Operator) ([][][]byte, error) {
// reqIDtracker is used to track if all ceremony are in the responses in the expected order
pub, err := spec_crypto.EncodeRSAPublicKey(&c.PrivateKey.PublicKey)
if err != nil {
return nil, err
}
reqIDs := make([][24]byte, 0)
for _, msg := range signedResign.Messages {
msgID, err := utils.GetReqIDfromMsg(msg, id)
msgID, err := utils.GetInstanceIDfromMsg(msg, id, pub)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -269,10 +273,14 @@ func (c *Initiator) ReshareMessageFlowHandling(id [24]byte, signedReshare *wire.
if err != nil {
return nil, err
}
pub, err := spec_crypto.EncodeRSAPublicKey(&c.PrivateKey.PublicKey)
if err != nil {
return nil, err
}
// reqIDtracker is used to track if all ceremony are in the responses in the expected order
reqIDs := make([][24]byte, 0)
for _, msg := range signedReshare.Messages {
msgID, err := utils.GetReqIDfromMsg(msg, id)
msgID, err := utils.GetInstanceIDfromMsg(msg, id, pub)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -383,20 +391,32 @@ func (c *Initiator) StartDKG(id [24]byte, withdraw []byte, ids []uint64, network
zap.Uint64("nonce", init.Nonce),
zap.Any("operator IDs", ids))
c.Logger = c.Logger.With(instanceIDField)
dkgResultsBytes, err := c.initMessageFlowHandling(init, id, ops)
pub, err := spec_crypto.EncodeRSAPublicKey(&c.PrivateKey.PublicKey)
if err != nil {
return nil, nil, nil, err
}
return c.CreateCeremonyResults(dkgResultsBytes, id, init.Operators, init.WithdrawalCredentials, nil, init.Fork, init.Owner, init.Nonce, phase0.Gwei(init.Amount))
instanceID, err := utils.GetInstanceIDfromMsg(init, id, pub)
if err != nil {
return nil, nil, nil, err
}
dkgResultsBytes, err := c.initMessageFlowHandling(init, id, instanceID, ops)
if err != nil {
return nil, nil, nil, err
}
return c.CreateCeremonyResults(dkgResultsBytes, instanceID, init.Operators, init.WithdrawalCredentials, nil, init.Fork, init.Owner, init.Nonce, phase0.Gwei(init.Amount))
}

func (c *Initiator) StartResigning(id [24]byte, signedResign *wire.SignedResign) ([]*wire.DepositDataCLI, []*wire.KeySharesCLI, [][]*wire.SignedProof, error) {
if len(signedResign.Messages) == 0 {
return nil, nil, nil, errors.New("no resign messages")
}
pub, err := spec_crypto.EncodeRSAPublicKey(&c.PrivateKey.PublicKey)
if err != nil {
return nil, nil, nil, err
}
resignIDMap := make(map[[24]byte]*spec.Resign)
for _, msg := range signedResign.Messages {
msgID, err := utils.GetReqIDfromMsg(msg, id)
msgID, err := utils.GetInstanceIDfromMsg(msg, id, pub)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -511,9 +531,13 @@ func (c *Initiator) StartResharing(id [24]byte, signedReshare *wire.SignedReshar
if len(signedReshare.Messages) == 0 {
return nil, nil, nil, errors.New("no reshare messages")
}
pub, err := spec_crypto.EncodeRSAPublicKey(&c.PrivateKey.PublicKey)
if err != nil {
return nil, nil, nil, err
}
reshareIDMap := make(map[[24]byte]*spec.Reshare)
for _, msg := range signedReshare.Messages {
msgID, err := utils.GetReqIDfromMsg(msg, id)
msgID, err := utils.GetInstanceIDfromMsg(msg, id, pub)
if err != nil {
return nil, nil, nil, err
}
Expand Down
18 changes: 11 additions & 7 deletions pkgs/initiator/initiator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/attestantio/go-eth2-client/spec/phase0"
e2m_core "github.com/bloxapp/eth2-key-manager/core"
"github.com/bloxapp/ssv/logging"
kyber_bls12381 "github.com/drand/kyber-bls12381"
kyber_dkg "github.com/drand/kyber/share/dkg"
Expand All @@ -18,7 +19,6 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/zap"

e2m_core "github.com/bloxapp/eth2-key-manager/core"
spec "github.com/ssvlabs/dkg-spec"
spec_crypto "github.com/ssvlabs/dkg-spec/crypto"
"github.com/ssvlabs/dkg-spec/testing/stubs"
Expand Down Expand Up @@ -468,7 +468,11 @@ func TestDKGFailWithOperatorsMisbehave(t *testing.T) {

exchangeMsgs, _, err := intr.SendInitMsg(id, init, ops)
require.NoError(t, err)
kyberMsgs, _, err := intr.SendExchangeMsgs(id, exchangeMsgs, ops)
pub, err := spec_crypto.EncodeRSAPublicKey(&intr.PrivateKey.PublicKey)
require.NoError(t, err)
instanceID, err := utils.GetInstanceIDfromMsg(init, id, pub)
require.NoError(t, err)
kyberMsgs, _, err := intr.SendExchangeMsgs(instanceID, exchangeMsgs, ops)
require.NoError(t, err)

tsp := &wire.SignedTransport{}
Expand Down Expand Up @@ -508,7 +512,7 @@ func TestDKGFailWithOperatorsMisbehave(t *testing.T) {

trsp := &wire.Transport{
Type: wire.KyberMessageType,
Identifier: id,
Identifier: instanceID,
Data: byts,
Version: intr.Version,
}
Expand All @@ -519,19 +523,19 @@ func TestDKGFailWithOperatorsMisbehave(t *testing.T) {
sign, err := srv1.Srv.State.Sign(bts)
require.NoError(t, err)

pub, err := spec_crypto.EncodeRSAPublicKey(&srv1.Srv.State.PrivateKey.PublicKey)
pubOp, err := spec_crypto.EncodeRSAPublicKey(&srv1.Srv.State.PrivateKey.PublicKey)
require.NoError(t, err)

signed := &wire.SignedTransport{
Message: trsp,
Signer: pub,
Signer: pubOp,
Signature: sign,
}
final, err := signed.MarshalSSZ()
kyberMsgs[srv1.ID] = final
require.NoError(t, err)

dkgResult, errs, err := intr.SendKyberMsgs(id, kyberMsgs, ops)
dkgResult, errs, err := intr.SendKyberMsgs(instanceID, kyberMsgs, ops)
require.NoError(t, err)

for _, err := range errs {
Expand All @@ -542,7 +546,7 @@ func TestDKGFailWithOperatorsMisbehave(t *testing.T) {
finalResults = append(finalResults, res)
}

_, _, _, err = intr.CreateCeremonyResults(finalResults, id, init.Operators, init.WithdrawalCredentials, nil, init.Fork, init.Owner, init.Nonce, phase0.Gwei(init.Amount))
_, _, _, err = intr.CreateCeremonyResults(finalResults, instanceID, init.Operators, init.WithdrawalCredentials, nil, init.Fork, init.Owner, init.Nonce, phase0.Gwei(init.Amount))
require.ErrorContains(t, err, "protocol failed with response complaints")
})

Expand Down
18 changes: 13 additions & 5 deletions pkgs/operator/instances_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,20 @@ func (s *Switch) InitInstance(reqID [24]byte, initMsg *wire.Transport, initiator
return nil, fmt.Errorf("init: initiator signature isn't valid: %w", err)
}
s.Logger.Info("✅ init message signature is successfully verified", zap.String("from initiator", fmt.Sprintf("%x", initiatorPubKey.N.Bytes())))
if err := s.validateInstances(reqID); err != nil {
instanceID, err := utils.GetInstanceIDfromMsg(init, reqID, initiatorPub)
if err != nil {
return nil, err
}
inst, resp, err := s.CreateInstance(reqID, init.Operators, init, initiatorPubKey)
if err := s.validateInstances(instanceID); err != nil {
return nil, err
}
inst, resp, err := s.CreateInstance(instanceID, init.Operators, init, initiatorPubKey)
if err != nil {
return nil, fmt.Errorf("init: failed to create instance: %w", err)
}
s.Mtx.Lock()
s.Instances[reqID] = inst
s.InstanceInitTime[reqID] = time.Now()
s.Instances[instanceID] = inst
s.InstanceInitTime[instanceID] = time.Now()
s.Mtx.Unlock()
return resp, nil
}
Expand Down Expand Up @@ -286,7 +290,11 @@ func (s *Switch) validateInstances(reqID InstanceID) error {
}

func (s *Switch) runInstance(reqID [24]byte, instance interface{}, allOps []*spec.Operator, initiatorPubKey *rsa.PublicKey, operationType string) ([]byte, error) {
instanceID, err := utils.GetReqIDfromMsg(instance, reqID)
pub, err := spec_crypto.EncodeRSAPublicKey(initiatorPubKey)
if err != nil {
return nil, err
}
instanceID, err := utils.GetInstanceIDfromMsg(instance, reqID, pub)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion pkgs/operator/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,15 @@ func TestSwitch_cleanInstances(t *testing.T) {
require.NoError(t, err)
sig, err := spec_crypto.SignRSA(priv, tsssz)
require.NoError(t, err)
instanceID, err := utils.GetInstanceIDfromMsg(init, reqID, encPubKey)
require.NoError(t, err)
resp, err := swtch.InitInstance(reqID, initMessage, encPubKey, sig)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, swtch.cleanInstances(), 0)

require.Len(t, swtch.Instances, 1)
swtch.InstanceInitTime[reqID] = time.Now().Add(-time.Minute * 6)
swtch.InstanceInitTime[instanceID] = time.Now().Add(-time.Minute * 6)

require.Equal(t, swtch.cleanInstances(), 1)
require.Len(t, swtch.Instances, 0)
Expand Down
7 changes: 4 additions & 3 deletions pkgs/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,16 @@ func GetMessageHash(msg interface{}) ([32]byte, error) {
return hash, nil
}

func GetReqIDfromMsg(instance interface{}, id [24]byte) ([24]byte, error) {
func GetInstanceIDfromMsg(instance interface{}, id [24]byte, initiatorPub []byte) ([24]byte, error) {
// make a unique ID for each reshare using the instance hash
reqID := [24]byte{}
instanceHash, err := GetMessageHash(instance)
if err != nil {
return reqID, fmt.Errorf("failed to get reqID: %w", err)
}
copy(reqID[:12], id[:12])
copy(reqID[12:24], instanceHash[:12])
copy(reqID[:8], eth_crypto.Keccak256(initiatorPub)[:8])
copy(reqID[8:16], instanceHash[:8])
copy(reqID[16:24], id[:8])
return reqID, nil
}

Expand Down
Loading