diff --git a/internal/backend/user.go b/internal/backend/user.go index 573a4698..2b60780a 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -24,6 +24,9 @@ type user struct { states map[int]*State statesLock sync.RWMutex nextStateID int + + updateStopCh chan struct{} + updateWG sync.WaitGroup } func newUser(ctx context.Context, userID string, client *ent.Client, remote *remote.User, store store.Store, delimiter string) (*user, error) { @@ -32,27 +35,35 @@ func newUser(ctx context.Context, userID string, client *ent.Client, remote *rem } user := &user{ - userID: userID, - remote: remote, - store: store, - delimiter: delimiter, - client: client, - states: make(map[int]*State), + userID: userID, + remote: remote, + store: store, + delimiter: delimiter, + client: client, + states: make(map[int]*State), + updateStopCh: make(chan struct{}), } if err := user.deleteAllMessagesMarkedDeleted(ctx); err != nil { return nil, err } + user.updateWG.Add(1) + go func() { - for update := range remote.GetUpdates() { - update := update - - if err := user.tx(context.Background(), func(tx *ent.Tx) error { - defer update.Done() - return user.apply(context.Background(), tx, update) - }); err != nil { - logrus.WithError(err).Error("Failed to apply update") + defer user.updateWG.Done() + + for { + select { + case update := <-remote.GetUpdates(): + if err := user.tx(context.Background(), func(tx *ent.Tx) error { + defer update.Done() + return user.apply(context.Background(), tx, update) + }); err != nil { + logrus.WithError(err).Errorf("Failed to apply update: %v", update) + } + case <-user.updateStopCh: + return } } }() @@ -99,6 +110,10 @@ func (user *user) tx(ctx context.Context, fn func(tx *ent.Tx) error) error { func (user *user) close(ctx context.Context) error { user.closeStates() + // Wait until the connector update go routine has finished. + close(user.updateStopCh) + user.updateWG.Wait() + if err := user.remote.CloseAndSerializeOperationQueue(); err != nil { return fmt.Errorf("failed to close user remote: %w", err) }