Skip to content

Commit

Permalink
feat: support FieldMask in Updates (#1197)
Browse files Browse the repository at this point in the history
* feat: support FieldMask in Updates

* fix index out of bounds

* simplify

* refactor for oneofs
  • Loading branch information
noahdietz authored Sep 1, 2022
1 parent 3ac40ce commit cdb4ce6
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 93 deletions.
44 changes: 8 additions & 36 deletions server/services/identity_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"sync"

"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/googleapis/gapic-showcase/server"
Expand Down Expand Up @@ -107,16 +108,10 @@ func (s *identityServerImpl) GetUser(_ context.Context, in *pb.GetUserRequest) (

// Updates a user.
func (s *identityServerImpl) UpdateUser(_ context.Context, in *pb.UpdateUserRequest) (*pb.User, error) {
mask := in.GetUpdateMask()
if mask != nil && len(mask.GetPaths()) > 0 {
return nil, status.Error(
codes.Unimplemented,
"Field masks are currently not supported.")
}

s.mu.Lock()
defer s.mu.Unlock()

mask := in.GetUpdateMask()
u := in.GetUser()
i, ok := s.keys[u.GetName()]
if !ok || s.users[i].deleted {
Expand All @@ -129,36 +124,13 @@ func (s *identityServerImpl) UpdateUser(_ context.Context, in *pb.UpdateUserRequ
if err != nil {
return nil, err
}
entry := s.users[i]
// Update store.
updated := &pb.User{
Name: u.GetName(),
DisplayName: u.GetDisplayName(),
Email: u.GetEmail(),
CreateTime: entry.user.GetCreateTime(),
UpdateTime: ptypes.TimestampNow(),
Age: entry.user.Age,
EnableNotifications: entry.user.EnableNotifications,
HeightFeet: entry.user.HeightFeet,
Nickname: entry.user.Nickname,
}

// Use direct field access to avoid unwrapping and rewrapping the pointer value.
//
// TODO: if field_mask is implemented, do a direct update if included,
// regardless of if the optional field is nil.
if u.Age != nil {
updated.Age = u.Age
}
if u.EnableNotifications != nil {
updated.EnableNotifications = u.EnableNotifications
}
if u.HeightFeet != nil {
updated.HeightFeet = u.HeightFeet
}
if u.Nickname != nil {
updated.Nickname = u.Nickname
}
// Update store.
existing := s.users[i].user
updated := proto.Clone(existing).(*pb.User)
applyFieldMask(u.ProtoReflect(), updated.ProtoReflect(), mask.GetPaths())
updated.CreateTime = existing.GetCreateTime()
updated.UpdateTime = ptypes.TimestampNow()

s.users[i] = userEntry{user: updated}
return updated, nil
Expand Down
33 changes: 21 additions & 12 deletions server/services/identity_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ package services
import (
"context"
"encoding/base64"
"strings"
"testing"

"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/gapic-showcase/server"
pb "github.com/googleapis/gapic-showcase/server/genproto"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)

func Test_User_lifecycle(t *testing.T) {
Expand Down Expand Up @@ -257,19 +259,26 @@ func Test_Get_deleted(t *testing.T) {
}

func Test_Update_fieldmask(t *testing.T) {
first := &pb.User{DisplayName: "Ekko", Email: "[email protected]"}
second := &pb.User{DisplayName: "Foo", Email: "[email protected]"}
paths := []string{"display_name"}
s := NewIdentityServer()
_, err := s.UpdateUser(
created, err := s.CreateUser(
context.Background(),
&pb.UpdateUserRequest{
User: nil,
UpdateMask: &field_mask.FieldMask{Paths: []string{"email"}},
})
status, _ := status.FromError(err)
if status.Code() != codes.Unimplemented {
t.Errorf(
"Update: Want error code %d got %d",
codes.Unimplemented,
status.Code())
&pb.CreateUserRequest{User: first})
if err != nil {
t.Errorf("Create: unexpected err %+v", err)
}
second.Name = created.GetName()

got, err := s.UpdateUser(
context.Background(),
&pb.UpdateUserRequest{User: second, UpdateMask: &fieldmaskpb.FieldMask{Paths: paths}})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got.GetDisplayName(), second.GetDisplayName()); diff != "" {
t.Errorf("Using update_mask %s, got(-),want(+):\n%s", strings.Join(paths, ","), diff)
}
}

Expand Down
35 changes: 13 additions & 22 deletions server/services/messaging_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,8 @@ func (s *messagingServerImpl) GetRoom(ctx context.Context, in *pb.GetRoomRequest

// Updates a room.
func (s *messagingServerImpl) UpdateRoom(ctx context.Context, in *pb.UpdateRoomRequest) (*pb.Room, error) {
f := in.GetUpdateMask()
mask := in.GetUpdateMask()
r := in.GetRoom()
if f != nil && len(f.GetPaths()) > 0 {
return nil, status.Error(
codes.Unimplemented,
"Field masks are currently not supported.")
}

s.roomMu.Lock()
defer s.roomMu.Unlock()
Expand All @@ -181,11 +176,11 @@ func (s *messagingServerImpl) UpdateRoom(ctx context.Context, in *pb.UpdateRoomR
codes.NotFound,
"A room with name %s not found.", r.GetName())
}
existing := s.rooms[i].room

entry := s.rooms[i]
// Validate Unique Fields.
uniqName := func(x *pb.Room) bool {
return x != entry.room && (r.GetDisplayName() == x.GetDisplayName())
return x != existing && (r.GetDisplayName() == x.GetDisplayName())
}
if s.anyRoom(uniqName) {
return nil, status.Errorf(
Expand All @@ -195,13 +190,11 @@ func (s *messagingServerImpl) UpdateRoom(ctx context.Context, in *pb.UpdateRoomR
}

// Update store.
updated := &pb.Room{
Name: r.GetName(),
DisplayName: r.GetDisplayName(),
Description: r.GetDescription(),
CreateTime: entry.room.GetCreateTime(),
UpdateTime: ptypes.TimestampNow(),
}
updated := proto.Clone(existing).(*pb.Room)
applyFieldMask(r.ProtoReflect(), updated.ProtoReflect(), mask.GetPaths())
updated.CreateTime = existing.GetCreateTime()
updated.UpdateTime = ptypes.TimestampNow()

s.rooms[i] = roomEntry{room: updated}
return updated, nil
}
Expand Down Expand Up @@ -350,15 +343,10 @@ func (s *messagingServerImpl) GetBlurb(ctx context.Context, in *pb.GetBlurbReque

// Updates a blurb.
func (s *messagingServerImpl) UpdateBlurb(ctx context.Context, in *pb.UpdateBlurbRequest) (*pb.Blurb, error) {
if in.GetUpdateMask() != nil && len(in.GetUpdateMask().GetPaths()) > 0 {
return nil, status.Error(
codes.Unimplemented,
"Field masks are currently not supported.")
}

s.blurbMu.Lock()
defer s.blurbMu.Unlock()

mask := in.GetUpdateMask()
b := in.GetBlurb()
i, ok := s.blurbKeys[b.GetName()]
if !ok || s.blurbs[i.row][i.col].deleted {
Expand All @@ -371,7 +359,10 @@ func (s *messagingServerImpl) UpdateBlurb(ctx context.Context, in *pb.UpdateBlur
return nil, err
}
// Update store.
updated := proto.Clone(b).(*pb.Blurb)
existing := s.blurbs[i.row][i.col].blurb
updated := proto.Clone(existing).(*pb.Blurb)
applyFieldMask(b.ProtoReflect(), updated.ProtoReflect(), mask.GetPaths())
updated.CreateTime = existing.GetCreateTime()
updated.UpdateTime = ptypes.TimestampNow()
s.blurbs[i.row][i.col] = blurbEntry{blurb: updated}

Expand Down
67 changes: 44 additions & 23 deletions server/services/messaging_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import (
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/gapic-showcase/server"
pb "github.com/googleapis/gapic-showcase/server/genproto"
"golang.org/x/sync/errgroup"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)

func Test_Room_lifecycle(t *testing.T) {
Expand Down Expand Up @@ -208,19 +209,27 @@ func Test_GetRoom_deleted(t *testing.T) {
}

func Test_UpdateRoom_fieldmask(t *testing.T) {
first := &pb.Room{DisplayName: "Living Room"}
second := &pb.Room{DisplayName: "Dining Room"}
paths := []string{"display_name"}
s := NewMessagingServer(NewIdentityServer())
_, err := s.UpdateRoom(
created, err := s.CreateRoom(
context.Background(),
&pb.CreateRoomRequest{Room: first})
if err != nil {
t.Errorf("Create: unexpected err %+v", err)
}
second.Name = created.GetName()
got, err := s.UpdateRoom(
context.Background(),
&pb.UpdateRoomRequest{
Room: nil,
UpdateMask: &field_mask.FieldMask{Paths: []string{"email"}},
})
status, _ := status.FromError(err)
if status.Code() != codes.Unimplemented {
t.Errorf(
"Update: Want error code %d got %d",
codes.Unimplemented,
status.Code())
Room: second,
UpdateMask: &fieldmaskpb.FieldMask{Paths: paths}})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got.GetDisplayName(), second.GetDisplayName()); diff != "" {
t.Errorf("Using update_mask %s, got(-),want(+):\n%s", strings.Join(paths, ","), diff)
}
}

Expand Down Expand Up @@ -579,19 +588,31 @@ func Test_GetBlurb_deleted(t *testing.T) {
}

func Test_UpdateBlurb_fieldmask(t *testing.T) {
s := NewMessagingServer(NewIdentityServer())
_, err := s.UpdateBlurb(
first := &pb.Blurb{
User: "users/rumble",
Content: &pb.Blurb_Text{Text: "woof"},
}
second := &pb.Blurb{
Content: &pb.Blurb_Text{Text: "bark"},
}
paths := []string{"text"}
s := NewMessagingServer(&mockIdentityServer{})
created, err := s.CreateBlurb(
context.Background(),
&pb.UpdateBlurbRequest{
Blurb: nil,
UpdateMask: &field_mask.FieldMask{Paths: []string{"email"}},
})
status, _ := status.FromError(err)
if status.Code() != codes.Unimplemented {
t.Errorf(
"Update: Want error code %d got %d",
codes.Unimplemented,
status.Code())
&pb.CreateBlurbRequest{Blurb: first})
if err != nil {
t.Errorf("Create: unexpected err %+v", err)
}
second.Name = created.GetName()
second.User = created.GetUser()
got, err := s.UpdateBlurb(
context.Background(),
&pb.UpdateBlurbRequest{Blurb: second, UpdateMask: &fieldmaskpb.FieldMask{Paths: paths}})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got.GetText(), second.GetText()); diff != "" {
t.Errorf("Using update_mask %s, got(-),want(+):\n%s", strings.Join(paths, ","), diff)
}
}

Expand Down
74 changes: 74 additions & 0 deletions server/services/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package services

import (
"google.golang.org/protobuf/reflect/protoreflect"
)

func strContains(haystack []string, needle string) bool {
for _, s := range haystack {
if s == needle {
return true
}
}
return false
}

// applyFieldMask applies the values from the src message to the values of the
// dst message according to the contents of the given field mask paths.
// If paths is empty/nil, or contains *, it is considered a full update.
//
// TODO: Does not support nested message paths. Currently only used with flat
// resource messages.
func applyFieldMask(src, dst protoreflect.Message, paths []string) {
fullUpdate := len(paths) == 0 || strContains(paths, "*")

fields := dst.Descriptor().Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
isOneof := field.ContainingOneof() != nil && !field.ContainingOneof().IsSynthetic()

// Set field in dst with value from src, skipping true oneofs, while
// handling proto3_optional fields represented as synthetic oneofs.
if (fullUpdate || strContains(paths, string(field.Name()))) && !isOneof {
dst.Set(field, src.Get(field))
}
}

oneofs := dst.Descriptor().Oneofs()
for i := 0; i < oneofs.Len(); i++ {
oneof := oneofs.Get(i)
// Skip proto3_optional synthetic oneofs.
if oneof.IsSynthetic() {
continue
}

setOneof := src.WhichOneof(oneof)
if setOneof == nil && fullUpdate {
// Full update with no field set in this oneof of
// src means clear all fields for this oneof in dst.
fields := oneof.Fields()
for j := 0; j < fields.Len(); j++ {
dst.Clear(fields.Get(j))
}
} else if setOneof != nil && (fullUpdate || strContains(paths, string(setOneof.Name()))) {
// Full update or targeted updated with a field set in this oneof of
// src means set that field for the same oneof in dst, which implicitly
// clears any previously set field for this oneof.
dst.Set(setOneof, src.Get(setOneof))
}
}
}

0 comments on commit cdb4ce6

Please sign in to comment.