diff --git a/internal/state/actions.go b/internal/state/actions.go index 9b11c8f7..a9aad84d 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -59,16 +59,16 @@ func (state *State) actionCreateMailbox(ctx context.Context, tx *ent.Tx, name st return db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity) } -func (state *State) actionDeleteMailbox(ctx context.Context, tx *ent.Tx, mboxID ids.MailboxIDPair) error { +func (state *State) actionDeleteMailbox(ctx context.Context, tx *ent.Tx, mboxID ids.MailboxIDPair) ([]Update, error) { if err := state.user.GetRemote().DeleteMailbox(ctx, mboxID.RemoteID); err != nil { - return err + return nil, err } if err := db.DeleteMailboxWithRemoteID(ctx, tx, mboxID.RemoteID); err != nil { - return err + return nil, err } - return state.user.QueueOrApplyStateUpdate(ctx, tx, NewMailboxDeletedStateUpdate(mboxID.InternalID)) + return []Update{NewMailboxDeletedStateUpdate(mboxID.InternalID)}, nil } func (state *State) actionUpdateMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.MailboxID, newName string) error { @@ -92,17 +92,17 @@ func (state *State) actionCreateMessage( date time.Time, isSelectedMailbox bool, cameFromDrafts bool, -) (imap.UID, error) { +) ([]Update, imap.UID, error) { internalID, res, newLiteral, err := state.user.GetRemote().CreateMessage(ctx, mboxID.RemoteID, literal, flags, date) if err != nil { - return 0, err + return nil, 0, err } { // Handle the case where duplicate messages can return the same remote ID. knownInternalID, knownErr := db.GetMessageIDFromRemoteID(ctx, tx.Client(), res.ID) if knownErr != nil && !ent.IsNotFound(knownErr) { - return 0, knownErr + return nil, 0, knownErr } if knownErr == nil { // Try to collect the original message date. @@ -121,37 +121,37 @@ func (state *State) actionCreateMessage( logrus.Errorf("Append to drafts must not return an existing RemoteID (Remote=%v, Internal=%v)", res.ID, knownInternalID) - return 0, fmt.Errorf("append to drafts returned an existing remote ID") + return nil, 0, fmt.Errorf("append to drafts returned an existing remote ID") } logrus.Debugf("Deduped message detected, adding existing %v message to mailbox instead.", knownInternalID.ShortID()) - result, err := state.actionAddMessagesToMailbox(ctx, + updates, result, err := state.actionAddMessagesToMailbox(ctx, tx, []ids.MessageIDPair{{InternalID: knownInternalID, RemoteID: res.ID}}, mboxID, isSelectedMailbox, ) if err != nil { - return 0, err + return nil, 0, err } - return result[0].UID, nil + return updates, result[0].UID, nil } } parsedMessage, err := imap.NewParsedMessage(newLiteral) if err != nil { - return 0, err + return nil, 0, err } literalWithHeader, literalSize, err := rfc822.SetHeaderValueNoMemCopy(newLiteral, ids.InternalIDKey, internalID.String()) if err != nil { - return 0, fmt.Errorf("failed to set internal ID: %w", err) + return nil, 0, fmt.Errorf("failed to set internal ID: %w", err) } if err := state.user.GetStore().SetUnchecked(internalID, literalWithHeader); err != nil { - return 0, fmt.Errorf("failed to store message literal: %w", err) + return nil, 0, fmt.Errorf("failed to store message literal: %w", err) } req := db.CreateMessageReq{ @@ -165,7 +165,7 @@ func (state *State) actionCreateMessage( messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, mboxID.InternalID, &req) if err != nil { - return 0, err + return nil, 0, err } // We can append to non-selected mailboxes. @@ -175,15 +175,14 @@ func (state *State) actionCreateMessage( st = state } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, newExistsStateUpdateWithExists( + updates := []Update{newExistsStateUpdateWithExists( mboxID.InternalID, []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: res.ID}, messageUID, flagSet)}, st, - )); err != nil { - return 0, err + ), } - return messageUID, nil + return updates, messageUID, nil } func (state *State) actionCreateRecoveredMessage( @@ -192,23 +191,23 @@ func (state *State) actionCreateRecoveredMessage( literal []byte, flags imap.FlagSet, date time.Time, -) (bool, error) { +) ([]Update, bool, error) { internalID := imap.NewInternalMessageID() remoteID := ids.NewRecoveredRemoteMessageID(internalID) parsedMessage, err := imap.NewParsedMessage(literal) if err != nil { - return false, err + return nil, false, err } alreadyKnown, err := state.user.GetRecoveredMessageHashesMap().Insert(internalID, literal) if err == nil && alreadyKnown { // Message is already known to us, so we ignore it. - return true, nil + return nil, true, nil } if err := state.user.GetStore().SetUnchecked(internalID, bytes.NewReader(literal)); err != nil { - return false, fmt.Errorf("failed to store message literal: %w", err) + return nil, false, fmt.Errorf("failed to store message literal: %w", err) } req := db.CreateMessageReq{ @@ -228,18 +227,17 @@ func (state *State) actionCreateRecoveredMessage( messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, recoveryMBoxID.InternalID, &req) if err != nil { - return false, err + return nil, false, err } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, newExistsStateUpdateWithExists( + var updates = []Update{newExistsStateUpdateWithExists( recoveryMBoxID.InternalID, []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: remoteID}, messageUID, flagSet)}, nil, - )); err != nil { - return false, err + ), } - return false, nil + return updates, false, nil } func (state *State) actionAddMessagesToMailbox( @@ -248,26 +246,31 @@ func (state *State) actionAddMessagesToMailbox( messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, isMailboxSelected bool, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { + var allUpdates []Update + { haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) if err != nil { - return nil, err + return nil, nil, err } if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { return slices.Contains(haveMessageIDs, messageID.InternalID) }); len(remMessageIDs) > 0 { - if err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxID); err != nil { - return nil, err + updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxID) + if err != nil { + return nil, nil, err } + + allUpdates = append(allUpdates, updates...) } } internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) if err := state.user.GetRemote().AddMessagesToMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { - return nil, err + return nil, nil, err } // Messages can be added to a mailbox that is not selected. @@ -278,14 +281,12 @@ func (state *State) actionAddMessagesToMailbox( messageUIDs, update, err := AddMessagesToMailbox(ctx, tx, mboxID.InternalID, internalIDs, st, state.imapLimits) if err != nil { - return nil, err + return nil, nil, err } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, update); err != nil { - return nil, err - } + allUpdates = append(allUpdates, update) - return messageUIDs, nil + return allUpdates, messageUIDs, nil } func (state *State) actionAddRecoveredMessagesToMailbox( @@ -382,14 +383,14 @@ func (state *State) actionCopyMessagesOutOfRecoveryMailbox( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { ids := make([]ids.MessageIDPair, 0, len(messageIDs)) // Import messages to remote. for _, id := range messageIDs { id, _, err := state.actionImportRecoveredMessage(ctx, tx, id.InternalID, mboxID.RemoteID) if err != nil { - return nil, err + return nil, nil, err } ids = append(ids, id) @@ -398,14 +399,10 @@ func (state *State) actionCopyMessagesOutOfRecoveryMailbox( // Label messages in destination. uidWithFlags, update, err := state.actionAddRecoveredMessagesToMailbox(ctx, tx, ids, mboxID) if err != nil { - return nil, err - } - - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, update); err != nil { - return nil, err + return nil, nil, err } - return uidWithFlags, nil + return []Update{update}, uidWithFlags, nil } func (state *State) actionMoveMessagesOutOfRecoveryMailbox( @@ -413,7 +410,7 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { ids := make([]ids.MessageIDPair, 0, len(messageIDs)) oldInternalIDs := make([]imap.InternalMessageID, 0, len(messageIDs)) @@ -421,12 +418,12 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( for _, id := range messageIDs { newID, deduped, err := state.actionImportRecoveredMessage(ctx, tx, id.InternalID, mboxID.RemoteID) if err != nil { - return nil, err + return nil, nil, err } if !deduped { if err := db.MarkMessageAsDeleted(ctx, tx, id.InternalID); err != nil { - return nil, err + return nil, nil, err } } @@ -439,7 +436,7 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( { removeUpdates, err := RemoveMessagesFromMailbox(ctx, tx, state.user.GetRecoveryMailboxID().InternalID, oldInternalIDs) if err != nil { - return nil, err + return nil, nil, err } state.user.GetRecoveredMessageHashesMap().Erase(oldInternalIDs...) @@ -450,17 +447,12 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( // Label messages in destination. uidWithFlags, update, err := state.actionAddRecoveredMessagesToMailbox(ctx, tx, ids, mboxID) if err != nil { - return nil, err + return nil, nil, err } - // Publish all updates in unison. updates = append(updates, update) - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, updates...); err != nil { - return nil, err - } - - return uidWithFlags, nil + return updates, uidWithFlags, nil } // actionRemoveMessagesFromMailboxUnchecked is similar to actionRemoveMessagesFromMailbox, but it does not validate @@ -471,23 +463,18 @@ func (state *State) actionRemoveMessagesFromMailboxUnchecked( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) error { +) ([]Update, error) { internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) if mboxID.InternalID != state.user.GetRecoveryMailboxID().InternalID { if err := state.user.GetRemote().RemoveMessagesFromMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { - return err + return nil, err } } else { state.user.GetRecoveredMessageHashesMap().Erase(internalIDs...) } - updates, err := RemoveMessagesFromMailbox(ctx, tx, mboxID.InternalID, internalIDs) - if err != nil { - return err - } - - return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) + return RemoveMessagesFromMailbox(ctx, tx, mboxID.InternalID, internalIDs) } func (state *State) actionRemoveMessagesFromMailbox( @@ -495,10 +482,10 @@ func (state *State) actionRemoveMessagesFromMailbox( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) error { +) ([]Update, error) { haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) if err != nil { - return err + return nil, err } messageIDs = xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { @@ -506,7 +493,7 @@ func (state *State) actionRemoveMessagesFromMailbox( }) if len(messageIDs) == 0 { - return nil + return nil, nil } return state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, messageIDs, mboxID) @@ -517,31 +504,41 @@ func (state *State) actionMoveMessages( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxFromID, mboxToID ids.MailboxIDPair, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { + var allUpdates []Update + if mboxFromID.InternalID == mboxToID.InternalID { internalIDs, _ := ids.SplitMessageIDPairSlice(messageIDs) - return db.BumpMailboxUIDsForMessage(ctx, tx, internalIDs, mboxToID.InternalID) + uid, err := db.BumpMailboxUIDsForMessage(ctx, tx, internalIDs, mboxToID.InternalID) + if err != nil { + return nil, nil, err + } + + return nil, uid, nil } { messageIDsToAdd, err := db.FilterMailboxContains(ctx, tx.Client(), mboxToID.InternalID, messageIDs) if err != nil { - return nil, err + return nil, nil, err } if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { return slices.Contains(messageIDsToAdd, messageID.InternalID) }); len(remMessageIDs) > 0 { - if err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxToID); err != nil { - return nil, err + updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxToID) + if err != nil { + return nil, nil, err } + + allUpdates = append(allUpdates, updates...) } } messageInFromMBox, err := db.FilterMailboxContains(ctx, tx.Client(), mboxFromID.InternalID, messageIDs) if err != nil { - return nil, err + return nil, nil, err } messagesIDsToMove := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { @@ -552,19 +549,17 @@ func (state *State) actionMoveMessages( shouldRemoveOldMessages, err := state.user.GetRemote().MoveMessagesFromMailbox(ctx, remoteIDs, mboxFromID.RemoteID, mboxToID.RemoteID) if err != nil { - return nil, err + return nil, nil, err } messageUIDs, updates, err := MoveMessagesFromMailbox(ctx, tx, mboxFromID.InternalID, mboxToID.InternalID, internalIDs, state, state.imapLimits, shouldRemoveOldMessages) if err != nil { - return nil, err + return nil, nil, err } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, updates...); err != nil { - return nil, err - } + allUpdates = append(allUpdates, updates...) - return messageUIDs, nil + return allUpdates, messageUIDs, nil } func (state *State) actionAddMessageFlags( @@ -572,16 +567,12 @@ func (state *State) actionAddMessageFlags( tx *ent.Tx, messages []snapMsgWithSeq, addFlags imap.FlagSet, -) error { +) ([]Update, error) { internalMessageIDs := xslices.Map(messages, func(sm snapMsgWithSeq) imap.InternalMessageID { return sm.ID.InternalID }) - if err := state.applyMessageFlagsAdded(ctx, tx, internalMessageIDs, addFlags); err != nil { - return err - } - - return nil + return state.applyMessageFlagsAdded(ctx, tx, internalMessageIDs, addFlags) } func (state *State) actionRemoveMessageFlags( @@ -589,21 +580,20 @@ func (state *State) actionRemoveMessageFlags( tx *ent.Tx, messages []snapMsgWithSeq, remFlags imap.FlagSet, -) error { +) ([]Update, error) { internalMessageIDs := xslices.Map(messages, func(sm snapMsgWithSeq) imap.InternalMessageID { return sm.ID.InternalID }) - if err := state.applyMessageFlagsRemoved(ctx, tx, internalMessageIDs, remFlags); err != nil { - return err - } - - return nil + return state.applyMessageFlagsRemoved(ctx, tx, internalMessageIDs, remFlags) } -func (state *State) actionSetMessageFlags(ctx context.Context, tx *ent.Tx, messages []snapMsgWithSeq, setFlags imap.FlagSet) error { +func (state *State) actionSetMessageFlags(ctx context.Context, + tx *ent.Tx, + messages []snapMsgWithSeq, + setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("recent flag is read-only") + return nil, fmt.Errorf("recent flag is read-only") } internalMessageIDs := xslices.Map(messages, func(sm snapMsgWithSeq) imap.InternalMessageID { diff --git a/internal/state/mailbox.go b/internal/state/mailbox.go index f6135b52..f3ce0cc8 100644 --- a/internal/state/mailbox.go +++ b/internal/state/mailbox.go @@ -85,19 +85,19 @@ func (m *Mailbox) Count() int { } func (m *Mailbox) Flags(ctx context.Context) (imap.FlagSet, error) { - return db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { return db.GetMailboxFlags(ctx, client, m.id.InternalID) }) } func (m *Mailbox) PermanentFlags(ctx context.Context) (imap.FlagSet, error) { - return db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { return db.GetMailboxPermanentFlags(ctx, client, m.id.InternalID) }) } func (m *Mailbox) Attributes(ctx context.Context) (imap.FlagSet, error) { - return db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { return db.GetMailboxAttributes(ctx, client, m.id.InternalID) }) } @@ -147,7 +147,7 @@ func (m *Mailbox) GetMessagesWithoutFlagCount(flag string) int { } func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap.FlagSet, date time.Time) (imap.UID, error) { - if err := m.state.db().Read(ctx, func(ctx context.Context, client *ent.Client) error { + if err := stateDBRead(ctx, m.state, func(ctx context.Context, client *ent.Client) error { if messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, client, m.snap.mboxID.InternalID); err != nil { return err } else { @@ -185,7 +185,7 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. return 0, err } - if message, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { + if message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { message, err := db.GetMessageWithIDWithDeletedFlag(ctx, client, msgID) if err != nil { if ent.IsNotFound(err) { @@ -201,10 +201,10 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. } else if !message.Deleted { logrus.Debugf("Appending duplicate message with Internal ID:%v", msgID.ShortID()) // Only shuffle around messages that haven't been marked for deletion. - if res, err := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]db.UIDWithFlags, error) { + if res, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { remoteID, err := db.GetMessageRemoteIDFromID(ctx, tx.Client(), msgID) if err != nil { - return nil, err + return nil, nil, err } return m.state.actionAddMessagesToMailbox(ctx, tx, @@ -229,7 +229,7 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. } } - return db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (imap.UID, error) { + return stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.UID, error) { return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date, m.snap == m.state.snap, appendIntoDrafts) }) } @@ -245,7 +245,7 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet } // Failed to append to mailbox attempt to insert into recovery mailbox. - knownMessage, recoverErr := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (bool, error) { + knownMessage, recoverErr := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, bool, error) { return m.state.actionCreateRecoveredMessage(ctx, tx, literal, flags, date) }) if recoverErr != nil && !knownMessage { @@ -269,7 +269,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -290,7 +290,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { return m.state.actionCopyMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) } else { @@ -320,7 +320,7 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -341,7 +341,7 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { return m.state.actionMoveMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) } else { @@ -369,25 +369,19 @@ func (m *Mailbox) Store(ctx context.Context, seqSet []command.SeqRange, action c return err } - return m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { switch action { case command.StoreActionAddFlags: - if err := m.state.actionAddMessageFlags(ctx, tx, messages, flags); err != nil { - return err - } + return m.state.actionAddMessageFlags(ctx, tx, messages, flags) case command.StoreActionRemFlags: - if err := m.state.actionRemoveMessageFlags(ctx, tx, messages, flags); err != nil { - return err - } + return m.state.actionRemoveMessageFlags(ctx, tx, messages, flags) case command.StoreActionSetFlags: - if err := m.state.actionSetMessageFlags(ctx, tx, messages, flags); err != nil { - return err - } + return m.state.actionSetMessageFlags(ctx, tx, messages, flags) } - return nil + return nil, fmt.Errorf("unknown flag action") }) } @@ -411,7 +405,7 @@ func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { msgIDs = m.snap.getAllMessagesIDsMarkedDelete() } - return m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID) }) } diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index 0fa379e0..fc2e8134 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -118,7 +118,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons defer async.HandlePanic(m.state.panicHandler) msg := snapMessages[i] - message, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { + message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { return db.GetMessage(ctx, client, msg.ID.InternalID) }) if err != nil { @@ -175,7 +175,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons }) if len(msgsToBeMarkedSeen) != 0 { - if err := m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { return m.state.actionAddMessageFlags(ctx, tx, msgsToBeMarkedSeen, imap.NewFlagSet(imap.FlagSeen)) }); err != nil { return err diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 6a4cce41..038a18d6 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -88,7 +88,7 @@ func buildSearchData(ctx context.Context, m *Mailbox, op *buildSearchOpResult, m data := searchData{message: message} if op.needsMessage { - dbm, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { + dbm, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { return db.GetMessageDateAndSize(ctx, client, message.ID.InternalID) }) if err != nil { diff --git a/internal/state/state.go b/internal/state/state.go index d113fde0..3b559bab 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -76,12 +76,8 @@ func (state *State) UserID() string { return state.user.GetUserID() } -func (state *State) db() *db.DB { - return state.user.GetDB() -} - func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn func(map[string]Match) error) error { - return state.db().Read(ctx, func(ctx context.Context, client *ent.Client) error { + return stateDBRead(ctx, state, func(ctx context.Context, client *ent.Client) error { mailboxes, err := db.GetAllMailboxes(ctx, client) if err != nil { return err @@ -168,7 +164,7 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn } func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -181,15 +177,15 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } } - snap, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { return err } - if err := state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - return db.ClearRecentFlags(ctx, tx, mbox.ID) + if err := stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + return nil, db.ClearRecentFlags(ctx, tx, mbox.ID) }); err != nil { return err } @@ -201,7 +197,7 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -214,7 +210,7 @@ func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) } } - snap, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -251,13 +247,13 @@ func (state *State) Create(ctx context.Context, name string) error { } } - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { client := tx.Client() if mailboxCount, err := db.GetMailboxCount(ctx, client); err != nil { - return err + return nil, err } else if err := state.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { - return err + return nil, err } var mboxesToCreate []string @@ -268,14 +264,14 @@ func (state *State) Create(ctx context.Context, name string) error { } if exists, err := db.MailboxExistsWithName(ctx, client, name); err != nil { - return err + return nil, err } else if exists { - return ErrExistingMailbox + return nil, ErrExistingMailbox } for _, superior := range listSuperiors(name, state.delimiter) { if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { - return err + return nil, err } else if exists { continue } @@ -287,11 +283,11 @@ func (state *State) Create(ctx context.Context, name string) error { for _, mboxName := range mboxesToCreate { if err := state.actionCreateMailbox(ctx, tx, mboxName, uidValidity); err != nil { - return err + return nil, err } } - return nil + return nil, nil }) } @@ -301,13 +297,18 @@ func (state *State) Delete(ctx context.Context, name string) (bool, error) { return false, ErrOperationNotAllowed } - mboxID, err := db.WriteResult(ctx, state.db(), func(ctx context.Context, tx *ent.Tx) (imap.InternalMailboxID, error) { + mboxID, err := stateDBWriteResult(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.InternalMailboxID, error) { mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) if err != nil { - return 0, ErrNoSuchMailbox + return nil, 0, ErrNoSuchMailbox + } + + update, err := state.actionDeleteMailbox(ctx, tx, ids.NewMailboxIDPair(mbox)) + if err != nil { + return nil, 0, err } - return mbox.ID, state.actionDeleteMailbox(ctx, tx, ids.NewMailboxIDPair(mbox)) + return update, mbox.ID, nil }) if err != nil { return false, err @@ -326,27 +327,27 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { return ErrOperationNotAllowed } - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { client := tx.Client() mbox, err := db.GetMailboxByName(ctx, client, oldName) if err != nil { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } if exists, err := db.MailboxExistsWithName(ctx, client, newName); err != nil { - return err + return nil, err } else if exists { - return ErrExistingMailbox + return nil, ErrExistingMailbox } var mboxesToCreate []string for _, superior := range listSuperiors(newName, state.delimiter) { if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { - return err + return nil, err } else if exists { if superior == oldName { - return ErrExistingMailbox + return nil, ErrExistingMailbox } continue } @@ -357,16 +358,16 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { for _, m := range mboxesToCreate { uidValidity, err := state.user.GenerateUIDValidity() if err != nil { - return err + return nil, err } res, err := state.user.GetRemote().CreateMailbox(ctx, strings.Split(m, state.delimiter)) if err != nil { - return err + return nil, err } if err := db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity); err != nil { - return err + return nil, err } } @@ -375,13 +376,13 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } if err := state.actionUpdateMailbox(ctx, tx, mbox.RemoteID, newName); err != nil { - return err + return nil, err } // Locally update all inferiors so we don't wait for update mailboxes, err := db.GetAllMailboxes(ctx, tx.Client()) if err != nil { - return err + return nil, err } inferiors := listInferiors(oldName, state.delimiter, xslices.Map(mailboxes, func(mailbox *ent.Mailbox) string { @@ -391,54 +392,54 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { for _, inferior := range inferiors { mbox, err := db.GetMailboxByName(ctx, tx.Client(), inferior) if err != nil { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } newInferior := newName + strings.TrimPrefix(inferior, oldName) if err := db.RenameMailboxWithRemoteID(ctx, tx, mbox.RemoteID, newInferior); err != nil { - return err + return nil, err } } - return nil + return nil, nil }) } func (state *State) Subscribe(ctx context.Context, name string) error { - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) if err != nil { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } if mbox.Subscribed { - return ErrAlreadySubscribed + return nil, ErrAlreadySubscribed } - return mbox.Update().SetSubscribed(true).Exec(ctx) + return nil, mbox.Update().SetSubscribed(true).Exec(ctx) }) } func (state *State) Unsubscribe(ctx context.Context, name string) error { - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) if err != nil { // If mailbox does not exist, check that if it is present in the deleted subscription table if count, err := db.RemoveDeletedSubscriptionWithName(ctx, tx, name); err != nil { - return err + return nil, err } else if count == 0 { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } else { - return nil + return nil, nil } } if !mbox.Subscribed { - return ErrAlreadyUnsubscribed + return nil, ErrAlreadyUnsubscribed } - return mbox.Update().SetSubscribed(false).Exec(ctx) + return nil, mbox.Update().SetSubscribed(false).Exec(ctx) }) } @@ -454,7 +455,7 @@ func (state *State) Idle(ctx context.Context, fn func([]response.Response, chan } func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -465,7 +466,7 @@ func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) return fn(newMailbox(mbox, state, state.snap)) } - snap, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -483,7 +484,7 @@ func (state *State) AppendOnlyMailbox(ctx context.Context, name string, fn func( return ErrOperationNotAllowed } - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -504,7 +505,7 @@ func (state *State) Selected(ctx context.Context, fn func(*Mailbox) error) error return ErrSessionNotSelected } - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByID(ctx, client, state.snap.mboxID.InternalID) }) if err != nil { @@ -571,7 +572,7 @@ func (state *State) ApplyUpdate(ctx context.Context, update Update) error { return nil } - if err := state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { return update.Apply(ctx, tx, state) }); err != nil { reporter.MessageWithContext(ctx, @@ -598,29 +599,30 @@ func (state *State) markInvalid() { } // renameInbox creates a new mailbox and moves everything there. -func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mailbox, newName string) error { +func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mailbox, newName string) ([]Update, error) { uidValidity, err := state.user.GenerateUIDValidity() if err != nil { - return err + return nil, err } mbox, err := state.actionCreateAndGetMailbox(ctx, tx, newName, uidValidity) if err != nil { - return err + return nil, err } messageIDs, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), inbox.ID) if err != nil { - return err + return nil, err } mboxIDPair := ids.NewMailboxIDPair(mbox) - if _, err := state.actionMoveMessages(ctx, tx, messageIDs, ids.NewMailboxIDPair(inbox), mboxIDPair); err != nil { - return err + updates, _, err := state.actionMoveMessages(ctx, tx, messageIDs, ids.NewMailboxIDPair(inbox), mboxIDPair) + if err != nil { + return nil, err } - return nil + return updates, nil } func (state *State) beginIdle(ctx context.Context) ([]response.Response, error) { @@ -708,7 +710,7 @@ func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]r } } - if err := state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { for _, update := range dbUpdates { if err := update.apply(ctx, tx); err != nil { return err @@ -830,3 +832,59 @@ func (state *State) close() error { return nil } + +func stateDBRead(ctx context.Context, state *State, fn func(context.Context, *ent.Client) error) error { + return state.user.GetDB().Read(ctx, fn) +} + +func stateDBReadResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Client) (T, error)) (T, error) { + return db.ReadResult(ctx, state.user.GetDB(), fn) +} + +func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, error)) error { + var updates []Update + + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + up, err := fn(ctx, tx) + updates = up + return err + }); err != nil { + return err + } + + // need to create a separate transaction for the state updates so that import changes get written first. + if len(updates) != 0 { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) + }); err != nil { + return err + } + } + + return nil +} + +func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, T, error)) (T, error) { + var updates []Update + + result, err := db.WriteResult(ctx, state.user.GetDB(), func(ctx context.Context, tx *ent.Tx) (T, error) { + up, val, err := fn(ctx, tx) + updates = up + return val, err + }) + if err != nil { + var t T + return t, err + } + + // need to create a separate transaction for the state updates so that import changes get written first. + if len(updates) != 0 { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) + }); err != nil { + return result, err + } + } + + return result, nil +} diff --git a/internal/state/updates.go b/internal/state/updates.go index 8dcb89db..d8e0b783 100644 --- a/internal/state/updates.go +++ b/internal/state/updates.go @@ -98,9 +98,9 @@ func (u *messageFlagsAddedStateUpdate) String() string { func (state *State) applyMessageFlagsAdded(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, - addFlags imap.FlagSet) error { + addFlags imap.FlagSet) ([]Update, error) { if addFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("the recent flag is read-only") + return nil, fmt.Errorf("the recent flag is read-only") } // Since DB state can be more up to date then the flag state we should only emit add flag updates for values @@ -110,7 +110,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) if err != nil { - return err + return nil, err } // If setting messages as seen, only set those messages that aren't currently seen. @@ -125,7 +125,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesSeen(ctx, messagesToApply, true); err != nil { - return err + return nil, err } } } @@ -142,7 +142,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesFlagged(ctx, messagesToApply, true); err != nil { - return err + return nil, err } } } @@ -151,7 +151,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, if addFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, true); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(newMessageFlagsAddedStateUpdate(imap.NewFlagSet(imap.FlagDeleted), state.snap.mboxID, messageIDs, state.StateID)) @@ -170,17 +170,13 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, } if err := db.AddMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(newMessageFlagsAddedStateUpdate(remainingFlags, state.snap.mboxID, messagesToFlag, state.StateID)) } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, flagStateUpdate); err != nil { - return err - } - - return nil + return []Update{flagStateUpdate}, nil } type messageFlagsRemovedStateUpdate struct { @@ -229,16 +225,19 @@ func (u *messageFlagsRemovedStateUpdate) String() string { } // applyMessageFlagsRemoved removes the flags from the given messages. -func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, remFlags imap.FlagSet) error { +func (state *State) applyMessageFlagsRemoved(ctx context.Context, + tx *ent.Tx, + messageIDs []imap.InternalMessageID, + remFlags imap.FlagSet) ([]Update, error) { if remFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("the recent flag is read-only") + return nil, fmt.Errorf("the recent flag is read-only") } client := tx.Client() curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) if err != nil { - return err + return nil, err } // If setting messages as unseen, only set those messages that are currently seen. if remFlags.ContainsUnchecked(imap.FlagSeenLowerCase) { @@ -252,7 +251,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesSeen(ctx, messagesToApply, false); err != nil { - return err + return nil, err } } } @@ -269,7 +268,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesFlagged(ctx, messagesToApply, false); err != nil { - return err + return nil, err } } } @@ -278,7 +277,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me if remFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, false); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(NewMessageFlagsRemovedStateUpdate(imap.NewFlagSet(imap.FlagDeleted), state.snap.mboxID, messageIDs, state.StateID)) @@ -297,17 +296,13 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me } if err := db.RemoveMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(NewMessageFlagsRemovedStateUpdate(remainingFlags, state.snap.mboxID, messagesToFlag, state.StateID)) } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, flagStateUpdate); err != nil { - return err - } - - return nil + return []Update{flagStateUpdate}, nil } type messageFlagsSetStateUpdate struct { @@ -356,18 +351,21 @@ func (u *messageFlagsSetStateUpdate) String() string { } // applyMessageFlagsSet sets the flags of the given messages. -func (state *State) applyMessageFlagsSet(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) error { +func (state *State) applyMessageFlagsSet(ctx context.Context, + tx *ent.Tx, + messageIDs []imap.InternalMessageID, + setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("the recent flag is read-only") + return nil, fmt.Errorf("the recent flag is read-only") } if state.snap == nil { - return nil + return nil, nil } curFlags, err := db.GetMessageFlags(ctx, tx.Client(), messageIDs) if err != nil { - return err + return nil, err } // If setting messages as seen, only set those messages that aren't currently seen, and vice versa. @@ -381,7 +379,7 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, tx *ent.Tx, messag for seen, messageIDs := range setSeen { if err := state.user.GetRemote().SetMessagesSeen(ctx, messageIDs, seen); err != nil { - return err + return nil, err } } @@ -396,23 +394,19 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, tx *ent.Tx, messag for flagged, messageIDs := range setFlagged { if err := state.user.GetRemote().SetMessagesFlagged(ctx, messageIDs, flagged); err != nil { - return err + return nil, err } } if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, setFlags.Contains(imap.FlagDeleted)); err != nil { - return err + return nil, err } if err := db.SetMessageFlags(ctx, tx, messageIDs, setFlags.Remove(imap.FlagDeleted)); err != nil { - return err - } - - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, NewMessageFlagsSetStateUpdate(setFlags, state.snap.mboxID, messageIDs, state.StateID)); err != nil { - return err + return nil, err } - return nil + return []Update{NewMessageFlagsSetStateUpdate(setFlags, state.snap.mboxID, messageIDs, state.StateID)}, nil } type mailboxRemoteIDUpdateStateUpdate struct { diff --git a/internal/state/updates_remote.go b/internal/state/updates_remote.go index 8aa4aebb..2f4e3d5f 100644 --- a/internal/state/updates_remote.go +++ b/internal/state/updates_remote.go @@ -7,7 +7,6 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" ) type RemoteAddMessageFlagsStateUpdate struct { @@ -54,21 +53,3 @@ type RemoteMessageDeletedStateUpdate struct { MessageIDStateFilter remoteID imap.MessageID } - -func NewRemoteMessageDeletedStateUpdate(messageID imap.InternalMessageID, remoteID imap.MessageID) Update { - return &RemoteMessageDeletedStateUpdate{ - MessageIDStateFilter: MessageIDStateFilter{MessageID: messageID}, - remoteID: remoteID, - } -} - -func (u *RemoteMessageDeletedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { - return s.actionRemoveMessagesFromMailbox(ctx, tx, []ids.MessageIDPair{{ - InternalID: u.MessageID, - RemoteID: u.remoteID, - }}, s.snap.mboxID) -} - -func (u *RemoteMessageDeletedStateUpdate) String() string { - return fmt.Sprintf("RemoteMessageDeletedStateUpdate %v remote ID = %v", u.MessageIDStateFilter.String(), u.remoteID) -}