Skip to content

Commit 04b019f

Browse files
authored
[SDK] Fix closing server-side streams & Polish NotifyIncomingFunds (#519)
* SDK: fix server-side streams * Fixes to client * Fix clients * Remove unnecessary waitgroup from NotifyIncomingFunds
1 parent 3d99e67 commit 04b019f

File tree

4 files changed

+185
-205
lines changed

4 files changed

+185
-205
lines changed

pkg/client-sdk/client.go

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/hex"
66
"fmt"
77
"strings"
8-
"sync"
98
"time"
109

1110
"github.com/ark-network/ark/common"
@@ -157,27 +156,19 @@ func (a *arkClient) NotifyIncomingFunds(
157156
if err != nil {
158157
return nil, err
159158
}
159+
defer closeFn()
160160

161-
wg := &sync.WaitGroup{}
162-
wg.Add(1)
163-
incomingVtxos := make([]types.Vtxo, 0)
164-
go func() {
165-
defer wg.Done()
166-
for event := range eventCh {
167-
if event.Err != nil {
168-
err = event.Err
169-
} else {
170-
for _, vtxo := range event.NewVtxos {
171-
incomingVtxos = append(incomingVtxos, toTypesVtxo(vtxo))
172-
}
173-
}
174-
closeFn()
175-
// nolint:all
176-
return
177-
}
178-
}()
179-
wg.Wait()
161+
event := <-eventCh
162+
163+
if event.Err != nil {
164+
err = event.Err
165+
return nil, err
166+
}
180167

168+
incomingVtxos := make([]types.Vtxo, 0)
169+
for _, vtxo := range event.NewVtxos {
170+
incomingVtxos = append(incomingVtxos, toTypesVtxo(vtxo))
171+
}
181172
return incomingVtxos, nil
182173
}
183174

pkg/client-sdk/client/client.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ const (
1919
RestClient = "rest"
2020
)
2121

22+
var (
23+
ErrConnectionClosedByServer = fmt.Errorf("connection closed by server")
24+
)
25+
2226
type RoundEvent interface {
2327
isRoundEvent()
2428
}

pkg/client-sdk/client/grpc/client.go

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ import (
1616
"github.com/ark-network/ark/pkg/client-sdk/internal/utils"
1717
"github.com/sirupsen/logrus"
1818
"google.golang.org/grpc"
19+
"google.golang.org/grpc/codes"
1920
"google.golang.org/grpc/credentials"
2021
"google.golang.org/grpc/credentials/insecure"
22+
"google.golang.org/grpc/status"
2123
)
2224

2325
type service struct {
@@ -203,17 +205,12 @@ func (a *grpcClient) SubmitSignedForfeitTxs(
203205

204206
func (a *grpcClient) GetEventStream(
205207
ctx context.Context, requestID string,
206-
) (ch <-chan client.RoundEventChannel, closeFn func(), err error) {
207-
req := &arkv1.GetEventStreamRequest{}
208+
) (<-chan client.RoundEventChannel, func(), error) {
208209
ctx, cancel := context.WithCancel(ctx)
209-
defer func() {
210-
if err != nil {
211-
cancel()
212-
}
213-
}()
214210

215-
stream, err := a.svc.GetEventStream(ctx, req)
211+
stream, err := a.svc.GetEventStream(ctx, &arkv1.GetEventStreamRequest{})
216212
if err != nil {
213+
cancel()
217214
return nil, nil, err
218215
}
219216

@@ -223,28 +220,37 @@ func (a *grpcClient) GetEventStream(
223220
defer close(eventsCh)
224221

225222
for {
226-
select {
227-
case <-stream.Context().Done():
228-
return
229-
default:
230-
resp, err := stream.Recv()
231-
if err != nil {
232-
eventsCh <- client.RoundEventChannel{Err: err}
223+
resp, err := stream.Recv()
224+
if err != nil {
225+
if err == io.EOF {
226+
eventsCh <- client.RoundEventChannel{Err: client.ErrConnectionClosedByServer}
233227
return
234228
}
235-
236-
ev, err := event{resp}.toRoundEvent()
237-
if err != nil {
238-
eventsCh <- client.RoundEventChannel{Err: err}
229+
if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled {
239230
return
240231
}
232+
eventsCh <- client.RoundEventChannel{Err: err}
233+
return
234+
}
241235

242-
eventsCh <- client.RoundEventChannel{Event: ev}
236+
ev, err := event{resp}.toRoundEvent()
237+
if err != nil {
238+
eventsCh <- client.RoundEventChannel{Err: err}
239+
return
243240
}
241+
242+
eventsCh <- client.RoundEventChannel{Event: ev}
244243
}
245244
}()
246245

247-
return eventsCh, cancel, nil
246+
closeFn := func() {
247+
if err := stream.CloseSend(); err != nil {
248+
logrus.Warnf("failed to close event stream: %s", err)
249+
}
250+
cancel()
251+
}
252+
253+
return eventsCh, closeFn, nil
248254
}
249255

250256
func (a *grpcClient) Ping(
@@ -345,28 +351,36 @@ func (c *grpcClient) Close() {
345351
func (c *grpcClient) GetTransactionsStream(
346352
ctx context.Context,
347353
) (<-chan client.TransactionEvent, func(), error) {
354+
ctx, cancel := context.WithCancel(ctx)
355+
348356
stream, err := c.svc.GetTransactionsStream(ctx, &arkv1.GetTransactionsStreamRequest{})
349357
if err != nil {
358+
cancel()
350359
return nil, nil, err
351360
}
352361

353-
eventCh := make(chan client.TransactionEvent)
362+
eventsCh := make(chan client.TransactionEvent)
354363

355364
go func() {
356-
defer close(eventCh)
365+
defer close(eventsCh)
366+
357367
for {
358368
resp, err := stream.Recv()
359-
if err == io.EOF {
360-
return
361-
}
362369
if err != nil {
363-
eventCh <- client.TransactionEvent{Err: err}
370+
if err == io.EOF {
371+
eventsCh <- client.TransactionEvent{Err: client.ErrConnectionClosedByServer}
372+
return
373+
}
374+
if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled {
375+
return
376+
}
377+
eventsCh <- client.TransactionEvent{Err: err}
364378
return
365379
}
366380

367381
switch tx := resp.Tx.(type) {
368382
case *arkv1.GetTransactionsStreamResponse_Round:
369-
eventCh <- client.TransactionEvent{
383+
eventsCh <- client.TransactionEvent{
370384
Round: &client.RoundTransaction{
371385
Txid: tx.Round.Txid,
372386
SpentVtxos: vtxos(tx.Round.SpentVtxos).toVtxos(),
@@ -376,7 +390,7 @@ func (c *grpcClient) GetTransactionsStream(
376390
},
377391
}
378392
case *arkv1.GetTransactionsStreamResponse_Redeem:
379-
eventCh <- client.TransactionEvent{
393+
eventsCh <- client.TransactionEvent{
380394
Redeem: &client.RedeemTransaction{
381395
Txid: tx.Redeem.Txid,
382396
SpentVtxos: vtxos(tx.Redeem.SpentVtxos).toVtxos(),
@@ -390,11 +404,12 @@ func (c *grpcClient) GetTransactionsStream(
390404

391405
closeFn := func() {
392406
if err := stream.CloseSend(); err != nil {
393-
logrus.Warnf("failed to close stream: %v", err)
407+
logrus.Warnf("failed to close transaction stream: %v", err)
394408
}
409+
cancel()
395410
}
396411

397-
return eventCh, closeFn, nil
412+
return eventsCh, closeFn, nil
398413
}
399414

400415
func (a *grpcClient) SetNostrRecipient(
@@ -421,28 +436,36 @@ func (a *grpcClient) DeleteNostrRecipient(
421436
func (c *grpcClient) SubscribeForAddress(
422437
ctx context.Context, addr string,
423438
) (<-chan client.AddressEvent, func(), error) {
439+
ctx, cancel := context.WithCancel(ctx)
440+
424441
stream, err := c.svc.SubscribeForAddress(ctx, &arkv1.SubscribeForAddressRequest{
425442
Address: addr,
426443
})
427444
if err != nil {
445+
cancel()
428446
return nil, nil, err
429447
}
430448

431-
eventCh := make(chan client.AddressEvent)
449+
eventsCh := make(chan client.AddressEvent)
432450

433451
go func() {
434-
defer close(eventCh)
452+
defer close(eventsCh)
453+
435454
for {
436455
resp, err := stream.Recv()
437-
if err == io.EOF {
438-
return
439-
}
440456
if err != nil {
441-
eventCh <- client.AddressEvent{Err: err}
457+
if err == io.EOF {
458+
eventsCh <- client.AddressEvent{Err: client.ErrConnectionClosedByServer}
459+
return
460+
}
461+
if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled {
462+
return
463+
}
464+
eventsCh <- client.AddressEvent{Err: err}
442465
return
443466
}
444467

445-
eventCh <- client.AddressEvent{
468+
eventsCh <- client.AddressEvent{
446469
NewVtxos: vtxos(resp.NewVtxos).toVtxos(),
447470
SpentVtxos: vtxos(resp.SpentVtxos).toVtxos(),
448471
}
@@ -451,11 +474,12 @@ func (c *grpcClient) SubscribeForAddress(
451474

452475
closeFn := func() {
453476
if err := stream.CloseSend(); err != nil {
454-
logrus.Warnf("failed to close stream: %v", err)
477+
logrus.Warnf("failed to close address stream: %v", err)
455478
}
479+
cancel()
456480
}
457481

458-
return eventCh, closeFn, nil
482+
return eventsCh, closeFn, nil
459483
}
460484

461485
func signedVtxosToProto(vtxos []client.SignedVtxoOutpoint) []*arkv1.SignedVtxoOutpoint {

0 commit comments

Comments
 (0)