diff --git a/UserRepository.go b/UserRepository.go index 84c8009..0cdd822 100644 --- a/UserRepository.go +++ b/UserRepository.go @@ -23,21 +23,26 @@ const ClientUserPrefix = "cu" const UserScanBatchSize = 500 -func (repository *UserRepository) SaveUser(clientUserId string, student *models.Student) (err error, hasChanges bool) { +func (repository *UserRepository) SaveUser(clientUserId string, student *models.Student) (error, bool) { previousStudent := repository.GetStudent(clientUserId) studentSerialized, err := proto.Marshal(student) ctx := context.Background() + hasChanges := previousStudent != student + if previousStudent != nil && student != nil { + hasChanges = previousStudent.Id != student.Id + } + if err == nil { clientUserKey := repository.getClientUserKey(clientUserId) pipe := repository.redis.TxPipeline() - if previousStudent.Id != student.Id && previousStudent.Id != 0 { + if hasChanges && previousStudent != nil { pipe.SRem(ctx, repository.getStudentKey(previousStudent.Id), clientUserId) } - if student.Id == 0 { + if student == nil || student.Id == 0 { pipe.Del(ctx, clientUserKey) } else { pipe.Set(ctx, clientUserKey, studentSerialized, UserExpiration) @@ -49,7 +54,7 @@ func (repository *UserRepository) SaveUser(clientUserId string, student *models. _, err = pipe.Exec(ctx) } - return err, previousStudent.Id != student.Id + return err, hasChanges } func (repository *UserRepository) Commit() error { @@ -68,16 +73,16 @@ func (repository *UserRepository) GetStudent(clientUserId string) *models.Studen UserExpiration, ).Bytes() - student := &models.Student{} if studentSerialized != nil && len(studentSerialized) > 0 { - _ = proto.Unmarshal(studentSerialized, student) - } - - if student.Id != 0 { - repository.redis.Expire(ctx, repository.getStudentKey(student.Id), UserExpiration) + student := models.Student{} + _ = proto.Unmarshal(studentSerialized, &student) + if student.Id != 0 { + repository.redis.Expire(ctx, repository.getStudentKey(student.Id), UserExpiration) + return &student + } } - return student + return nil } func (repository *UserRepository) GetClientUserIds(studentId uint) []string { diff --git a/UserRepository_test.go b/UserRepository_test.go index dcdb409..04301c8 100644 --- a/UserRepository_test.go +++ b/UserRepository_test.go @@ -252,7 +252,7 @@ func TestUserRepository_SaveUser(t *testing.T) { assert.NoError(t, err) actualStudent = userRepository.GetStudent(expectedClientUserId) - assert.Equal(t, emptyStudent.String(), actualStudent.String()) + assert.Nil(t, actualStudent) actualClientIds = userRepository.GetClientUserIds(uint(expectedStudent1.Id)) assert.Len(t, actualClientIds, 0) @@ -311,7 +311,7 @@ func TestUserRepository_GetStudent(t *testing.T) { redisMock.ExpectGetEx(clientUserId, UserExpiration).RedisNil() actualStudent := userRepository.GetStudent(clientUserId) - assertStudent(t, &models.Student{}, actualStudent) + assert.Nil(t, actualStudent) }) }