diff --git a/graphsync.go b/graphsync.go index 83aa8202..831f1737 100644 --- a/graphsync.go +++ b/graphsync.go @@ -289,6 +289,9 @@ type GraphExchange interface { // RegisterPersistenceOption registers an alternate loader/storer combo that can be substituted for the default RegisterPersistenceOption(name string, loader ipld.Loader, storer ipld.Storer) error + // UnregisterPersistenceOption unregisters an alternate loader/storer combo + UnregisterPersistenceOption(name string) error + // RegisterIncomingRequestHook adds a hook that runs when a request is received RegisterIncomingRequestHook(hook OnIncomingRequestHook) UnregisterHookFunc diff --git a/impl/graphsync.go b/impl/graphsync.go index bff7070f..9828b6ba 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -160,6 +160,15 @@ func (gs *GraphSync) RegisterPersistenceOption(name string, loader ipld.Loader, return gs.persistenceOptions.Register(name, loader) } +// UnregisterPersistenceOption unregisters an alternate loader/storer combo +func (gs *GraphSync) UnregisterPersistenceOption(name string) error { + err := gs.asyncLoader.UnregisterPersistenceOption(name) + if err != nil { + return err + } + return gs.persistenceOptions.Unregister(name) +} + // RegisterOutgoingBlockHook registers a hook that runs after each block is sent in a response func (gs *GraphSync) RegisterOutgoingBlockHook(hook graphsync.OnOutgoingBlockHook) graphsync.UnregisterHookFunc { return gs.outgoingBlockHooks.Register(hook) diff --git a/requestmanager/asyncloader/asyncloader.go b/requestmanager/asyncloader/asyncloader.go index 9e4a0069..73f38219 100644 --- a/requestmanager/asyncloader/asyncloader.go +++ b/requestmanager/asyncloader/asyncloader.go @@ -79,34 +79,26 @@ func (al *AsyncLoader) RegisterPersistenceOption(name string, loader ipld.Loader return errors.New("Persistence option must have a name") } response := make(chan error, 1) - select { - case <-al.ctx.Done(): - return errors.New("context closed") - case al.incomingMessages <- ®isterPersistenceOptionMessage{name, loader, storer, response}: - } - select { - case <-al.ctx.Done(): - return errors.New("context closed") - case err := <-response: - return err + err := al.sendSyncMessage(®isterPersistenceOptionMessage{name, loader, storer, response}, response) + return err +} + +// UnregisterPersistenceOption unregisters an existing loader/storer option for processing requests +func (al *AsyncLoader) UnregisterPersistenceOption(name string) error { + if name == "" { + return errors.New("Persistence option must have a name") } + response := make(chan error, 1) + err := al.sendSyncMessage(&unregisterPersistenceOptionMessage{name, response}, response) + return err } // StartRequest indicates the given request has started and the manager should // continually attempt to load links for this request as new responses come in func (al *AsyncLoader) StartRequest(requestID graphsync.RequestID, persistenceOption string) error { response := make(chan error, 1) - select { - case <-al.ctx.Done(): - return errors.New("context closed") - case al.incomingMessages <- &startRequestMessage{requestID, persistenceOption, response}: - } - select { - case <-al.ctx.Done(): - return errors.New("context closed") - case err := <-response: - return err - } + err := al.sendSyncMessage(&startRequestMessage{requestID, persistenceOption, response}, response) + return err } // ProcessResponse injests new responses and completes asynchronous loads as @@ -123,17 +115,12 @@ func (al *AsyncLoader) ProcessResponse(responses map[graphsync.RequestID]metadat // for errors -- only one message will be sent over either. func (al *AsyncLoader) AsyncLoad(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult { resultChan := make(chan types.AsyncLoadResult, 1) - response := make(chan struct{}, 1) + response := make(chan error, 1) lr := loadattemptqueue.NewLoadRequest(requestID, link, resultChan) - select { - case <-al.ctx.Done(): - resultChan <- types.AsyncLoadResult{Data: nil, Err: errors.New("Context closed")} + err := al.sendSyncMessage(&loadRequestMessage{response, requestID, lr}, response) + if err != nil { + resultChan <- types.AsyncLoadResult{Data: nil, Err: err} close(resultChan) - case al.incomingMessages <- &loadRequestMessage{response, requestID, lr}: - } - select { - case <-al.ctx.Done(): - case <-response: } return resultChan } @@ -158,8 +145,22 @@ func (al *AsyncLoader) CleanupRequest(requestID graphsync.RequestID) { } } +func (al *AsyncLoader) sendSyncMessage(message loaderMessage, response chan error) error { + select { + case <-al.ctx.Done(): + return errors.New("Context Closed") + case al.incomingMessages <- message: + } + select { + case <-al.ctx.Done(): + return errors.New("Context Closed") + case err := <-response: + return err + } +} + type loadRequestMessage struct { - response chan struct{} + response chan error requestID graphsync.RequestID loadRequest loadattemptqueue.LoadRequest } @@ -176,6 +177,11 @@ type registerPersistenceOptionMessage struct { response chan error } +type unregisterPersistenceOptionMessage struct { + name string + response chan error +} + type startRequestMessage struct { requestID graphsync.RequestID persistenceOption string @@ -247,7 +253,7 @@ func (lrm *loadRequestMessage) handle(al *AsyncLoader) { loadAttemptQueue.AttemptLoad(lrm.loadRequest, retry) select { case <-al.ctx.Done(): - case lrm.response <- struct{}{}: + case lrm.response <- nil: } } @@ -269,6 +275,28 @@ func (rpom *registerPersistenceOptionMessage) handle(al *AsyncLoader) { } } +func (upom *unregisterPersistenceOptionMessage) unregister(al *AsyncLoader) error { + _, ok := al.alternateQueues[upom.name] + if !ok { + return errors.New("Unknown persistence option") + } + for _, requestQueue := range al.requestQueues { + if upom.name == requestQueue { + return errors.New("cannot unregister while requests are in progress") + } + } + delete(al.alternateQueues, upom.name) + return nil +} + +func (upom *unregisterPersistenceOptionMessage) handle(al *AsyncLoader) { + err := upom.unregister(al) + select { + case <-al.ctx.Done(): + case upom.response <- err: + } +} + func (srm *startRequestMessage) startRequest(al *AsyncLoader) error { if srm.persistenceOption != "" { _, ok := al.alternateQueues[srm.persistenceOption] diff --git a/requestmanager/asyncloader/asyncloader_test.go b/requestmanager/asyncloader/asyncloader_test.go index debc54f6..c39076df 100644 --- a/requestmanager/asyncloader/asyncloader_test.go +++ b/requestmanager/asyncloader/asyncloader_test.go @@ -188,6 +188,35 @@ func TestAsyncLoadTwiceLoadsLocallySecondTime(t *testing.T) { }) } +func TestRegisterUnregister(t *testing.T) { + st := newStore() + otherSt := newStore() + blocks := testutil.GenerateBlocksOfSize(3, 100) + link1 := otherSt.Store(t, blocks[0]) + withLoader(st, func(ctx context.Context, asyncLoader *AsyncLoader) { + + requestID1 := graphsync.RequestID(rand.Int31()) + err := asyncLoader.StartRequest(requestID1, "other") + require.EqualError(t, err, "Unknown persistence option") + + err = asyncLoader.RegisterPersistenceOption("other", otherSt.loader, otherSt.storer) + requestID2 := graphsync.RequestID(rand.Int31()) + err = asyncLoader.StartRequest(requestID2, "other") + require.NoError(t, err) + resultChan1 := asyncLoader.AsyncLoad(requestID2, link1) + assertSuccessResponse(ctx, t, resultChan1) + err = asyncLoader.UnregisterPersistenceOption("other") + require.EqualError(t, err, "cannot unregister while requests are in progress") + asyncLoader.CompleteResponsesFor(requestID2) + asyncLoader.CleanupRequest(requestID2) + err = asyncLoader.UnregisterPersistenceOption("other") + require.NoError(t, err) + + requestID3 := graphsync.RequestID(rand.Int31()) + err = asyncLoader.StartRequest(requestID3, "other") + require.EqualError(t, err, "Unknown persistence option") + }) +} func TestRequestSplittingLoadLocallyFromBlockstore(t *testing.T) { st := newStore() otherSt := newStore() diff --git a/responsemanager/persistenceoptions/persistenceoptions.go b/responsemanager/persistenceoptions/persistenceoptions.go index 7701dc14..a082d27f 100644 --- a/responsemanager/persistenceoptions/persistenceoptions.go +++ b/responsemanager/persistenceoptions/persistenceoptions.go @@ -32,6 +32,18 @@ func (po *PersistenceOptions) Register(name string, loader ipld.Loader) error { return nil } +// Unregister unregisters a loader for the response manager +func (po *PersistenceOptions) Unregister(name string) error { + po.persistenceOptionsLk.Lock() + defer po.persistenceOptionsLk.Unlock() + _, ok := po.persistenceOptions[name] + if !ok { + return errors.New("persistence option is not registered") + } + delete(po.persistenceOptions, name) + return nil +} + // GetLoader returns the loader for the named persistence option func (po *PersistenceOptions) GetLoader(name string) (ipld.Loader, bool) { po.persistenceOptionsLk.RLock()