diff --git a/mpc/binance/ecdsa/mpc.go b/mpc/binance/ecdsa/mpc.go index 5ec970d..3209749 100644 --- a/mpc/binance/ecdsa/mpc.go +++ b/mpc/binance/ecdsa/mpc.go @@ -26,6 +26,9 @@ import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/btcsuite/btcd/btcec/v2" + s256k1 "github.com/btcsuite/btcd/btcec/v2" + "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/any" ) @@ -106,14 +109,20 @@ type party struct { in chan tss.Message shareData *keygen.LocalPartySaveData closeChan chan struct{} + curve elliptic.Curve } -func NewParty(id uint16, logger Logger) *party { +func NewParty(id uint16, curve elliptic.Curve, logger Logger) *party { + if curve == nil { + curve = s256k1.S256() + } + return &party{ logger: logger, id: tss.NewPartyID(fmt.Sprintf("%d", id), "", big.NewInt(int64(id))), out: make(chan tss.Message, 1000), in: make(chan tss.Message, 1000), + curve: curve, } } @@ -190,7 +199,17 @@ func (p *party) ThresholdPK() ([]byte, error) { if err != nil { return nil, err } - return x509.MarshalPKIXPublicKey(pk) + + switch p.curve.Params().Name { + case string(tss.Secp256k1): + xFieldVal, yFieldVal := new(secp256k1.FieldVal), new(secp256k1.FieldVal) + xFieldVal.SetByteSlice(pk.X.Bytes()) + yFieldVal.SetByteSlice(pk.Y.Bytes()) + btcecPubKey := btcec.NewPublicKey(xFieldVal, yFieldVal) + return btcecPubKey.SerializeCompressed(), nil + default: + return x509.MarshalPKIXPublicKey(pk) + } } func (p *party) SetShareData(shareData []byte) error { @@ -199,9 +218,9 @@ func (p *party) SetShareData(shareData []byte) error { if err != nil { return fmt.Errorf("failed deserializing shares: %w", err) } - localSaveData.ECDSAPub.SetCurve(elliptic.P256()) + localSaveData.ECDSAPub.SetCurve(p.curve) for _, xj := range localSaveData.BigXj { - xj.SetCurve(elliptic.P256()) + xj.SetCurve(p.curve) } p.shareData = &localSaveData return nil @@ -210,7 +229,7 @@ func (p *party) SetShareData(shareData []byte) error { func (p *party) Init(parties []uint16, threshold int, sendMsg func(msg []byte, isBroadcast bool, to uint16)) { partyIDs := partyIDsFromNumbers(parties) ctx := tss.NewPeerContext(partyIDs) - p.params = tss.NewParameters(elliptic.P256(), ctx, p.id, len(parties), threshold) + p.params = tss.NewParameters(p.curve, ctx, p.id, len(parties), threshold) p.id.Index = p.locatePartyIndex(p.id) p.sendMsg = sendMsg p.closeChan = make(chan struct{}) @@ -237,7 +256,7 @@ func (p *party) Sign(ctx context.Context, msgHash []byte) ([]byte, error) { end := make(chan *common.SignatureData, 1) - msgToSign := hashToInt(msgHash, elliptic.P256()) + msgToSign := hashToInt(msgHash, p.curve) party := signing.NewLocalParty(msgToSign, p.params, *p.shareData, p.out, end) var endWG sync.WaitGroup diff --git a/mpc/binance/ecdsa/mpc_test.go b/mpc/binance/ecdsa/mpc_test.go index f90e628..273ab6a 100644 --- a/mpc/binance/ecdsa/mpc_test.go +++ b/mpc/binance/ecdsa/mpc_test.go @@ -9,6 +9,7 @@ package ecdsa import ( "context" "crypto/ecdsa" + "crypto/elliptic" "fmt" "math/big" "sync" @@ -17,6 +18,10 @@ import ( "time" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/btcsuite/btcd/btcec/v2" + s256k1 "github.com/btcsuite/btcd/btcec/v2" + btcecdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -108,45 +113,76 @@ func (parties parties) Mapping() map[string]*tss.PartyID { } func TestTSS(t *testing.T) { - pA := NewParty(1, logger("pA", t.Name())) - pB := NewParty(2, logger("pB", t.Name())) - pC := NewParty(3, logger("pC", t.Name())) + curves := []elliptic.Curve{ + elliptic.P256(), + s256k1.S256(), + } + + for _, tc := range curves { + t.Run(tc.Params().Name, func(t *testing.T) { + pA := NewParty(1, tc, logger("pA", t.Name())) + pB := NewParty(2, tc, logger("pB", t.Name())) + pC := NewParty(3, tc, logger("pC", t.Name())) + + t.Logf("Created parties") - t.Logf("Created parties") + parties := parties{pA, pB, pC} + parties.init(senders(parties)) - parties := parties{pA, pB, pC} - parties.init(senders(parties)) + t.Logf("Running DKG") - t.Logf("Running DKG") + t1 := time.Now() + shares, err := parties.keygen() + assert.NoError(t, err) + t.Logf("DKG elapsed %s", time.Since(t1)) - t1 := time.Now() - shares, err := parties.keygen() - assert.NoError(t, err) - t.Logf("DKG elapsed %s", time.Since(t1)) + parties.init(senders(parties)) - parties.init(senders(parties)) + parties.setShareData(shares) + t.Logf("Signing") - parties.setShareData(shares) - t.Logf("Signing") + msgToSign := []byte("bla bla") - msgToSign := []byte("bla bla") + t.Logf("Signing message") + t1 = time.Now() + sigs, err := parties.sign(digest(msgToSign)) + assert.NoError(t, err) + t.Logf("Signing completed in %v", time.Since(t1)) + + sigSet := make(map[string]struct{}) + for _, s := range sigs { + sigSet[string(s)] = struct{}{} + } + assert.Len(t, sigSet, 1) - t.Logf("Signing message") - t1 = time.Now() - sigs, err := parties.sign(digest(msgToSign)) - assert.NoError(t, err) - t.Logf("Signing completed in %v", time.Since(t1)) + pk, err := parties[0].TPubKey() + assert.NoError(t, err) - sigSet := make(map[string]struct{}) - for _, s := range sigs { - sigSet[string(s)] = struct{}{} + assert.True(t, verifySignature(tc.Params().Name, pk, msgToSign, sigs[0])) + }) } - assert.Len(t, sigSet, 1) +} + +func verifySignature(curveName string, pk *ecdsa.PublicKey, msg []byte, sig []byte) bool { + switch curveName { + case elliptic.P256().Params().Name: + return ecdsa.VerifyASN1(pk, digest(msg), sig) + case s256k1.S256().Params().Name: + // convert pk to s256k1.PublicKey + xFieldVal, yFieldVal := new(secp256k1.FieldVal), new(secp256k1.FieldVal) + xFieldVal.SetByteSlice(pk.X.Bytes()) + yFieldVal.SetByteSlice(pk.Y.Bytes()) + btcecPubKey := btcec.NewPublicKey(xFieldVal, yFieldVal) + + signature, err := btcecdsa.ParseDERSignature(sig) + if err != nil { + return false + } - pk, err := parties[0].TPubKey() - assert.NoError(t, err) + return signature.Verify(digest(msg), btcecPubKey) + } - assert.True(t, ecdsa.VerifyASN1(pk, digest(msgToSign), sigs[0])) + return false } func senders(parties parties) []Sender { diff --git a/test/binance/ecdsa_test.go b/test/binance/ecdsa_test.go index 3cc1c4d..0a4edc9 100644 --- a/test/binance/ecdsa_test.go +++ b/test/binance/ecdsa_test.go @@ -2,51 +2,94 @@ package binance_test import ( "crypto/ecdsa" + "crypto/elliptic" "crypto/x509" "testing" ecdsa_scheme "github.com/IBM/TSS/mpc/binance/ecdsa" . "github.com/IBM/TSS/types" + s256k1 "github.com/btcsuite/btcd/btcec/v2" + btcecdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/stretchr/testify/assert" ) func TestThresholdBinanceECDSA(t *testing.T) { - n := 4 + curves := []elliptic.Curve{ + elliptic.P256(), + s256k1.S256(), + } - var verifySig signatureVerifyFunc + for _, curve := range curves { + t.Run(curve.Params().Name, func(t *testing.T) { + n := 4 - var signatureAlgorithms func([]*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) + var verifySig signatureVerifyFunc + var signatureAlgorithms func([]*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) - verifySig = verifySignatureECDSA - signatureAlgorithms = ecdsaKeygenAndSign + verifySig = getVerifySignature(curve) + signatureAlgorithms = func(loggers []*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) { + return ecdsaKeygenAndSign(curve, loggers) + } - testScheme(t, n, signatureAlgorithms, verifySig, false) + testScheme(t, n, signatureAlgorithms, verifySig, false) + }) + } } func TestFastThresholdBinanceECDSA(t *testing.T) { - n := 4 + curves := []elliptic.Curve{ + elliptic.P256(), + s256k1.S256(), + } - var verifySig signatureVerifyFunc + for _, curve := range curves { + t.Run(curve.Params().Name, func(t *testing.T) { + n := 4 - var signatureAlgorithms func([]*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) + var verifySig signatureVerifyFunc + var signatureAlgorithms func([]*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) - verifySig = verifySignatureECDSA - signatureAlgorithms = ecdsaKeygenAndSign + verifySig = getVerifySignature(curve) + signatureAlgorithms = func(loggers []*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) { + return ecdsaKeygenAndSign(curve, loggers) + } - testScheme(t, n, signatureAlgorithms, verifySig, true) + testScheme(t, n, signatureAlgorithms, verifySig, true) + }) + } } -func ecdsaKeygenAndSign(loggers []*commLogger) (func(id uint16) KeyGenerator, func(id uint16) Signer) { +func ecdsaKeygenAndSign(curve elliptic.Curve, loggers []*commLogger) (func(id uint16) KeyGenerator, func(id uint16) Signer) { kgf := func(id uint16) KeyGenerator { - return ecdsa_scheme.NewParty(id, loggers[id-1]) + return ecdsa_scheme.NewParty(id, curve, loggers[id-1]) } sf := func(id uint16) Signer { - return ecdsa_scheme.NewParty(id, loggers[id-1]) + return ecdsa_scheme.NewParty(id, curve, loggers[id-1]) } return kgf, sf } +func getVerifySignature(curve elliptic.Curve) func(pkBytes []byte, t *testing.T, msg string, signature []byte) { + switch curve.Params().Name { + case s256k1.S256().Params().Name: + return verifySignatureSecp256k1 + default: + return verifySignatureECDSA + } +} + +func verifySignatureSecp256k1(pkBytes []byte, t *testing.T, msg string, signature []byte) { + pk, err := s256k1.ParsePubKey(pkBytes) + assert.NoError(t, err) + + sig, err := btcecdsa.ParseDERSignature(signature) + assert.NoError(t, err) + + assert.True(t, sig.Verify(sha256Digest([]byte(msg)), pk)) +} + func verifySignatureECDSA(pkBytes []byte, t *testing.T, msg string, signature []byte) { pk, err := x509.ParsePKIXPublicKey(pkBytes) assert.NoError(t, err) diff --git a/test/go.mod b/test/go.mod index 2ea1437..a240b10 100644 --- a/test/go.mod +++ b/test/go.mod @@ -49,6 +49,8 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) +replace github.com/IBM/TSS/mpc/binance/ecdsa => ../mpc/binance/ecdsa + replace github.com/IBM/TSS/mpc/bls => ../mpc/bls replace github.com/agl/ed25519 => github.com/binance-chain/edwards25519 v0.0.0-20200305024217-f36fc4b53d43