Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix double-syncing error where devices receive entries from themselves #202 #204

Merged
merged 7 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions backend/server/internal/database/historyentries.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (db *DB) AllHistoryEntriesForUser(ctx context.Context, userID string) ([]*s

func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
var historyEntries []*shared.EncHistoryEntry
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ?", deviceID, limit).Find(&historyEntries)
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ? AND NOT is_from_same_device", deviceID, limit).Find(&historyEntries)

if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -52,12 +52,13 @@ func (db *DB) AddHistoryEntries(ctx context.Context, entries ...*shared.EncHisto
})
}

func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, devices []*Device, entries []*shared.EncHistoryEntry) error {
func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, sourceDeviceId string, devices []*Device, entries []*shared.EncHistoryEntry) error {
chunkSize := 1000
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, device := range devices {
for _, entry := range entries {
entry.DeviceId = device.DeviceId
entry.IsFromSameDevice = sourceDeviceId == device.DeviceId
}
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
for _, entriesChunk := range shared.Chunks(entries, chunkSize) {
Expand Down
13 changes: 7 additions & 6 deletions backend/server/internal/server/api_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
}
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))

err = s.db.AddHistoryEntriesForAllDevices(r.Context(), devices, entries)
sourceDeviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
err = s.db.AddHistoryEntriesForAllDevices(r.Context(), sourceDeviceId, devices, entries)
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
}
Expand All @@ -49,21 +50,20 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {

resp := shared.SubmitResponse{}

deviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
if deviceId != "" {
if sourceDeviceId != "" {
hv, err := shared.ParseVersionString(version)
if err != nil || hv.GreaterThan(shared.ParsedVersion{MinorVersion: 0, MajorVersion: 221}) {
// Note that if we fail to parse the version string, we do return dump and deletion requests. This is necessary
// since tests run with v0.Unknown which obviously fails to parse.
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, sourceDeviceId)
checkGormError(err)
resp.DumpRequests = dumpRequests

deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, sourceDeviceId)
checkGormError(err)
resp.DeletionRequests = deletionRequests

checkGormError(s.db.DeletionRequestInc(r.Context(), userId, deviceId))
checkGormError(s.db.DeletionRequestInc(r.Context(), userId, sourceDeviceId))
}
}

Expand All @@ -73,6 +73,7 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
// TODO: Update this to filter out duplicate entries
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
version := getHishtoryVersion(r)
Expand Down
155 changes: 39 additions & 116 deletions backend/server/internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestESubmitThenQuery(t *testing.T) {
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)

// Submit a few entries for different devices
// Submit an entry from device 1
entry := testutils.MakeFakeHistoryEntry("ls ~/")
encEntry, err := data.EncryptHistoryEntry("key", entry)
require.NoError(t, err)
Expand All @@ -85,7 +85,7 @@ func TestESubmitThenQuery(t *testing.T) {
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)

