Skip to content

Improve readability of musig2 test #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/bitcointree/musig2.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ type SignerSession interface {

type CoordinatorSession interface {
AddNonce(*btcec.PublicKey, TreeNonces)
AddSig(*btcec.PublicKey, TreePartialSigs)
AddSignatures(*btcec.PublicKey, TreePartialSigs)
AggregateNonces() (TreeNonces, error)
// SignTree combines the signatures and add them to the tree's psbts
SignTree() (tree.VtxoTree, error)
Expand Down Expand Up @@ -420,7 +420,7 @@ func (t *treeCoordinatorSession) AddNonce(pubkey *btcec.PublicKey, nonce TreeNon
t.nonces[hex.EncodeToString(schnorr.SerializePubKey(pubkey))] = nonce
}

func (t *treeCoordinatorSession) AddSig(pubkey *btcec.PublicKey, sig TreePartialSigs) {
func (t *treeCoordinatorSession) AddSignatures(pubkey *btcec.PublicKey, sig TreePartialSigs) {
t.sigs[hex.EncodeToString(schnorr.SerializePubKey(pubkey))] = sig
}

Expand Down
310 changes: 184 additions & 126 deletions common/bitcointree/musig2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
"github.com/ark-network/ark/common/bitcointree"
"github.com/ark-network/ark/common/tree"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
)

Expand All @@ -24,10 +22,10 @@ const (

var (
vtxoTreeExpiry = common.RelativeLocktime{Type: common.LocktimeTypeBlock, Value: 144}
testTxid, _ = chainhash.NewHashFromStr("49f8664acc899be91902f8ade781b7eeb9cbe22bdd9efbc36e56195de21bcd12")
serverPrivKey, _ = secp256k1.GeneratePrivateKey()
rootInput, _ = wire.NewOutPointFromString("49f8664acc899be91902f8ade781b7eeb9cbe22bdd9efbc36e56195de21bcd12:0")
serverPrivKey, _ = btcec.NewPrivateKey()
sweepScript, _ = (&tree.CSVMultisigClosure{
MultisigClosure: tree.MultisigClosure{PubKeys: []*secp256k1.PublicKey{serverPrivKey.PubKey()}},
MultisigClosure: tree.MultisigClosure{PubKeys: []*btcec.PublicKey{serverPrivKey.PubKey()}},
Locktime: vtxoTreeExpiry,
}).Script()
sweepRoot = txscript.NewBaseTapLeaf(sweepScript).TapHash()
Expand All @@ -37,125 +35,214 @@ var (
func TestBuildAndSignVtxoTree(t *testing.T) {
t.Parallel()

for _, tc := range generateTestCases(t) {
t.Run(tc.name, func(t *testing.T) {
sharedOutputScript, sharedOutputAmount, err := bitcointree.CraftSharedOutput(
tc.receivers,
minRelayFee,
sweepRoot[:],
testVectors, err := makeTestVectors()
require.NoError(t, err)
require.NotEmpty(t, testVectors)

for _, v := range testVectors {
t.Run(v.name, func(t *testing.T) {
sharedOutScript, sharedOutAmount, err := bitcointree.CraftSharedOutput(
v.receivers, minRelayFee, sweepRoot[:],
)
require.NoError(t, err)
require.NotNil(t, sharedOutputScript)
require.NotNil(t, sharedOutScript)
require.NotZero(t, sharedOutAmount)

vtxoTree, err := bitcointree.BuildVtxoTree(
&wire.OutPoint{
Hash: *testTxid,
Index: 0,
},
tc.receivers,
minRelayFee,
sweepRoot[:],
vtxoTreeExpiry,
rootInput, v.receivers, minRelayFee, sweepRoot[:], vtxoTreeExpiry,
)
require.NoError(t, err)
require.NotNil(t, vtxoTree)

serverCoordinator, err := bitcointree.NewTreeCoordinatorSession(
sharedOutputAmount,
vtxoTree,
sweepRoot[:],
coordinator, err := bitcointree.NewTreeCoordinatorSession(
sharedOutAmount, vtxoTree, sweepRoot[:],
)
require.NoError(t, err)
require.NotNil(t, coordinator)

// Cceate signer sessions for each receivers
signerSessions := make(map[*btcec.PublicKey]bitcointree.SignerSession)
for _, prvkey := range tc.privKeys {
session := bitcointree.NewTreeSignerSession(prvkey)
err := session.Init(sweepRoot[:], sharedOutputAmount, vtxoTree)
require.NoError(t, err)
signerSessions[prvkey.PubKey()] = session
}
signers, err := makeCosigners(v.privKeys, sharedOutAmount, vtxoTree)
require.NoError(t, err)
require.NotNil(t, signers)

// Create server's signer session
serverSession := bitcointree.NewTreeSignerSession(serverPrivKey)
err = serverSession.Init(sweepRoot[:], sharedOutputAmount, vtxoTree)
err = makeAggregatedNonces(signers, coordinator, checkNoncesRoundtrip(t))
require.NoError(t, err)
signerSessions[serverPrivKey.PubKey()] = serverSession

// generate nonces from all signers
for pubkey, session := range signerSessions {
nonces, err := session.GetNonces()
require.NoError(t, err)
var encodedNonces bytes.Buffer
err = nonces.Encode(&encodedNonces)
require.NoError(t, err)
decodedNonces, err := bitcointree.DecodeNonces(&encodedNonces)
require.NoError(t, err)
for i, nonceRow := range nonces {
for j, nonce := range nonceRow {
require.Equal(t, nonce, decodedNonces[i][j])
}
}

serverCoordinator.AddNonce(pubkey, nonces)
}
signedTree, err := makeAggregatedSignatures(signers, coordinator, checkSigsRoundtrip(t))
require.NoError(t, err)
require.NotNil(t, signedTree)

aggregatedNonce, err := serverCoordinator.AggregateNonces()
// validate signatures
err = bitcointree.ValidateTreeSigs(sweepRoot[:], sharedOutAmount, signedTree)
require.NoError(t, err)
})
}
}

func checkNoncesRoundtrip(t *testing.T) func(nonces bitcointree.TreeNonces) {
return func(nonces bitcointree.TreeNonces) {
var encodedNonces bytes.Buffer
err := nonces.Encode(&encodedNonces)
require.NoError(t, err)

// set the aggregated nonces for all signers sessions
for _, session := range signerSessions {
session.SetAggregatedNonces(aggregatedNonce)
decodedNonces, err := bitcointree.DecodeNonces(&encodedNonces)
require.NoError(t, err)
for i, nonceRow := range nonces {
for j, nonce := range nonceRow {
require.Equal(t, nonce, decodedNonces[i][j])
}
}
}
}

// get signatures from all signers sessions
for pubkey, session := range signerSessions {
sig, err := session.Sign()
require.NoError(t, err)
require.NotNil(t, sig)
var encodedSig bytes.Buffer
err = sig.Encode(&encodedSig)
require.NoError(t, err)
decodedSig, err := bitcointree.DecodeSignatures(&encodedSig)
require.NoError(t, err)
for i, sigRow := range sig {
for j, sig := range sigRow {
if sig == nil {
require.Nil(t, decodedSig[i][j])
} else {
require.Equal(t, sig.S, decodedSig[i][j].S)
}
}
func checkSigsRoundtrip(t *testing.T) func(sigs bitcointree.TreePartialSigs) {
return func(sigs bitcointree.TreePartialSigs) {
var encodedSig bytes.Buffer
err := sigs.Encode(&encodedSig)
require.NoError(t, err)
decodedSig, err := bitcointree.DecodeSignatures(&encodedSig)
require.NoError(t, err)
for i, sigRow := range sigs {
for j, sig := range sigRow {
if sig == nil {
require.Nil(t, decodedSig[i][j])
} else {
require.Equal(t, sig.S, decodedSig[i][j].S)
}

serverCoordinator.AddSig(pubkey, sig)
}
}
}
}

// aggregate signatures
signedTree, err := serverCoordinator.SignTree()
require.NoError(t, err)
require.NotNil(t, signedTree)
// validate signatures
err = bitcointree.ValidateTreeSigs(
sweepRoot[:],
sharedOutputAmount,
signedTree,
)
require.NoError(t, err)
})
func makeCosigners(
keys []*btcec.PrivateKey, sharedOutAmount int64, vtxoTree tree.VtxoTree,
) (map[string]bitcointree.SignerSession, error) {
signers := make(map[string]bitcointree.SignerSession)
for _, prvkey := range keys {
session := bitcointree.NewTreeSignerSession(prvkey)
if err := session.Init(sweepRoot[:], sharedOutAmount, vtxoTree); err != nil {
return nil, err
}
signers[keyToStr(prvkey)] = session
}

// create signer session for the server itself
serverSession := bitcointree.NewTreeSignerSession(serverPrivKey)
if err := serverSession.Init(sweepRoot[:], sharedOutAmount, vtxoTree); err != nil {
return nil, err
}
signers[keyToStr(serverPrivKey)] = serverSession
return signers, nil
}

func makeAggregatedNonces(
signers map[string]bitcointree.SignerSession, coordinator bitcointree.CoordinatorSession,
checkNoncesRoundtrip func(bitcointree.TreeNonces),
) error {
for pk, session := range signers {
buf, err := hex.DecodeString(pk)
if err != nil {
return err
}
pubkey, err := btcec.ParsePubKey(buf)
if err != nil {
return err
}

nonces, err := session.GetNonces()
if err != nil {
return err
}
checkNoncesRoundtrip(nonces)

coordinator.AddNonce(pubkey, nonces)
}

aggregatedNonce, err := coordinator.AggregateNonces()
if err != nil {
return err
}

// set the aggregated nonces for all signers sessions
for _, session := range signers {
session.SetAggregatedNonces(aggregatedNonce)
}
return nil
}

func makeAggregatedSignatures(
signers map[string]bitcointree.SignerSession, coordinator bitcointree.CoordinatorSession,
checkSigsRoundtrip func(bitcointree.TreePartialSigs),
) (tree.VtxoTree, error) {
for pk, session := range signers {
buf, err := hex.DecodeString(pk)
if err != nil {
return nil, err
}
pubkey, err := btcec.ParsePubKey(buf)
if err != nil {
return nil, err
}

sigs, err := session.Sign()
if err != nil {
return nil, err
}
checkSigsRoundtrip(sigs)

coordinator.AddSignatures(pubkey, sigs)
}

// aggregate signatures
return coordinator.SignTree()
}

type testCase struct {
name string
receivers []tree.VtxoLeaf
privKeys []*secp256k1.PrivateKey
privKeys []*btcec.PrivateKey
}

func makeTestVectors() ([]testCase, error) {
vectors := make([]testCase, 0, len(receiverCounts))
for _, count := range receiverCounts {
receivers, privKeys, err := generateMockedReceivers(count)
if err != nil {
return nil, err
}

// add mixed types test case if count is between 2 and 32
if count > 1 && count < 32 {
vectors = append(vectors, testCase{
name: fmt.Sprintf("%d receivers Mixed Signing Types", len(receivers)),
receivers: withMixedSigningTypes(receivers),
privKeys: privKeys,
})
}

// add SignAll test case if count is less than 32
if count < 32 {
vectors = append(vectors, testCase{
name: fmt.Sprintf("%d receivers SignAll", len(receivers)),
receivers: withSigningType(tree.SignAll, receivers),
privKeys: privKeys,
})
}

// always add SignBranch test case
vectors = append(vectors, testCase{
name: fmt.Sprintf("%d receivers SignBranch", len(receivers)),
receivers: withSigningType(tree.SignBranch, receivers),
privKeys: privKeys,
})
}
return vectors, nil
}

func generateReceiversFixture(count int) ([]tree.VtxoLeaf, []*secp256k1.PrivateKey, error) {
receivers := make([]tree.VtxoLeaf, 0, count)
privKeys := make([]*secp256k1.PrivateKey, 0, count)
for i := 0; i < count; i++ {
prvkey, err := secp256k1.GeneratePrivateKey()
func generateMockedReceivers(num int) ([]tree.VtxoLeaf, []*btcec.PrivateKey, error) {
receivers := make([]tree.VtxoLeaf, 0, num)
privKeys := make([]*btcec.PrivateKey, 0, num)
for i := 0; i < num; i++ {
prvkey, err := btcec.NewPrivateKey()
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -196,35 +283,6 @@ func withMixedSigningTypes(receivers []tree.VtxoLeaf) []tree.VtxoLeaf {
return append(first, second...)
}

func generateTestCases(t *testing.T) []testCase {
testCases := make([]testCase, 0)
for _, count := range receiverCounts {
receivers, privKeys, err := generateReceiversFixture(count)
require.NoError(t, err)
// add mixed types test case if count is between 2 and 32
if count > 1 && count < 32 {
testCases = append(testCases, testCase{
name: fmt.Sprintf("%d receivers Mixed Signing Types", len(receivers)),
receivers: withMixedSigningTypes(receivers),
privKeys: privKeys,
})
}

// add SignAll test case if count is less than 32
if count < 32 {
testCases = append(testCases, testCase{
name: fmt.Sprintf("%d receivers SignAll", len(receivers)),
receivers: withSigningType(tree.SignAll, receivers),
privKeys: privKeys,
})
}

// always add SignBranch test case
testCases = append(testCases, testCase{
name: fmt.Sprintf("%d receivers SignBranch", len(receivers)),
receivers: withSigningType(tree.SignBranch, receivers),
privKeys: privKeys,
})
}
return testCases
func keyToStr(key *btcec.PrivateKey) string {
return hex.EncodeToString(key.PubKey().SerializeCompressed())
}
Loading
Loading