@@ -33,14 +33,27 @@ type Conn struct {
33
33
stream Stream
34
34
err error
35
35
pendingMu sync.Mutex // protects the pending map
36
- pending map [ID ]chan * Response
36
+ pending map [ID ]chan * wireResponse
37
37
handlingMu sync.Mutex // protects the handling map
38
38
handling map [ID ]handling
39
39
}
40
40
41
+ // Request is sent to a server to represent a Call or Notify operaton.
42
+ type Request struct {
43
+ conn * Conn
44
+
45
+ // Method is a string containing the method name to invoke.
46
+ Method string
47
+ // Params is either a struct or an array with the parameters of the method.
48
+ Params * json.RawMessage
49
+ // The id of this request, used to tie the response back to the request.
50
+ // Will be either a string or a number. If not set, the request is a notify,
51
+ // and no response is possible.
52
+ ID * ID
53
+ }
54
+
41
55
type queueEntry struct {
42
56
ctx context.Context
43
- c * Conn
44
57
r * Request
45
58
size int64
46
59
}
@@ -50,16 +63,14 @@ type queueEntry struct {
50
63
// call Reply on the Conn with the supplied request.
51
64
// Handlers are called synchronously, they should pass the work off to a go
52
65
// routine if they are going to take a long time.
53
- type Handler func (context.Context , * Conn , * Request )
66
+ type Handler func (context.Context , * Request )
54
67
55
68
// Canceler is an option you can pass to NewConn which is invoked for
56
69
// cancelled outgoing requests.
57
- // The request will have the ID filled in, which can be used to propagate the
58
- // cancel to the other process if needed.
59
70
// It is okay to use the connection to send notifications, but the context will
60
71
// be in the cancelled state, so you must do it with the background context
61
72
// instead.
62
- type Canceler func (context.Context , * Conn , * Request )
73
+ type Canceler func (context.Context , * Conn , ID )
63
74
64
75
type rpcStats struct {
65
76
server bool
@@ -133,17 +144,17 @@ func NewErrorf(code int64, format string, args ...interface{}) *Error {
133
144
func NewConn (s Stream ) * Conn {
134
145
conn := & Conn {
135
146
stream : s ,
136
- pending : make (map [ID ]chan * Response ),
147
+ pending : make (map [ID ]chan * wireResponse ),
137
148
handling : make (map [ID ]handling ),
138
149
}
139
150
// the default handler reports a method error
140
- conn .Handler = func (ctx context.Context , c * Conn , r * Request ) {
151
+ conn .Handler = func (ctx context.Context , r * Request ) {
141
152
if r .IsNotify () {
142
- c .Reply (ctx , r , nil , NewErrorf (CodeMethodNotFound , "method %q not found" , r .Method ))
153
+ r .Reply (ctx , nil , NewErrorf (CodeMethodNotFound , "method %q not found" , r .Method ))
143
154
}
144
155
}
145
- // the default canceller does nothing
146
- conn .Canceler = func (context.Context , * Conn , * Request ) {}
156
+ // the default canceler does nothing
157
+ conn .Canceler = func (context.Context , * Conn , ID ) {}
147
158
// the default logger does nothing
148
159
conn .Logger = func (Direction , * ID , time.Duration , string , * json.RawMessage , * Error ) {}
149
160
return conn
@@ -174,7 +185,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e
174
185
if err != nil {
175
186
return fmt .Errorf ("marshalling notify parameters: %v" , err )
176
187
}
177
- request := & Request {
188
+ request := & wireRequest {
178
189
Method : method ,
179
190
Params : jsonParams ,
180
191
}
@@ -200,7 +211,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
200
211
if err != nil {
201
212
return fmt .Errorf ("marshalling call parameters: %v" , err )
202
213
}
203
- request := & Request {
214
+ request := & wireRequest {
204
215
ID : & id ,
205
216
Method : method ,
206
217
Params : jsonParams ,
@@ -212,7 +223,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
212
223
}
213
224
// we have to add ourselves to the pending map before we send, otherwise we
214
225
// are racing the response
215
- rchan := make (chan * Response )
226
+ rchan := make (chan * wireResponse )
216
227
c .pendingMu .Lock ()
217
228
c .pending [id ] = rchan
218
229
c .pendingMu .Unlock ()
@@ -249,40 +260,48 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
249
260
return nil
250
261
case <- ctx .Done ():
251
262
// allow the handler to propagate the cancel
252
- c .Canceler (ctx , c , request )
263
+ c .Canceler (ctx , c , id )
253
264
return ctx .Err ()
254
265
}
255
266
}
256
267
268
+ // Conn returns the connection that created this request.
269
+ func (r * Request ) Conn () * Conn { return r .conn }
270
+
271
+ // IsNotify returns true if this request is a notification.
272
+ func (r * Request ) IsNotify () bool {
273
+ return r .ID == nil
274
+ }
275
+
257
276
// Reply sends a reply to the given request.
258
277
// It is an error to call this if request was not a call.
259
278
// You must call this exactly once for any given request.
260
279
// If err is set then result will be ignored.
261
- func (c * Conn ) Reply (ctx context.Context , req * Request , result interface {}, err error ) error {
262
- ctx , st := trace .StartSpan (ctx , req .Method + ":reply" , trace .WithSpanKind (trace .SpanKindClient ))
280
+ func (r * Request ) Reply (ctx context.Context , result interface {}, err error ) error {
281
+ ctx , st := trace .StartSpan (ctx , r .Method + ":reply" , trace .WithSpanKind (trace .SpanKindClient ))
263
282
defer st .End ()
264
283
265
- if req .IsNotify () {
284
+ if r .IsNotify () {
266
285
return fmt .Errorf ("reply not invoked with a valid call" )
267
286
}
268
- c .handlingMu .Lock ()
269
- handling , found := c . handling [* req .ID ]
287
+ r . conn .handlingMu .Lock ()
288
+ handling , found := r . conn . handling [* r .ID ]
270
289
if found {
271
- delete (c . handling , * req .ID )
290
+ delete (r . conn . handling , * r .ID )
272
291
}
273
- c .handlingMu .Unlock ()
292
+ r . conn .handlingMu .Unlock ()
274
293
if ! found {
275
- return fmt .Errorf ("not a call in progress: %v" , req .ID )
294
+ return fmt .Errorf ("not a call in progress: %v" , r .ID )
276
295
}
277
296
278
297
elapsed := time .Since (handling .start )
279
298
var raw * json.RawMessage
280
299
if err == nil {
281
300
raw , err = marshalToRaw (result )
282
301
}
283
- response := & Response {
302
+ response := & wireResponse {
284
303
Result : raw ,
285
- ID : req .ID ,
304
+ ID : r .ID ,
286
305
}
287
306
if err != nil {
288
307
if callErr , ok := err .(* Error ); ok {
@@ -295,8 +314,8 @@ func (c *Conn) Reply(ctx context.Context, req *Request, result interface{}, err
295
314
if err != nil {
296
315
return err
297
316
}
298
- c . Logger (Send , response .ID , elapsed , req .Method , response .Result , response .Error )
299
- n , err := c .stream .Write (ctx , data )
317
+ r . conn . Logger (Send , response .ID , elapsed , r .Method , response .Result , response .Error )
318
+ n , err := r . conn .stream .Write (ctx , data )
300
319
301
320
v := ctx .Value (rpcStatsKey )
302
321
if v != nil {
@@ -332,7 +351,7 @@ type combined struct {
332
351
}
333
352
334
353
func (c * Conn ) deliver (ctx context.Context , q chan queueEntry , request * Request , size int64 ) bool {
335
- e := queueEntry {ctx : ctx , c : c , r : request , size : size }
354
+ e := queueEntry {ctx : ctx , r : request , size : size }
336
355
if ! c .RejectIfOverloaded {
337
356
q <- e
338
357
return true
@@ -361,7 +380,7 @@ func (c *Conn) Run(ctx context.Context) error {
361
380
}
362
381
ctx , rpcStats := start (ctx , true , e .r .Method , e .r .ID )
363
382
rpcStats .received += e .size
364
- c .Handler (ctx , e .c , e . r )
383
+ c .Handler (ctx , e .r )
365
384
rpcStats .end (ctx , nil )
366
385
}
367
386
}()
@@ -385,6 +404,7 @@ func (c *Conn) Run(ctx context.Context) error {
385
404
case msg .Method != "" :
386
405
// if method is set it must be a request
387
406
request := & Request {
407
+ conn : c ,
388
408
Method : msg .Method ,
389
409
Params : msg .Params ,
390
410
ID : msg .ID ,
@@ -407,7 +427,7 @@ func (c *Conn) Run(ctx context.Context) error {
407
427
c .Logger (Receive , request .ID , - 1 , request .Method , request .Params , nil )
408
428
if ! c .deliver (reqCtx , q , request , n ) {
409
429
// queue is full, reject the message by directly replying
410
- c .Reply (ctx , request , nil , NewErrorf (CodeServerOverloaded , "no room in queue" ))
430
+ request .Reply (ctx , nil , NewErrorf (CodeServerOverloaded , "no room in queue" ))
411
431
}
412
432
}
413
433
case msg .ID != nil :
@@ -419,7 +439,7 @@ func (c *Conn) Run(ctx context.Context) error {
419
439
}
420
440
c .pendingMu .Unlock ()
421
441
// and send the reply to the channel
422
- response := & Response {
442
+ response := & wireResponse {
423
443
Result : msg .Result ,
424
444
Error : msg .Error ,
425
445
ID : msg .ID ,
0 commit comments