// Query for device id 1
// Query for device id 1, no results returned
w = httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
s.apiQueryHandler(w, searchReq)
Expand All @@ -96,16 +96,9 @@ func TestESubmitThenQuery(t *testing.T) {
require.NoError(t, err)
var retrievedEntries []*shared.EncHistoryEntry
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
require.Equal(t, 1, len(retrievedEntries))
dbEntry := retrievedEntries[0]
require.Equal(t, devId1, dbEntry.DeviceId)
require.Equal(t, data.UserId("key"), dbEntry.UserId)
require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
require.NoError(t, err)
require.Equal(t, decEntry, entry)
require.Equal(t, 0, len(retrievedEntries))

// Same for device id 2
// Query for device id 2 and the entry is found
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
s.apiQueryHandler(w, searchReq)
Expand All @@ -114,20 +107,12 @@ func TestESubmitThenQuery(t *testing.T) {
respBody, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
dbEntry = retrievedEntries[0]
if dbEntry.DeviceId != devId2 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != data.UserId("key") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 0 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
decEntry, err = data.DecryptHistoryEntry("key", *dbEntry)
require.Len(t, retrievedEntries, 1)
dbEntry := retrievedEntries[0]
require.Equal(t, dbEntry.DeviceId, devId2)
require.Equal(t, dbEntry.UserId, data.UserId("key"))
require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
require.NoError(t, err)
require.Equal(t, decEntry, entry)

Expand All @@ -140,9 +125,7 @@ func TestESubmitThenQuery(t *testing.T) {
respBody, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
}
require.Len(t, retrievedEntries, 2)

// Assert that we aren't leaking connections
assertNoLeakedConnections(t, DB)
Expand Down Expand Up @@ -177,16 +160,10 @@ func TestDumpRequestAndResponse(t *testing.T) {
require.NoError(t, err)
var dumpRequests []*shared.DumpRequest
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
if len(dumpRequests) != 1 {
t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
}
require.Len(t, dumpRequests, 1)
dumpRequest := dumpRequests[0]
if dumpRequest.RequestingDeviceId != devId2 {
t.Fatalf("unexpected device ID")
}
if dumpRequest.UserId != userId {
t.Fatalf("unexpected user ID")
}
require.Equal(t, devId2, dumpRequest.RequestingDeviceId)
require.Equal(t, userId, dumpRequest.UserId)

// And one for otherUser
w = httptest.NewRecorder()
Expand All @@ -197,16 +174,10 @@ func TestDumpRequestAndResponse(t *testing.T) {
require.NoError(t, err)
dumpRequests = make([]*shared.DumpRequest, 0)
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
if len(dumpRequests) != 1 {
t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
}
require.Len(t, dumpRequests, 1)
dumpRequest = dumpRequests[0]
if dumpRequest.RequestingDeviceId != otherDev2 {
t.Fatalf("unexpected device ID")
}
if dumpRequest.UserId != otherUser {
t.Fatalf("unexpected user ID")
}
require.Equal(t, otherDev2, dumpRequest.RequestingDeviceId)
require.Equal(t, otherUser, dumpRequest.UserId)

// And none if we query for a user ID that doesn't exit
w = httptest.NewRecorder()
Expand Down Expand Up @@ -270,16 +241,10 @@ func TestDumpRequestAndResponse(t *testing.T) {
require.NoError(t, err)
dumpRequests = make([]*shared.DumpRequest, 0)
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
if len(dumpRequests) != 1 {
t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
}
require.Len(t, dumpRequests, 1)
dumpRequest = dumpRequests[0]
if dumpRequest.RequestingDeviceId != otherDev2 {
t.Fatalf("unexpected device ID")
}
if dumpRequest.UserId != otherUser {
t.Fatalf("unexpected user ID")
}
require.Equal(t, otherDev2, dumpRequest.RequestingDeviceId)
require.Equal(t, otherUser, dumpRequest.UserId)

// And finally, query to ensure that the dumped entries are in the DB
w = httptest.NewRecorder()
Expand All @@ -291,19 +256,11 @@ func TestDumpRequestAndResponse(t *testing.T) {
require.NoError(t, err)
var retrievedEntries []*shared.EncHistoryEntry
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
}
require.Len(t, retrievedEntries, 2)
for _, dbEntry := range retrievedEntries {
if dbEntry.DeviceId != devId2 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != userId {
t.Fatalf("Response contains an incorrect user ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 0 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
require.Equal(t, devId2, dbEntry.DeviceId)
require.Equal(t, userId, dbEntry.UserId)
require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
require.NoError(t, err)
require.True(t, assert.ObjectsAreEqual(decEntry, entry1Dec) || assert.ObjectsAreEqual(decEntry, entry2Dec))
Expand Down Expand Up @@ -345,7 +302,6 @@ func TestDeletionRequests(t *testing.T) {
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)

// And another entry for user1
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
Expand All @@ -359,7 +315,6 @@ func TestDeletionRequests(t *testing.T) {
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)

// And an entry for user2 that has the same timestamp as the previous entry
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
Expand All @@ -374,7 +329,6 @@ func TestDeletionRequests(t *testing.T) {
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)

// Query for device id 1
w = httptest.NewRecorder()
Expand All @@ -386,19 +340,11 @@ func TestDeletionRequests(t *testing.T) {
require.NoError(t, err)
var retrievedEntries []*shared.EncHistoryEntry
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
require.Len(t, retrievedEntries, 1)
for _, dbEntry := range retrievedEntries {
if dbEntry.DeviceId != devId1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != data.UserId("dkey") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 0 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
require.Equal(t, devId1, dbEntry.DeviceId)
require.Equal(t, data.UserId("dkey"), dbEntry.UserId)
require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
require.NoError(t, err)
require.True(t, assert.ObjectsAreEqual(decEntry, entry1) || assert.ObjectsAreEqual(decEntry, entry2))
Expand Down Expand Up @@ -428,19 +374,11 @@ func TestDeletionRequests(t *testing.T) {
respBody, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
require.Len(t, retrievedEntries, 1)
dbEntry := retrievedEntries[0]
if dbEntry.DeviceId != devId1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != data.UserId("dkey") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
require.Equal(t, devId1, dbEntry.DeviceId)
require.Equal(t, data.UserId("dkey"), dbEntry.UserId)
require.Equal(t, 1, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
require.NoError(t, err)
require.Equal(t, decEntry, entry2)
Expand All @@ -454,19 +392,11 @@ func TestDeletionRequests(t *testing.T) {
respBody, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
require.Len(t, retrievedEntries, 1)
dbEntry = retrievedEntries[0]
if dbEntry.DeviceId != otherDev1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != data.UserId("dOtherkey") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 0 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
require.Equal(t, otherDev1, dbEntry.DeviceId)
require.Equal(t, data.UserId("dOtherkey"), dbEntry.UserId)
require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err = data.DecryptHistoryEntry("dOtherkey", *dbEntry)
require.NoError(t, err)
require.Equal(t, decEntry, entry3)
Expand All @@ -481,7 +411,6 @@ func TestDeletionRequests(t *testing.T) {
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.NotEmpty(t, deserializeSubmitResponse(t, w).DeletionRequests)
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)

// Query for deletion requests
w = httptest.NewRecorder()
Expand All @@ -493,9 +422,7 @@ func TestDeletionRequests(t *testing.T) {
require.NoError(t, err)
var deletionRequests []*shared.DeletionRequest
require.NoError(t, json.Unmarshal(respBody, &deletionRequests))
if len(deletionRequests) != 1 {
t.Fatalf("received %d deletion requests, expected only one", len(deletionRequests))
}
require.Len(t, deletionRequests, 1)
deletionRequest := deletionRequests[0]
expected := shared.DeletionRequest{
UserId: data.UserId("dkey"),
Expand All @@ -518,16 +445,12 @@ func TestHealthcheck(t *testing.T) {
s := NewServer(DB, TrackUsageData(true))
w := httptest.NewRecorder()
s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
if w.Code != 200 {
t.Fatalf("expected 200 resp code for healthCheckHandler")
}
require.Equal(t, 200, w.Code)
res := w.Result()
defer res.Body.Close()
respBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
if string(respBody) != "OK" {
t.Fatalf("expected healthcheckHandler to return OK")
}
require.Equal(t, "OK", string(respBody))

// Assert that we aren't leaking connections
assertNoLeakedConnections(t, DB)
Expand Down
4 changes: 3 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2996,7 +2996,9 @@ func BenchmarkQuery(b *testing.B) {
// Benchmarked code:
b.StartTimer()
ctx := hctx.MakeContext()
_, err := lib.Search(ctx, hctx.GetDb(ctx), "echo", 100)
err := lib.RetrieveAdditionalEntriesFromRemote(ctx, "tui")
require.NoError(b, err)
_, err = lib.Search(ctx, hctx.GetDb(ctx), "echo", 100)
require.NoError(b, err)
b.StopTimer()
}
Expand Down
4 changes: 2 additions & 2 deletions client/lib/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ func ApiGet(ctx context.Context, path string) ([]byte, error) {
return nil, fmt.Errorf("failed to read response body from GET %s%s: %w", GetServerHostname(), path, err)
}
duration := time.Since(start)
hctx.GetLogger().Infof("ApiGet(%#v): %s\n", path, duration.String())
hctx.GetLogger().Infof("ApiGet(%#v): %d bytes - %s\n", GetServerHostname()+path, len(respBody), duration.String())
return respBody, nil
}

Expand Down Expand Up @@ -511,7 +511,7 @@ func ApiPost(ctx context.Context, path, contentType string, reqBody []byte) ([]b
return nil, fmt.Errorf("failed to read response body from POST %s: %w", GetServerHostname()+path, err)
}
duration := time.Since(start)
hctx.GetLogger().Infof("ApiPost(%#v): %s\n", GetServerHostname()+path, duration.String())
hctx.GetLogger().Infof("ApiPost(%#v): %d bytes - %s\n", GetServerHostname()+path, len(respBody), duration.String())
return respBody, nil
}

Expand Down
Loading
Loading