diff --git a/routing/dht/ext_test.go b/routing/dht/ext_test.go index 772724956b2..539d55ccaf5 100644 --- a/routing/dht/ext_test.go +++ b/routing/dht/ext_test.go @@ -52,7 +52,7 @@ func TestGetFailures(t *testing.T) { err = merr[0] } - if err != context.DeadlineExceeded { + if err != context.DeadlineExceeded && err != context.Canceled { t.Fatal("Got different error than we expected", err) } } else { diff --git a/routing/dht/query.go b/routing/dht/query.go index aacab106f16..3687bc85983 100644 --- a/routing/dht/query.go +++ b/routing/dht/query.go @@ -12,7 +12,8 @@ import ( pset "github.com/jbenet/go-ipfs/util/peerset" todoctr "github.com/jbenet/go-ipfs/util/todocounter" - ctxgroup "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-ctxgroup" + process "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/goprocess" + ctxproc "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/goprocess/context" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context" ) @@ -52,11 +53,17 @@ type queryFunc func(context.Context, peer.ID) (*dhtQueryResult, error) // Run runs the query at hand. pass in a list of peers to use first. func (q *dhtQuery) Run(ctx context.Context, peers []peer.ID) (*dhtQueryResult, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + ctx, cancel := context.WithCancel(ctx) defer cancel() - runner := newQueryRunner(ctx, q) - return runner.Run(peers) + runner := newQueryRunner(q) + return runner.Run(ctx, peers) } type dhtQueryRunner struct { @@ -71,22 +78,24 @@ type dhtQueryRunner struct { rateLimit chan struct{} // processing semaphore log eventlog.EventLogger - cg ctxgroup.ContextGroup + proc process.Process sync.RWMutex } -func newQueryRunner(ctx context.Context, q *dhtQuery) *dhtQueryRunner { +func newQueryRunner(q *dhtQuery) *dhtQueryRunner { + proc := process.WithParent(process.Background()) + ctx := ctxproc.WithProcessClosing(context.Background(), proc) return &dhtQueryRunner{ query: q, peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(q.key)), peersRemaining: todoctr.NewSyncCounter(), peersSeen: pset.New(), rateLimit: make(chan struct{}, q.concurrency), - cg: ctxgroup.WithContext(ctx), + proc: proc, } } -func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { +func (r *dhtQueryRunner) Run(ctx context.Context, peers []peer.ID) (*dhtQueryResult, error) { r.log = log if len(peers) == 0 { @@ -101,22 +110,30 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { // add all the peers we got first. for _, p := range peers { - r.addPeerToQuery(r.cg.Context(), p) + r.addPeerToQuery(p) } // go do this thing. - // do it as a child func to make sure Run exits + // do it as a child proc to make sure Run exits // ONLY AFTER spawn workers has exited. - r.cg.AddChildFunc(r.spawnWorkers) + r.proc.Go(r.spawnWorkers) // so workers are working. // wait until they're done. err := routing.ErrNotFound + // now, if the context finishes, close the proc. + // we have to do it here because the logic before is setup, which + // should run without closing the proc. + go func() { + <-ctx.Done() + r.proc.Close() + }() + select { case <-r.peersRemaining.Done(): - r.cg.Close() + r.proc.Close() r.RLock() defer r.RUnlock() @@ -128,12 +145,10 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { err = r.errs[0] } - case <-r.cg.Closed(): - log.Debug("r.cg.Closed()") - + case <-r.proc.Closed(): r.RLock() defer r.RUnlock() - err = r.cg.Context().Err() // collect the error. + err = context.DeadlineExceeded } if r.result != nil && r.result.success { @@ -143,7 +158,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { return nil, err } -func (r *dhtQueryRunner) addPeerToQuery(ctx context.Context, next peer.ID) { +func (r *dhtQueryRunner) addPeerToQuery(next peer.ID) { // if new peer is ourselves... if next == r.query.dht.self { r.log.Debug("addPeerToQuery skip self") @@ -157,18 +172,18 @@ func (r *dhtQueryRunner) addPeerToQuery(ctx context.Context, next peer.ID) { r.peersRemaining.Increment(1) select { case r.peersToQuery.EnqChan <- next: - case <-ctx.Done(): + case <-r.proc.Closing(): } } -func (r *dhtQueryRunner) spawnWorkers(parent ctxgroup.ContextGroup) { +func (r *dhtQueryRunner) spawnWorkers(proc process.Process) { for { select { case <-r.peersRemaining.Done(): return - case <-r.cg.Closing(): + case <-r.proc.Closing(): return case p, more := <-r.peersToQuery.DeqChan: @@ -178,24 +193,27 @@ func (r *dhtQueryRunner) spawnWorkers(parent ctxgroup.ContextGroup) { // do it as a child func to make sure Run exits // ONLY AFTER spawn workers has exited. - parent.AddChildFunc(func(cg ctxgroup.ContextGroup) { - r.queryPeer(cg, p) + proc.Go(func(proc process.Process) { + r.queryPeer(proc, p) }) } } } -func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { +func (r *dhtQueryRunner) queryPeer(proc process.Process, p peer.ID) { // make sure we rate limit concurrency. select { case <-r.rateLimit: - case <-cg.Closing(): + case <-proc.Closing(): r.peersRemaining.Decrement(1) return } // ok let's do this! + // create a context from our proc. + ctx := ctxproc.WithProcessClosing(context.Background(), proc) + // make sure we do this when we exit defer func() { // signal we're done proccessing peer p @@ -212,10 +230,11 @@ func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { r.rateLimit <- struct{}{} pi := peer.PeerInfo{ID: p} - if err := r.query.dht.host.Connect(cg.Context(), pi); err != nil { + + if err := r.query.dht.host.Connect(ctx, pi); err != nil { log.Debugf("Error connecting: %s", err) - notif.PublishQueryEvent(cg.Context(), ¬if.QueryEvent{ + notif.PublishQueryEvent(ctx, ¬if.QueryEvent{ Type: notif.QueryError, Extra: err.Error(), }) @@ -231,7 +250,7 @@ func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { } // finally, run the query against this peer - res, err := r.query.qfunc(cg.Context(), p) + res, err := r.query.qfunc(ctx, p) if err != nil { log.Debugf("ERROR worker for: %v %v", p, err) @@ -244,7 +263,7 @@ func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { r.Lock() r.result = res r.Unlock() - go r.cg.Close() // signal to everyone that we're done. + go r.proc.Close() // signal to everyone that we're done. // must be async, as we're one of the children, and Close blocks. } else if len(res.closerPeers) > 0 { @@ -257,7 +276,7 @@ func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { // add their addresses to the dialer's peerstore r.query.dht.peerstore.AddAddrs(next.ID, next.Addrs, peer.TempAddrTTL) - r.addPeerToQuery(cg.Context(), next.ID) + r.addPeerToQuery(next.ID) log.Debugf("PEERS CLOSER -- worker for: %v added %v (%v)", p, next.ID, next.Addrs) } } else {