diff --git a/connector/connector.go b/connector/connector.go index a858f3bd..fdbae1b0 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -15,6 +15,10 @@ type Connector interface { // CreateMailbox creates a mailbox with the given name. CreateMailbox(ctx context.Context, name []string) (imap.Mailbox, error) + // GetMessageLiteral is intended to be used by Gluon when, for some reason, the local cached data no longer exists. + // Note: this can get called from different go routines. + GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) + // IsMailboxVisible can be used to hide mailboxes from connected clients. IsMailboxVisible(ctx context.Context, mboxID imap.MailboxID) bool diff --git a/connector/dummy.go b/connector/dummy.go index e9c7e8e7..d61a5733 100644 --- a/connector/dummy.go +++ b/connector/dummy.go @@ -158,6 +158,10 @@ func (conn *Dummy) DeleteMailbox(ctx context.Context, mboxID imap.MailboxID) err return nil } +func (conn *Dummy) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) { + return conn.state.tryGetLiteral(id) +} + func (conn *Dummy) CreateMessage(ctx context.Context, mboxID imap.MailboxID, literal []byte, flags imap.FlagSet, date time.Time) (imap.Message, []byte, error) { // NOTE: We are only recording this here since it was the easiest command to verify the data has been record properly // in the context, as APPEND will always require a communication with the remote connector. diff --git a/connector/dummy_state.go b/connector/dummy_state.go index b5dea8a4..b0af5a94 100644 --- a/connector/dummy_state.go +++ b/connector/dummy_state.go @@ -171,6 +171,18 @@ func (state *dummyState) getLiteral(messageID imap.MessageID) []byte { return state.messages[messageID].literal } +func (state *dummyState) tryGetLiteral(messageID imap.MessageID) ([]byte, error) { + state.lock.Lock() + defer state.lock.Unlock() + + v, ok := state.messages[messageID] + if !ok { + return nil, ErrNoSuchMessage + } + + return v.literal, nil +} + func (state *dummyState) createMessage( mboxID imap.MailboxID, literal []byte, diff --git a/internal/backend/state_connector_impl.go b/internal/backend/state_connector_impl.go index d854dbdf..9cc15c06 100644 --- a/internal/backend/state_connector_impl.go +++ b/internal/backend/state_connector_impl.go @@ -75,6 +75,12 @@ func (sc *stateConnectorImpl) CreateMessage( return imap.NewInternalMessageID(), msg, newLiteral, nil } +func (sc *stateConnectorImpl) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) { + ctx = sc.newContextWithMetadata(ctx) + + return sc.connector.GetMessageLiteral(ctx, id) +} + func (sc *stateConnectorImpl) AddMessagesToMailbox( ctx context.Context, messageIDs []imap.MessageID, diff --git a/internal/backend/state_user_interface_impl.go b/internal/backend/state_user_interface_impl.go index 0710483b..65c20a31 100644 --- a/internal/backend/state_user_interface_impl.go +++ b/internal/backend/state_user_interface_impl.go @@ -38,7 +38,7 @@ func (s *StateUserInterfaceImpl) GetRemote() state.Connector { return s.c } -func (s *StateUserInterfaceImpl) GetStore() store.Store { +func (s *StateUserInterfaceImpl) GetStore() *store.WriteControlledStore { return s.u.store } diff --git a/internal/backend/user.go b/internal/backend/user.go index bab3301f..61e07549 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -25,7 +25,7 @@ type user struct { connector connector.Connector updateInjector *updateInjector - store store.Store + store *store.WriteControlledStore delimiter string db *db.DB @@ -49,7 +49,7 @@ func newUser( userID string, database *db.DB, conn connector.Connector, - store store.Store, + st store.Store, delimiter string, imapLimits limits.IMAP, ) (*user, error) { @@ -79,7 +79,7 @@ func newUser( connector: conn, updateInjector: newUpdateInjector(conn, userID), - store: store, + store: store.NewWriteControlledStore(st), delimiter: delimiter, db: database, diff --git a/internal/state/actions.go b/internal/state/actions.go index 9a85a5c6..ee8fdfd7 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -145,7 +145,7 @@ func (state *State) actionCreateMessage( return 0, fmt.Errorf("failed to set internal ID: %w", err) } - if err := state.user.GetStore().Set(internalID, literalWithHeader); err != nil { + if err := state.user.GetStore().SetUnchecked(internalID, literalWithHeader); err != nil { return 0, fmt.Errorf("failed to store message literal: %w", err) } @@ -196,7 +196,7 @@ func (state *State) actionCreateRecoveredMessage( return err } - if err := state.user.GetStore().Set(internalID, literal); err != nil { + if err := state.user.GetStore().SetUnchecked(internalID, literal); err != nil { return fmt.Errorf("failed to store message literal: %w", err) } @@ -343,7 +343,7 @@ func (state *State) actionImportRecoveredMessage( return ids.MessageIDPair{}, false, fmt.Errorf("failed to set internal ID: %w", err) } - if err := state.user.GetStore().Set(internalID, literalWithHeader); err != nil { + if err := state.user.GetStore().SetUnchecked(internalID, literalWithHeader); err != nil { return ids.MessageIDPair{}, false, fmt.Errorf("failed to store message literal: %w", err) } diff --git a/internal/state/connector.go b/internal/state/connector.go index 8c495398..196f9601 100644 --- a/internal/state/connector.go +++ b/internal/state/connector.go @@ -41,6 +41,10 @@ type Connector interface { date time.Time, ) (imap.InternalMessageID, imap.Message, []byte, error) + // GetMessageLiteral retrieves the message literal from the connector. + // Note: this can get called from different go routines. + GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) + // AddMessagesToMailbox adds the message with the given ID to the mailbox with the given ID. AddMessagesToMailbox( ctx context.Context, diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index 94e5b473..33264694 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -111,7 +111,7 @@ func (m *Mailbox) Fetch(ctx context.Context, seq *proto.SequenceSet, attributes var literal []byte if needsLiteral { - l, err := m.state.getLiteral(msg.ID.InternalID) + l, err := m.state.getLiteral(ctx, msg.ID) if err != nil { return err } diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index fd6ae5d4..ab7d4f70 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -97,7 +97,7 @@ func buildSearchData(ctx context.Context, m *Mailbox, op *buildSearchOpResult, m } if op.needsLiteral { - l, err := m.state.getLiteral(message.ID.InternalID) + l, err := m.state.getLiteral(ctx, message.ID) if err != nil { return searchData{}, err } diff --git a/internal/state/state.go b/internal/state/state.go index f400b5c6..a21054ec 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -621,8 +621,32 @@ func (state *State) endIdle() { state.idleCh = nil } -func (state *State) getLiteral(messageID imap.InternalMessageID) ([]byte, error) { - return state.user.GetStore().Get(messageID) +func (state *State) getLiteral(ctx context.Context, messageID ids.MessageIDPair) ([]byte, error) { + var literal []byte + + storeLiteral, firstErr := state.user.GetStore().Get(messageID.InternalID) + if firstErr != nil { + logrus.Debugf("Failed load %v from store, attempting to download from connector", messageID.InternalID.ShortID()) + + connectorLiteral, err := state.user.GetRemote().GetMessageLiteral(ctx, messageID.RemoteID) + if err != nil { + logrus.Errorf("Failed to download message from connector: %v", err) + return nil, fmt.Errorf("message failed to load from cache (%v), failed to download from connector: %w", firstErr, err) + } + + if err := state.user.GetStore().Set(messageID.InternalID, connectorLiteral); err != nil { + logrus.Errorf("Failed to store download message from connector: %v", err) + return nil, fmt.Errorf("message failed to load from cache (%v), failed to store new downloaded message: %w", firstErr, err) + } + + logrus.Debugf("Message %v downloaded and stored ", messageID.InternalID.ShortID()) + + literal = connectorLiteral + } else { + literal = storeLiteral + } + + return literal, nil } func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]response.Response, error) { diff --git a/internal/state/user_interface.go b/internal/state/user_interface.go index 58ee1b12..1df384ed 100644 --- a/internal/state/user_interface.go +++ b/internal/state/user_interface.go @@ -22,7 +22,7 @@ type UserInterface interface { GetRemote() Connector - GetStore() store.Store + GetStore() *store.WriteControlledStore QueueOrApplyStateUpdate(ctx context.Context, tx *ent.Tx, update ...Update) error diff --git a/store/disk.go b/store/disk.go index b10ef8e0..f746a1b4 100644 --- a/store/disk.go +++ b/store/disk.go @@ -95,6 +95,10 @@ func (c *onDiskStore) Set(messageID imap.InternalMessageID, b []byte) error { b = enc } + if err := os.MkdirAll(c.path, 0o700); err != nil { + return err + } + return os.WriteFile( filepath.Join(c.path, messageID.String()), c.gcm.Seal(nonce, nonce, b, nil), diff --git a/store/write_controlled_store.go b/store/write_controlled_store.go new file mode 100644 index 00000000..3a4099d7 --- /dev/null +++ b/store/write_controlled_store.go @@ -0,0 +1,140 @@ +package store + +import ( + "github.com/ProtonMail/gluon/imap" + "sync" + "sync/atomic" +) + +type syncRef struct { + lock sync.RWMutex + counter int32 +} + +// WriteControlledStore ensures that a given file on disk can safely be accessed by multiple readers and only +// one writer. Internally we maintain a list of RWLocks per message ID. +type WriteControlledStore struct { + impl Store + + lock sync.Mutex + entryTable map[imap.InternalMessageID]*syncRef + lockPool []*syncRef +} + +func NewWriteControlledStore(impl Store) *WriteControlledStore { + return &WriteControlledStore{ + impl: impl, + entryTable: make(map[imap.InternalMessageID]*syncRef), + } +} + +func (w *WriteControlledStore) acquireSyncRef(id imap.InternalMessageID) *syncRef { + w.lock.Lock() + defer w.lock.Unlock() + + v, ok := w.entryTable[id] + if !ok { + var s *syncRef + + if len(w.lockPool) != 0 { + s = w.lockPool[0] + s.counter = 1 + w.lockPool = w.lockPool[1:] + } else { + s = &syncRef{counter: 1} + } + + w.entryTable[id] = s + + return s + } + + atomic.AddInt32(&v.counter, 1) + + return v +} + +func (w *WriteControlledStore) releaseSyncRef(id imap.InternalMessageID, ref *syncRef) { + if atomic.AddInt32(&ref.counter, -1) <= 0 { + w.lock.Lock() + defer w.lock.Unlock() + + if atomic.LoadInt32(&ref.counter) <= 0 { + delete(w.entryTable, id) + w.lockPool = append(w.lockPool, ref) + } + } +} + +func (w *WriteControlledStore) Get(messageID imap.InternalMessageID) ([]byte, error) { + syncRef := w.acquireSyncRef(messageID) + defer w.releaseSyncRef(messageID, syncRef) + + syncRef.lock.RLock() + defer syncRef.lock.RUnlock() + + return w.impl.Get(messageID) +} + +func (w *WriteControlledStore) Set(messageID imap.InternalMessageID, literal []byte) error { + syncRef := w.acquireSyncRef(messageID) + defer w.releaseSyncRef(messageID, syncRef) + + syncRef.lock.Lock() + defer syncRef.lock.Unlock() + + return w.impl.Set(messageID, literal) +} + +// SetUnchecked allows the user to bypass lock access. This will only work if you can guarantee that the data being +// set does not previously exit (e.g: New message). +func (w *WriteControlledStore) SetUnchecked(messageID imap.InternalMessageID, literal []byte) error { + return w.impl.Set(messageID, literal) +} + +func (w *WriteControlledStore) Delete(messageID ...imap.InternalMessageID) error { + for _, id := range messageID { + if err := func() error { + syncRef := w.acquireSyncRef(id) + defer w.releaseSyncRef(id, syncRef) + + syncRef.lock.Lock() + defer syncRef.lock.Unlock() + + return w.impl.Delete(messageID...) + }(); err != nil { + return err + } + } + + return nil +} + +func (w *WriteControlledStore) Close() error { + return w.impl.Close() +} + +func (w *WriteControlledStore) List() ([]imap.InternalMessageID, error) { + return w.impl.List() +} + +type WriteControlledStoreBuilder struct { + builder Builder +} + +func NewWriteControlledStoreBuilder(builder Builder) *WriteControlledStoreBuilder { + return &WriteControlledStoreBuilder{builder: builder} +} + +func (w *WriteControlledStoreBuilder) New(dir, userID string, passphrase []byte) (Store, error) { + impl, err := w.builder.New(dir, userID, passphrase) + if err != nil { + return nil, err + } + + return NewWriteControlledStore(impl), nil +} + +func (w *WriteControlledStoreBuilder) Delete(dir, userID string) error { + return w.builder.Delete(dir, userID) +} diff --git a/store/write_controlled_store_test.go b/store/write_controlled_store_test.go new file mode 100644 index 00000000..642fd345 --- /dev/null +++ b/store/write_controlled_store_test.go @@ -0,0 +1,63 @@ +package store + +import ( + "bytes" + "github.com/ProtonMail/gluon/imap" + "github.com/stretchr/testify/require" + "sync" + "testing" +) + +func TestWriteControlledStore(t *testing.T) { + id1 := imap.NewInternalMessageID() + id2 := imap.NewInternalMessageID() + id3 := imap.NewInternalMessageID() + + st, err := NewOnDiskStore( + t.TempDir(), + []byte("pass"), + WithCompressor(nil), + ) + require.NoError(t, err) + + st = NewWriteControlledStore(st) + + wg := sync.WaitGroup{} + + for i := 0; i < 256; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + var id imap.InternalMessageID + + switch i % 3 { + case 0: + require.NoError(t, st.Set(id1, []byte("literal1"))) + id = id1 + case 1: + require.NoError(t, st.Set(id2, []byte("literal2"))) + id = id2 + case 2: + require.NoError(t, st.Set(id3, []byte("literal3"))) + id = id3 + } + + require.NotEmpty(t, id, imap.InternalMessageID{}) + + // It's not guaranteed which version of the literal will be available on disk, but it should be + // match one of the following + literal, err := st.Get(id) + require.NoError(t, err) + + isEqual := bytes.Equal([]byte("literal1"), literal) || + bytes.Equal([]byte("literal2"), literal) || + bytes.Equal([]byte("literal3"), literal) + + require.True(t, isEqual) + }(i) + } + + wg.Wait() +} diff --git a/tests/cache_reset_test.go b/tests/cache_reset_test.go new file mode 100644 index 00000000..68b7d3d1 --- /dev/null +++ b/tests/cache_reset_test.go @@ -0,0 +1,61 @@ +package tests + +import ( + goimap "github.com/emersion/go-imap" + "github.com/emersion/go-imap/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "os" + "testing" + "time" +) + +func TestFetchWhenFileDeletedFromCache(t *testing.T) { + runOneToOneTestClientWithAuth(t, defaultServerOptions(t), func(client *client.Client, s *testSession) { + // create message + require.NoError(t, doAppendWithClientFromFile(t, client, "INBOX", "testdata/afternoon-meeting.eml", time.Now())) + + // delete message from cache + require.NoError(t, os.RemoveAll(s.options.dataDir)) + + status, err := client.Select("INBOX", false) + require.NoError(t, err) + assert.Equal(t, uint32(1), status.Messages) + + // Load message + fullMessageBytes, err := os.ReadFile("testdata/afternoon-meeting.eml") + require.NoError(t, err) + fullMessage := string(fullMessageBytes) + + newFetchCommand(t, client).withItems(goimap.FetchRFC822).fetch("1").forSeqNum(1, func(validator *validatorBuilder) { + validator.ignoreFlags() + validator.wantSectionString(goimap.FetchRFC822, func(t testing.TB, literal string) { + messageFromSection := skipGLUONHeader(literal) + require.Equal(t, fullMessage, messageFromSection) + }) + }).checkAndRequireMessageCount(1) + }) +} + +func TestSearchWhenFileDeletedFromCache(t *testing.T) { + runOneToOneTestClientWithAuth(t, defaultServerOptions(t), func(client *client.Client, s *testSession) { + // create message + require.NoError(t, doAppendWithClientFromFile(t, client, "INBOX", "testdata/afternoon-meeting.eml", time.Now())) + + // delete message from cache + require.NoError(t, os.RemoveAll(s.options.dataDir)) + + status, err := client.Select("INBOX", false) + require.NoError(t, err) + assert.Equal(t, uint32(1), status.Messages) + + searchCriteria := goimap.NewSearchCriteria() + searchCriteria.Text = append(searchCriteria.Text, "3:30") + + seqs, err := client.Search(searchCriteria) + require.NoError(t, err) + require.Equal(t, 1, len(seqs)) + require.Equal(t, uint32(1), seqs[0]) + + }) +}