@@ -35,12 +35,26 @@ type Conn struct {
35
35
pendingMu sync.Mutex // protects the pending map
36
36
pending map [ID ]chan * wireResponse
37
37
handlingMu sync.Mutex // protects the handling map
38
- handling map [ID ]handling
38
+ handling map [ID ]* Request
39
39
}
40
40
41
+ type requestState int
42
+
43
+ const (
44
+ requestWaiting = requestState (iota )
45
+ requestSerial
46
+ requestParallel
47
+ requestReplied
48
+ requestDone
49
+ )
50
+
41
51
// Request is sent to a server to represent a Call or Notify operaton.
42
52
type Request struct {
43
- conn * Conn
53
+ conn * Conn
54
+ cancel context.CancelFunc
55
+ start time.Time
56
+ state requestState
57
+ nextRequest chan struct {}
44
58
45
59
// Method is a string containing the method name to invoke.
46
60
Method string
@@ -52,12 +66,6 @@ type Request struct {
52
66
ID * ID
53
67
}
54
68
55
- type queueEntry struct {
56
- ctx context.Context
57
- r * Request
58
- size int64
59
- }
60
-
61
69
// Handler is an option you can pass to NewConn to handle incoming requests.
62
70
// If the request returns false from IsNotify then the Handler must eventually
63
71
// call Reply on the Conn with the supplied request.
@@ -75,7 +83,6 @@ type Canceler func(context.Context, *Conn, ID)
75
83
type rpcStats struct {
76
84
server bool
77
85
method string
78
- ctx context.Context
79
86
span trace.Span
80
87
start time.Time
81
88
received int64
@@ -87,13 +94,15 @@ type statsKeyType string
87
94
const rpcStatsKey = statsKeyType ("rpcStatsKey" )
88
95
89
96
func start (ctx context.Context , server bool , method string , id * ID ) (context.Context , * rpcStats ) {
97
+ if method == "" {
98
+ panic ("no method in rpc stats" )
99
+ }
90
100
s := & rpcStats {
91
101
server : server ,
92
102
method : method ,
93
- ctx : ctx ,
94
103
start : time .Now (),
95
104
}
96
- s . ctx = context .WithValue (s . ctx , rpcStatsKey , s )
105
+ ctx = context .WithValue (ctx , rpcStatsKey , s )
97
106
tags := make ([]tag.Mutator , 0 , 4 )
98
107
tags = append (tags , tag .Upsert (telemetry .KeyMethod , method ))
99
108
mode := telemetry .Outbound
@@ -106,10 +115,10 @@ func start(ctx context.Context, server bool, method string, id *ID) (context.Con
106
115
if id != nil {
107
116
tags = append (tags , tag .Upsert (telemetry .KeyRPCID , id .String ()))
108
117
}
109
- s . ctx , s .span = trace .StartSpan (ctx , method , trace .WithSpanKind (spanKind ))
110
- s . ctx , _ = tag .New (s . ctx , tags ... )
111
- stats .Record (s . ctx , telemetry .Started .M (1 ))
112
- return s . ctx , s
118
+ ctx , s .span = trace .StartSpan (ctx , method , trace .WithSpanKind (spanKind ))
119
+ ctx , _ = tag .New (ctx , tags ... )
120
+ stats .Record (ctx , telemetry .Started .M (1 ))
121
+ return ctx , s
113
122
}
114
123
115
124
func (s * rpcStats ) end (ctx context.Context , err * error ) {
@@ -145,11 +154,11 @@ func NewConn(s Stream) *Conn {
145
154
conn := & Conn {
146
155
stream : s ,
147
156
pending : make (map [ID ]chan * wireResponse ),
148
- handling : make (map [ID ]handling ),
157
+ handling : make (map [ID ]* Request ),
149
158
}
150
159
// the default handler reports a method error
151
160
conn .Handler = func (ctx context.Context , r * Request ) {
152
- if r .IsNotify () {
161
+ if ! r .IsNotify () {
153
162
r .Reply (ctx , nil , NewErrorf (CodeMethodNotFound , "method %q not found" , r .Method ))
154
163
}
155
164
}
@@ -273,28 +282,38 @@ func (r *Request) IsNotify() bool {
273
282
return r .ID == nil
274
283
}
275
284
285
+ // Parallel indicates that the system is now allowed to process other requests
286
+ // in parallel with this one.
287
+ // It is safe to call any number of times, but must only be called from the
288
+ // request handling go routine.
289
+ // It is implied by both reply and by the handler returning.
290
+ func (r * Request ) Parallel () {
291
+ if r .state >= requestParallel {
292
+ return
293
+ }
294
+ r .state = requestParallel
295
+ close (r .nextRequest )
296
+ }
297
+
276
298
// Reply sends a reply to the given request.
277
299
// It is an error to call this if request was not a call.
278
300
// You must call this exactly once for any given request.
301
+ // It should only be called from the handler go routine.
279
302
// If err is set then result will be ignored.
280
303
func (r * Request ) Reply (ctx context.Context , result interface {}, err error ) error {
281
- ctx , st := trace . StartSpan ( ctx , r . Method + ":reply" , trace . WithSpanKind ( trace . SpanKindClient ))
282
- defer st . End ( )
283
-
304
+ if r . state >= requestReplied {
305
+ return fmt . Errorf ( "reply invoked more than once" )
306
+ }
284
307
if r .IsNotify () {
285
308
return fmt .Errorf ("reply not invoked with a valid call" )
286
309
}
287
- r .conn .handlingMu .Lock ()
288
- handling , found := r .conn .handling [* r .ID ]
289
- if found {
290
- delete (r .conn .handling , * r .ID )
291
- }
292
- r .conn .handlingMu .Unlock ()
293
- if ! found {
294
- return fmt .Errorf ("not a call in progress: %v" , r .ID )
295
- }
310
+ ctx , st := trace .StartSpan (ctx , r .Method + ":reply" , trace .WithSpanKind (trace .SpanKindClient ))
311
+ defer st .End ()
312
+
313
+ r .Parallel ()
314
+ r .state = requestReplied
296
315
297
- elapsed := time .Since (handling .start )
316
+ elapsed := time .Since (r .start )
298
317
var raw * json.RawMessage
299
318
if err == nil {
300
319
raw , err = marshalToRaw (result )
@@ -319,10 +338,9 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro
319
338
320
339
v := ctx .Value (rpcStatsKey )
321
340
if v != nil {
322
- s := v .(* rpcStats )
323
- s .sent += n
341
+ v .(* rpcStats ).sent += n
324
342
} else {
325
- // panic("no stats available in reply")
343
+ panic ("no stats available in reply" )
326
344
}
327
345
328
346
if err != nil {
@@ -333,10 +351,17 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro
333
351
return nil
334
352
}
335
353
336
- type handling struct {
337
- request * Request
338
- cancel context.CancelFunc
339
- start time.Time
354
+ func (c * Conn ) setHandling (r * Request , active bool ) {
355
+ if r .ID == nil {
356
+ return
357
+ }
358
+ r .conn .handlingMu .Lock ()
359
+ defer r .conn .handlingMu .Unlock ()
360
+ if active {
361
+ r .conn .handling [* r .ID ] = r
362
+ } else {
363
+ delete (r .conn .handling , * r .ID )
364
+ }
340
365
}
341
366
342
367
// combined has all the fields of both Request and Response.
@@ -350,40 +375,13 @@ type combined struct {
350
375
Error * Error `json:"error,omitempty"`
351
376
}
352
377
353
- func (c * Conn ) deliver (ctx context.Context , q chan queueEntry , request * Request , size int64 ) bool {
354
- e := queueEntry {ctx : ctx , r : request , size : size }
355
- if ! c .RejectIfOverloaded {
356
- q <- e
357
- return true
358
- }
359
- select {
360
- case q <- e :
361
- return true
362
- default :
363
- return false
364
- }
365
- }
366
-
367
378
// Run blocks until the connection is terminated, and returns any error that
368
379
// caused the termination.
369
380
// It must be called exactly once for each Conn.
370
381
// It returns only when the reader is closed or there is an error in the stream.
371
382
func (c * Conn ) Run (ctx context.Context ) error {
372
- q := make (chan queueEntry , c .Capacity )
373
- defer close (q )
374
- // start the queue processor
375
- go func () {
376
- // TODO: idle notification?
377
- for e := range q {
378
- if e .ctx .Err () != nil {
379
- continue
380
- }
381
- ctx , rpcStats := start (ctx , true , e .r .Method , e .r .ID )
382
- rpcStats .received += e .size
383
- c .Handler (ctx , e .r )
384
- rpcStats .end (ctx , nil )
385
- }
386
- }()
383
+ nextRequest := make (chan struct {})
384
+ close (nextRequest )
387
385
for {
388
386
// get the data for a message
389
387
data , n , err := c .stream .Read (ctx )
@@ -403,33 +401,36 @@ func (c *Conn) Run(ctx context.Context) error {
403
401
switch {
404
402
case msg .Method != "" :
405
403
// if method is set it must be a request
406
- request := & Request {
407
- conn : c ,
408
- Method : msg .Method ,
409
- Params : msg .Params ,
410
- ID : msg .ID ,
411
- }
412
- if request .IsNotify () {
413
- c .Logger (Receive , request .ID , - 1 , request .Method , request .Params , nil )
414
- // we have a Notify, add to the processor queue
415
- c .deliver (ctx , q , request , n )
416
- //TODO: log when we drop a message?
417
- } else {
418
- // we have a Call, add to the processor queue
419
- reqCtx , cancelReq := context .WithCancel (ctx )
420
- c .handlingMu .Lock ()
421
- c .handling [* request .ID ] = handling {
422
- request : request ,
423
- cancel : cancelReq ,
424
- start : time .Now (),
425
- }
426
- c .handlingMu .Unlock ()
427
- c .Logger (Receive , request .ID , - 1 , request .Method , request .Params , nil )
428
- if ! c .deliver (reqCtx , q , request , n ) {
429
- // queue is full, reject the message by directly replying
430
- request .Reply (ctx , nil , NewErrorf (CodeServerOverloaded , "no room in queue" ))
431
- }
404
+ reqCtx , cancelReq := context .WithCancel (ctx )
405
+ reqCtx , rpcStats := start (reqCtx , true , msg .Method , msg .ID )
406
+ rpcStats .received += n
407
+ thisRequest := nextRequest
408
+ nextRequest = make (chan struct {})
409
+ req := & Request {
410
+ conn : c ,
411
+ cancel : cancelReq ,
412
+ nextRequest : nextRequest ,
413
+ start : time .Now (),
414
+ Method : msg .Method ,
415
+ Params : msg .Params ,
416
+ ID : msg .ID ,
432
417
}
418
+ c .setHandling (req , true )
419
+ go func () {
420
+ <- thisRequest
421
+ req .state = requestSerial
422
+ defer func () {
423
+ c .setHandling (req , false )
424
+ if ! req .IsNotify () && req .state < requestReplied {
425
+ req .Reply (reqCtx , nil , NewErrorf (CodeInternalError , "method %q did not reply" , req .Method ))
426
+ }
427
+ req .Parallel ()
428
+ rpcStats .end (reqCtx , nil )
429
+ cancelReq ()
430
+ }()
431
+ c .Logger (Receive , req .ID , - 1 , req .Method , req .Params , nil )
432
+ c .Handler (reqCtx , req )
433
+ }()
433
434
case msg .ID != nil :
434
435
// we have a response, get the pending entry from the map
435
436
c .pendingMu .Lock ()
0 commit comments