Skip to content

Commit 82745ae

Browse files
Separate Handshake and Notification message interfaces and other tiny changes (#2696)
* Added checks to avoid decoding handshakes as messages * Use roles to check if a message is a handshake or not * break createNotificationsMessageHandle into smaller functions * fixing IsHandshake for block announce * Separated Handshake and NotificationMessage interfaces * remove IsValidHandshake from NotificationsMessageInterface * renamed IsHandshake to IsValid * removed Type() from Handshake
1 parent 8aa043b commit 82745ae

14 files changed

+112
-179
lines changed

dot/network/block_announce.go

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ import (
1515
"github.com/libp2p/go-libp2p-core/peer"
1616
)
1717

18-
var errInvalidRole = errors.New("invalid role")
1918
var (
2019
_ NotificationsMessage = &BlockAnnounceMessage{}
21-
_ NotificationsMessage = &BlockAnnounceHandshake{}
20+
_ Handshake = (*BlockAnnounceHandshake)(nil)
2221
)
2322

2423
// BlockAnnounceMessage is a state block header
@@ -31,9 +30,9 @@ type BlockAnnounceMessage struct {
3130
BestBlock bool
3231
}
3332

34-
// Type returns BlockAnnounceMsgType
33+
// Type returns blockAnnounceMsgType
3534
func (*BlockAnnounceMessage) Type() byte {
36-
return BlockAnnounceMsgType
35+
return blockAnnounceMsgType
3736
}
3837

3938
// string formats a BlockAnnounceMessage as a string
@@ -75,11 +74,6 @@ func (bm *BlockAnnounceMessage) Hash() (common.Hash, error) {
7574
return common.Blake2bHash(encMsg)
7675
}
7776

78-
// IsHandshake returns false
79-
func (*BlockAnnounceMessage) IsHandshake() bool {
80-
return false
81-
}
82-
8377
func decodeBlockAnnounceHandshake(in []byte) (Handshake, error) {
8478
hs := BlockAnnounceHandshake{}
8579
err := scale.Unmarshal(in, &hs)
@@ -133,25 +127,14 @@ func (hs *BlockAnnounceHandshake) Decode(in []byte) error {
133127
return nil
134128
}
135129

136-
// Type ...
137-
func (*BlockAnnounceHandshake) Type() byte {
138-
return 0
139-
}
140-
141-
// Hash returns blake2b hash of block announce handshake.
142-
func (hs *BlockAnnounceHandshake) Hash() (common.Hash, error) {
143-
// scale encode each extrinsic
144-
encMsg, err := hs.Encode()
145-
if err != nil {
146-
return common.Hash{}, fmt.Errorf("cannot encode handshake: %w", err)
130+
// IsValid returns true if handshakes's role is valid.
131+
func (hs *BlockAnnounceHandshake) IsValid() bool {
132+
switch hs.Roles {
133+
case common.AuthorityRole, common.FullNodeRole, common.LightClientRole:
134+
return true
135+
default:
136+
return false
147137
}
148-
149-
return common.Blake2bHash(encMsg)
150-
}
151-
152-
// IsHandshake returns true
153-
func (*BlockAnnounceHandshake) IsHandshake() bool {
154-
return true
155138
}
156139

157140
func (s *Service) getBlockAnnounceHandshake() (Handshake, error) {
@@ -188,7 +171,7 @@ func (s *Service) validateBlockAnnounceHandshake(from peer.ID, hs Handshake) err
188171
return errors.New("genesis hash mismatch")
189172
}
190173

191-
np, ok := s.notificationsProtocols[BlockAnnounceMsgType]
174+
np, ok := s.notificationsProtocols[blockAnnounceMsgType]
192175
if !ok {
193176
// this should never happen.
194177
return nil

dot/network/block_announce_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
158158

159159
nodeA := createTestService(t, configA)
160160
nodeA.noGossip = true
161-
nodeA.notificationsProtocols[BlockAnnounceMsgType] = &notificationsProtocol{
161+
nodeA.notificationsProtocols[blockAnnounceMsgType] = &notificationsProtocol{
162162
peersData: newPeersData(),
163163
}
164164
testPeerID := peer.ID("noot")
165-
nodeA.notificationsProtocols[BlockAnnounceMsgType].peersData.setInboundHandshakeData(testPeerID, &handshakeData{})
165+
nodeA.notificationsProtocols[blockAnnounceMsgType].peersData.setInboundHandshakeData(testPeerID, &handshakeData{})
166166

167167
err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
168168
Roles: common.FullNodeRole,

dot/network/errors.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ import (
1010
var (
1111
errCannotValidateHandshake = errors.New("failed to validate handshake")
1212
errMessageTypeNotValid = errors.New("message type is not valid")
13-
errMessageIsNotHandshake = errors.New("failed to convert message to Handshake")
1413
errInvalidHandshakeForPeer = errors.New("peer previously sent invalid handshake")
1514
errHandshakeTimeout = errors.New("handshake timeout reached")
1615
errBlockRequestFromNumberInvalid = errors.New("block request message From number is not valid")
1716
errInvalidStartingBlockType = errors.New("invalid StartingBlock in messsage")
17+
errInboundHanshakeExists = errors.New("an inbound handshake already exists for given peer")
18+
errInvalidRole = errors.New("invalid role")
1819
)

dot/network/gossip.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ func (g *gossip) hasSeen(msg NotificationsMessage) (bool, error) {
4040
_, ok := g.seenMap[msgHash]
4141
if !ok {
4242
// set message to has been seen
43-
if !msg.IsHandshake() {
44-
g.seenMap[msgHash] = struct{}{}
45-
}
43+
g.seenMap[msgHash] = struct{}{}
4644
return false, nil
4745
}
4846

dot/network/host_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
344344
_, err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
345345
require.NoError(t, err)
346346

347-
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]
347+
info := nodeA.notificationsProtocols[blockAnnounceMsgType]
348348

349349
// Set handshake data to received
350350
info.peersData.setInboundHandshakeData(nodeB.host.id(), &handshakeData{

dot/network/message.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ import (
1919

2020
// Message types for notifications protocol messages. Used internally to map message to protocol.
2121
const (
22-
BlockAnnounceMsgType byte = 3
23-
TransactionMsgType byte = 4
22+
blockAnnounceMsgType byte = 3
23+
transactionMsgType byte = 4
2424
ConsensusMsgType byte = 5
2525
)
2626

@@ -36,7 +36,6 @@ type NotificationsMessage interface {
3636
Message
3737
Type() byte
3838
Hash() (common.Hash, error)
39-
IsHandshake() bool
4039
}
4140

4241
//nolint:revive
@@ -381,8 +380,3 @@ func (cm *ConsensusMessage) Hash() (common.Hash, error) {
381380
}
382381
return common.Blake2bHash(encMsg)
383382
}
384-
385-
// IsHandshake returns false
386-
func (cm *ConsensusMessage) IsHandshake() bool {
387-
return false
388-
}

dot/network/message_cache.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package network
55

66
import (
7-
"errors"
87
"fmt"
98
"time"
109

@@ -65,10 +64,6 @@ func (m *messageCache) exists(peer peer.ID, msg NotificationsMessage) bool {
6564
}
6665

6766
func generateCacheKey(peer peer.ID, msg NotificationsMessage) ([]byte, error) {
68-
if msg.IsHandshake() {
69-
return nil, errors.New("cache does not support handshake messages")
70-
}
71-
7267
msgHash, err := msg.Hash()
7368
if err != nil {
7469
return nil, fmt.Errorf("cannot hash notification message: %w", err)

dot/network/message_cache_test.go

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,30 +51,3 @@ func TestMessageCache(t *testing.T) {
5151
ok = msgCache.exists(peerID, msg)
5252
require.False(t, ok)
5353
}
54-
55-
func TestMessageCacheError(t *testing.T) {
56-
t.Parallel()
57-
58-
cacheSize := 64 << 20 // 64 MB
59-
msgCache, err := newMessageCache(ristretto.Config{
60-
NumCounters: int64(float64(cacheSize) * 0.05 * 2),
61-
MaxCost: int64(float64(cacheSize) * 0.95),
62-
BufferItems: 64,
63-
Cost: func(value interface{}) int64 {
64-
return int64(1)
65-
},
66-
}, 800*time.Millisecond)
67-
require.NoError(t, err)
68-
69-
peerID := peer.ID("gossamer")
70-
msg := &BlockAnnounceHandshake{
71-
Roles: 4,
72-
BestBlockNumber: 77,
73-
BestBlockHash: common.Hash{1},
74-
GenesisHash: common.Hash{2},
75-
}
76-
77-
ok, err := msgCache.put(peerID, msg)
78-
require.Error(t, err, "cache does not support handshake messages")
79-
require.False(t, ok)
80-
}

dot/network/notifications.go

Lines changed: 65 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ const handshakeTimeout = time.Second * 10
2020

2121
// Handshake is the interface all handshakes for notifications protocols must implement
2222
type Handshake interface {
23-
NotificationsMessage
23+
Message
24+
IsValid() bool
2425
}
2526

2627
// the following are used for RegisterNotificationsProtocol
@@ -91,6 +92,9 @@ func newHandshakeData(received, validated bool, stream network.Stream) *handshak
9192
}
9293
}
9394

95+
// createDecoder combines the notification message decoder and the handshake decoder. The combined
96+
// decoder decodes using the handshake decoder if we already have handshake data stored for a given
97+
// peer, otherwise it decodes using the notification message decoder.
9498
func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecoder,
9599
messageDecoder MessageDecoder) messageDecoder {
96100
return func(in []byte, peer peer.ID, inbound bool) (Message, error) {
@@ -130,6 +134,18 @@ func (s *Service) createNotificationsMessageHandler(
130134
peer = stream.Conn().RemotePeer()
131135
)
132136

137+
hs, ok := m.(Handshake)
138+
if ok {
139+
if !hs.IsValid() {
140+
return errInvalidRole
141+
}
142+
err := s.handleHandshake(info, stream, hs, peer)
143+
if err != nil {
144+
return fmt.Errorf("handling handshake: %w", err)
145+
}
146+
return nil
147+
}
148+
133149
if msg, ok = m.(NotificationsMessage); !ok {
134150
return fmt.Errorf("%w: expected %T but got %T", errMessageTypeNotValid, (NotificationsMessage)(nil), msg)
135151
}
@@ -148,62 +164,6 @@ func (s *Service) createNotificationsMessageHandler(
148164
return nil
149165
}
150166

151-
if msg.IsHandshake() {
152-
logger.Tracef("received handshake on notifications sub-protocol %s from peer %s, message is: %s",
153-
info.protocolID, stream.Conn().RemotePeer(), msg)
154-
155-
hs, ok := msg.(Handshake)
156-
if !ok {
157-
// NOTE: As long as, Handshake interface and NotificationMessage interfaces are same,
158-
// this error would never happen.
159-
return errMessageIsNotHandshake
160-
}
161-
162-
// if we are the receiver and haven't received the handshake already, validate it
163-
// note: if this function is being called, it's being called via SetStreamHandler,
164-
// ie it is an inbound stream and we only send the handshake over it.
165-
// we do not send any other data over this stream, we would need to open a new outbound stream.
166-
hsData := info.peersData.getInboundHandshakeData(peer)
167-
if hsData == nil {
168-
logger.Tracef("receiver: validating handshake using protocol %s", info.protocolID)
169-
170-
hsData = newHandshakeData(true, false, stream)
171-
info.peersData.setInboundHandshakeData(peer, hsData)
172-
173-
err := info.handshakeValidator(peer, hs)
174-
if err != nil {
175-
logger.Tracef(
176-
"failed to validate handshake from peer %s using protocol %s: %s",
177-
peer, info.protocolID, err)
178-
return errCannotValidateHandshake
179-
}
180-
181-
hsData.validated = true
182-
info.peersData.setInboundHandshakeData(peer, hsData)
183-
184-
// once validated, send back a handshake
185-
resp, err := info.getHandshake()
186-
if err != nil {
187-
logger.Warnf("failed to get handshake using protocol %s: %s", info.protocolID, err)
188-
return err
189-
}
190-
191-
err = s.host.writeToStream(stream, resp)
192-
if err != nil {
193-
logger.Tracef("failed to send handshake to peer %s using protocol %s: %s", peer, info.protocolID, err)
194-
return err
195-
}
196-
197-
logger.Tracef("receiver: sent handshake to peer %s using protocol %s", peer, info.protocolID)
198-
199-
if err := stream.CloseWrite(); err != nil {
200-
logger.Tracef("failed to close stream for writing: %s", err)
201-
}
202-
}
203-
204-
return nil
205-
}
206-
207167
logger.Tracef("received message on notifications sub-protocol %s from peer %s, message is: %s",
208168
info.protocolID, stream.Conn().RemotePeer(), msg)
209169

@@ -226,6 +186,54 @@ func (s *Service) createNotificationsMessageHandler(
226186
}
227187
}
228188

189+
func (s *Service) handleHandshake(info *notificationsProtocol, stream network.Stream,
190+
hs Handshake, peer peer.ID) error {
191+
logger.Tracef("received handshake on notifications sub-protocol %s from peer %s, message is: %s",
192+
info.protocolID, stream.Conn().RemotePeer(), hs)
193+
194+
// if we are the receiver and haven't received the handshake already, validate it
195+
// note: if this function is being called, it's being called via SetStreamHandler,
196+
// ie it is an inbound stream and we only send the handshake over it.
197+
// we do not send any other data over this stream, we would need to open a new outbound stream.
198+
hsData := info.peersData.getInboundHandshakeData(peer)
199+
if hsData != nil {
200+
return fmt.Errorf("%w: for peer id %s", errInboundHanshakeExists, peer)
201+
}
202+
203+
logger.Tracef("receiver: validating handshake using protocol %s", info.protocolID)
204+
205+
hsData = newHandshakeData(true, false, stream)
206+
info.peersData.setInboundHandshakeData(peer, hsData)
207+
208+
err := info.handshakeValidator(peer, hs)
209+
if err != nil {
210+
return fmt.Errorf("%w from peer %s using protocol %s: %s",
211+
errCannotValidateHandshake, peer, info.protocolID, err)
212+
}
213+
214+
hsData.validated = true
215+
info.peersData.setInboundHandshakeData(peer, hsData)
216+
217+
// once validated, send back a handshake
218+
resp, err := info.getHandshake()
219+
if err != nil {
220+
return fmt.Errorf("failed to get handshake using protocol %s: %s", info.protocolID, err)
221+
}
222+
223+
err = s.host.writeToStream(stream, resp)
224+
if err != nil {
225+
return fmt.Errorf("failed to send handshake to peer %s using protocol %s: %w", peer, info.protocolID, err)
226+
}
227+
228+
logger.Tracef("receiver: sent handshake to peer %s using protocol %s", peer, info.protocolID)
229+
230+
if err := stream.CloseWrite(); err != nil {
231+
return fmt.Errorf("failed to close stream for writing: %s", err)
232+
}
233+
234+
return nil
235+
}
236+
229237
func closeOutboundStream(info *notificationsProtocol, peerID peer.ID, stream network.Stream) {
230238
logger.Debugf(
231239
"cleaning up outbound handshake data for protocol=%s, peer=%s",

dot/network/notifications_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
186186

187187
// try invalid handshake
188188
testHandshake := &BlockAnnounceHandshake{
189-
Roles: 4,
189+
Roles: common.AuthorityRole,
190190
BestBlockNumber: 77,
191191
BestBlockHash: common.Hash{1},
192192
// we are using a different genesis here, thus this
@@ -195,7 +195,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
195195
}
196196

197197
err = handler(stream, testHandshake)
198-
require.Equal(t, errCannotValidateHandshake, err)
198+
require.ErrorIs(t, err, errCannotValidateHandshake)
199199
data := info.peersData.getInboundHandshakeData(testPeerID)
200200
require.NotNil(t, data)
201201
require.True(t, data.received)

0 commit comments

Comments
 (0)