diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index faf951db14..e325fca56c 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -33,6 +33,8 @@ type httpclient interface { } var ( + errUnexpectedType = errors.New("unexpected type") + errUnexpectedParamLen = errors.New("unexpected params length") errCannotReadFromWebsocket = errors.New("cannot read message from websocket") errEmptyMethod = errors.New("empty method") ) @@ -163,25 +165,35 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (L wsconn: c, } - pA, ok := params.([]interface{}) - if !ok { - return nil, fmt.Errorf("unknown parameter type") - } - for _, param := range pA { - switch p := param.(type) { - case []interface{}: - for _, pp := range param.([]interface{}) { - data, ok := pp.(string) - if !ok { - return nil, fmt.Errorf("unknown parameter type") + // the following type checking/casting is needed in order to satisfy some + // websocket request field params eg.: + // "params": ["0x..."] or + // "params": [["0x...", "0x..."]] + switch filters := params.(type) { + case []interface{}: + for _, interfaceKey := range filters { + switch key := interfaceKey.(type) { + case string: + stgobs.filter[key] = []byte{} + case []string: + for _, k := range key { + stgobs.filter[k] = []byte{} + } + case []interface{}: + for _, k := range key { + k, ok := k.(string) + if !ok { + return nil, fmt.Errorf("%w: %T, expected type string", errUnexpectedType, k) + } + + stgobs.filter[k] = []byte{} } - stgobs.filter[data] = []byte{} + default: + return nil, fmt.Errorf("%w: %T, expected type string, []string, []interface{}", errUnexpectedType, interfaceKey) } - case string: - stgobs.filter[p] = []byte{} - default: - return nil, fmt.Errorf("unknown parameter type") } + default: + return nil, fmt.Errorf("%w: %T, expected type []interface{}", errUnexpectedType, params) } c.mu.Lock() @@ -269,14 +281,32 @@ func (c *WSConn) initAllBlocksListerner(reqID float64, _ interface{}) (Listener, } func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener, error) { - pA := params.([]interface{}) + var encodedExtrinsic string + + switch encodedHex := params.(type) { + case []string: + if len(encodedHex) != 1 { + return nil, fmt.Errorf("%w: expected 1 param, got: %d", errUnexpectedParamLen, len(encodedHex)) + } + encodedExtrinsic = encodedHex[0] + // the bellow case is needed to cover a interface{} slice containing one string + // as `[]interface{"a"}` is not the same as `[]string{"a"}` + case []interface{}: + if len(encodedHex) != 1 { + return nil, fmt.Errorf("%w: expected 1 param, got: %d", errUnexpectedParamLen, len(encodedHex)) + } - if len(pA) != 1 { - return nil, errors.New("expecting only one parameter") + var ok bool + encodedExtrinsic, ok = encodedHex[0].(string) + if !ok { + return nil, fmt.Errorf("%w: %T, expected type string", errUnexpectedType, encodedHex[0]) + } + default: + return nil, fmt.Errorf("%w: %T, expected type []string or []interface{}", errUnexpectedType, params) } // The passed parameter should be a HEX of a SCALE encoded extrinsic - extBytes, err := common.HexToBytes(pA[0].(string)) + extBytes, err := common.HexToBytes(encodedExtrinsic) if err != nil { return nil, err } diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 5ad9c6db4b..33dd91d612 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -45,7 +45,8 @@ func TestWSConn_HandleConn(t *testing.T) { res, err = wsconn.initStorageChangeListener(1, nil) require.Nil(t, res) require.Len(t, wsconn.Subscriptions, 0) - require.EqualError(t, err, "unknown parameter type") + require.ErrorIs(t, err, errUnexpectedType) + require.EqualError(t, err, "unexpected type: , expected type []interface{}") res, err = wsconn.initStorageChangeListener(2, []interface{}{}) require.NotNil(t, res) @@ -55,7 +56,8 @@ func TestWSConn_HandleConn(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":1,"id":2}`+"\n"), msg) - res, err = wsconn.initStorageChangeListener(3, []interface{}{"0x26aa"}) + var testFilter0 = []interface{}{"0x26aa"} + res, err = wsconn.initStorageChangeListener(3, testFilter0) require.NotNil(t, res) require.NoError(t, err) require.Len(t, wsconn.Subscriptions, 2) @@ -63,25 +65,26 @@ func TestWSConn_HandleConn(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":2,"id":3}`+"\n"), msg) - var testFilters = []interface{}{} - var testFilter1 = []interface{}{"0x26aa", "0x26a1"} - res, err = wsconn.initStorageChangeListener(4, append(testFilters, testFilter1)) - require.NotNil(t, res) + var testFilter1 = []interface{}{[]interface{}{"0x26aa", "0x26a1"}} + res, err = wsconn.initStorageChangeListener(4, testFilter1) require.NoError(t, err) + require.NotNil(t, res) require.Len(t, wsconn.Subscriptions, 3) _, msg, err = c.ReadMessage() require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","result":3,"id":4}`+"\n"), msg) - var testFilterWrongType = []interface{}{"0x26aa", 1} - res, err = wsconn.initStorageChangeListener(5, append(testFilters, testFilterWrongType)) - require.EqualError(t, err, "unknown parameter type") + var testFilterWrongType = []interface{}{[]int{123}} + res, err = wsconn.initStorageChangeListener(5, testFilterWrongType) + require.ErrorIs(t, err, errUnexpectedType) + require.EqualError(t, err, "unexpected type: []int, expected type string, []string, []interface{}") require.Nil(t, res) // keep subscriptions len == 3, no additions was made require.Len(t, wsconn.Subscriptions, 3) res, err = wsconn.initStorageChangeListener(6, []interface{}{1}) - require.EqualError(t, err, "unknown parameter type") + require.ErrorIs(t, err, errUnexpectedType) + require.EqualError(t, err, "unexpected type: int, expected type string, []string, []interface{}") require.Nil(t, res) require.Len(t, wsconn.Subscriptions, 3) @@ -207,7 +210,7 @@ func TestWSConn_HandleConn(t *testing.T) { wsconn.CoreAPI = modules.NewMockCoreAPI() wsconn.BlockAPI = nil wsconn.TxStateAPI = modules.NewMockTransactionStateAPI() - listner, err := wsconn.initExtrinsicWatch(0, []interface{}{"NotHex"}) + listner, err := wsconn.initExtrinsicWatch(0, []string{"NotHex"}) require.EqualError(t, err, "could not byteify non 0x prefixed string: NotHex") require.Nil(t, listner)