Skip to content

Commit

Permalink
feat: add batching to UUID mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed May 16, 2022
1 parent b97f25f commit cb8509d
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package uuidmapping
import (
"database/sql"
"fmt"
"strings"
"time"

"github.com/gobuffalo/pop/v6"
Expand Down Expand Up @@ -61,14 +62,15 @@ var (

for _, rt := range relationTuples {
rt := rt
if err = migrateSubjectID(conn, &rt); err != nil {
return fmt.Errorf("could not migrate subject ID: %w", err)
fields := []*string{&rt.Object}
if rt.SubjectID.Valid {
fields = append(fields, &rt.SubjectID.String)
}
if err = migrateSubjectSetObject(conn, &rt); err != nil {
return fmt.Errorf("could not migrate subject set object: %w", err)
if rt.SubjectSetObject.Valid {
fields = append(fields, &rt.SubjectSetObject.String)
}
if err = migrateObject(conn, &rt); err != nil {
return fmt.Errorf("could not migrate object: %w", err)
if err := batchReplaceStrings(conn, &rt, fields); err != nil {
return fmt.Errorf("could not replace UUIDs: %w", err)
}
if err = conn.Update(&rt); err != nil {
return fmt.Errorf("failed to update relation tuple: %w", err)
Expand Down Expand Up @@ -127,64 +129,61 @@ var (
}
)

func hasMapping(conn *pop.Connection, id string) (bool, error) {
found, err := conn.Where("id = ?", id).Exists(&UUIDMapping{})
if err != nil {
return false, nil
func getRelationTuples(conn *pop.Connection, page int) (
res []RelationTuple, hasNext bool, err error,
) {
q := conn.Order("nid, shard_id").Paginate(page, 100)

if err := q.All(&res); err != nil {
return nil, false, sqlcon.HandleError(err)
}
return found, nil
return res, q.Paginator.Page < q.Paginator.TotalPages, nil
}

func migrateSubjectID(conn *pop.Connection, rt *RelationTuple) error {
if !rt.SubjectID.Valid || rt.SubjectID.String == "" {
return nil
}
skip, err := hasMapping(conn, rt.SubjectID.String)
if err != nil {
return err
}
if skip {
return nil
func removeNonUUIDs(fields []*string) []*string {
var res []*string
for _, f := range fields {
if f == nil || *f == "" {
continue
}
if _, err := uuid.FromString(*f); err != nil {
continue
}
res = append(res, f)
}

rt.SubjectID.String, err = addUUIDMapping(conn, rt.NetworkID, rt.SubjectID.String)
return err
return res
}

func migrateSubjectSetObject(conn *pop.Connection, rt *RelationTuple) error {
if !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" {
return nil
}
skip, err := hasMapping(conn, rt.SubjectSetObject.String)
if err != nil {
return err
}
if skip {
return nil
func removeEmpty(fields []*string) []*string {
var res []*string
for _, f := range fields {
if f == nil || *f == "" {
continue
}
res = append(res, f)
}

rt.SubjectSetObject.String, err = addUUIDMapping(conn, rt.NetworkID, rt.SubjectSetObject.String)
return err
return res
}

func migrateObject(conn *pop.Connection, rt *RelationTuple) error {
if rt.Object == "" {
return nil
}
skip, err := hasMapping(conn, rt.Object)
if err != nil {
return err
func batchReplaceStrings(conn *pop.Connection, rt *RelationTuple, fields []*string) (err error) {
fields = removeEmpty(fields)
if len(fields) == 0 {
return
}
if skip {
return nil
values := make([]string, len(fields))
for i, field := range fields {
values[i] = *field
}

rt.Object, err = addUUIDMapping(conn, rt.NetworkID, rt.Object)
return err
}

func addUUIDMapping(conn *pop.Connection, networkID uuid.UUID, value string) (uid string, err error) {
uid = uuid.NewV5(networkID, value).String()
uuids := make([]uuid.UUID, len(values))
placeholderArray := make([]string, len(values))
args := make([]interface{}, 0, len(values)*2)
for i, val := range values {
uuids[i] = uuid.NewV5(rt.NetworkID, val)
placeholderArray[i] = "(?, ?)"
args = append(args, uuids[i].String(), val)
}
placeholders := strings.Join(placeholderArray, ", ")

// We need to write manual SQL here because the INSERT should not fail if
// the UUID already exists, but we still want to return an error if anything
Expand All @@ -193,45 +192,22 @@ func addUUIDMapping(conn *pop.Connection, networkID uuid.UUID, value string) (ui
switch d := conn.Dialect.Name(); d {
case "mysql":
query = `
INSERT IGNORE INTO keto_uuid_mappings (id, string_representation)
VALUES (?, ?)`
INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ` + placeholders
default:
query = `
INSERT INTO keto_uuid_mappings (id, string_representation)
VALUES (?, ?)
VALUES ` + placeholders + `
ON CONFLICT (id) DO NOTHING`
}

err = sqlcon.HandleError(conn.RawQuery(query, uid, value).Exec())
if err != nil {
return "", fmt.Errorf("failed to add UUID mapping: %w", err)
}
return
}

func getRelationTuples(conn *pop.Connection, page int) (
res []RelationTuple, hasNext bool, err error,
) {
q := conn.Order("nid, shard_id").Paginate(page, 100)

if err := q.All(&res); err != nil {
return nil, false, sqlcon.HandleError(err)
if err = sqlcon.HandleError(conn.RawQuery(query, args...).Exec()); err != nil {
return err
}
return res, q.Paginator.Page < q.Paginator.TotalPages, nil
}

func removeNonUUIDs(fields []*string) []*string {
var res []*string
for _, f := range fields {
if f == nil || *f == "" {
continue
}
if _, err := uuid.FromString(*f); err != nil {
continue
}
res = append(res, f)
for i, field := range fields {
*field = uuids[i].String()
}
return res
return nil
}

func batchReplaceUUIDs(conn *pop.Connection, ids []*string) (err error) {
Expand Down
93 changes: 52 additions & 41 deletions internal/persistence/sql/uuid_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package sql
import (
"context"
"fmt"
"strings"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/ory/x/sqlcon"

Expand All @@ -28,9 +28,22 @@ func (UUIDMapping) TableName() string {
return "keto_uuid_mappings"
}

func (p *Persister) ToUUID(ctx context.Context, text string) (uuid.UUID, error) {
id := uuid.NewV5(p.NetworkID(ctx), text)
p.d.Logger().Trace("adding UUID mapping")
func (p *Persister) batchToUUIDs(ctx context.Context, values []string) (uuids []uuid.UUID, err error) {
if len(values) == 0 {
return
}

uuids = make([]uuid.UUID, len(values))
placeholderArray := make([]string, len(values))
args := make([]interface{}, 0, len(values)*2)
for i, val := range values {
uuids[i] = uuid.NewV5(p.NetworkID(ctx), val)
placeholderArray[i] = "(?, ?)"
args = append(args, uuids[i].String(), val)
}
placeholders := strings.Join(placeholderArray, ", ")

p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings")

// We need to write manual SQL here because the INSERT should not fail if
// the UUID already exists, but we still want to return an error if anything
Expand All @@ -39,21 +52,24 @@ func (p *Persister) ToUUID(ctx context.Context, text string) (uuid.UUID, error)
switch d := p.Connection(ctx).Dialect.Name(); d {
case "mysql":
query = `
INSERT IGNORE INTO keto_uuid_mappings (id, string_representation)
VALUES (?, ?)`
INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ` + placeholders
default:
query = `
INSERT INTO keto_uuid_mappings (id, string_representation)
VALUES (?, ?)
VALUES ` + placeholders + `
ON CONFLICT (id) DO NOTHING`
}

return id, sqlcon.HandleError(
p.Connection(ctx).RawQuery(query, id, text).Exec(),
return uuids, sqlcon.HandleError(
p.Connection(ctx).RawQuery(query, args...).Exec(),
)
}

func (p *Persister) FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) {
func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) {
if len(ids) == 0 {
return
}

p.d.Logger().Trace("looking up UUIDs")

// We need to paginate on the ids, because we want to get the exact chunk of
Expand Down Expand Up @@ -96,56 +112,51 @@ func (p *Persister) FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.Pag
return
}

func (p *Persister) replaceWithUUID(ctx context.Context, s *string) error {
if s == nil {
return nil
func filterFields(fields []*string) []*string {
res := make([]*string, 0, len(fields))
for _, field := range fields {
if field != nil && *field != "" {
res = append(res, field)
}
}
return res
}

func (p *Persister) MapFieldsToUUID(ctx context.Context, m relationtuple.UUIDMappable) error {
fields := filterFields(m.UUIDMappableFields())
values := make([]string, len(fields))

for i, field := range fields {
values[i] = *field
}
uuid, err := p.ToUUID(ctx, *s)
ids, err := p.batchToUUIDs(ctx, values)
if err != nil {
p.d.Logger().WithError(err).WithField("values", values).Error("could insert UUID mappings")
return err
}
*s = uuid.String()

for i, field := range fields {
*field = ids[i].String()
}
return nil
}

func (p *Persister) MapFieldsToUUID(ctx context.Context, m relationtuple.UUIDMappable) error {
return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error {
for _, s := range m.UUIDMappableFields() {
if s == nil || *s == "" {
continue
}
if err := p.replaceWithUUID(ctx, s); err != nil {
p.d.Logger().WithError(err).WithField("string", s).Error("got an error while mapping string to UUID")
return err
}
}
return nil
})
}

func (p *Persister) MapFieldsFromUUID(ctx context.Context, m relationtuple.UUIDMappable) error {
ids := make([]uuid.UUID, len(m.UUIDMappableFields()))
for i, field := range m.UUIDMappableFields() {
if field == nil {
continue
}
fields := filterFields(m.UUIDMappableFields())
ids := make([]uuid.UUID, len(fields))
for i, field := range fields {
id, err := uuid.FromString(*field)
if err != nil {
p.d.Logger().WithError(err).WithField("UUID", *field).Error("could not parse as UUID")
return err
}
ids[i] = id
}
reps, err := p.FromUUID(ctx, ids)
reps, err := p.batchFromUUIDs(ctx, ids)
if err != nil {
p.d.Logger().WithError(err).WithField("UUIDs", ids).Error("could fetch string mappings from DB")
return err
}
for i, field := range m.UUIDMappableFields() {
if field == nil {
continue
}
for i, field := range fields {
if reps[i] == "" {
p.d.Logger().WithError(err).WithField("string", reps[i]).Error("could not find the corresponding UUID")
return fmt.Errorf("failed to map %s", ids[i])
Expand Down
2 changes: 1 addition & 1 deletion internal/persistence/sql/uuid_mapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestUUIDMapping(t *testing.T) {
assertErr: assertCheckErr,
}, {
desc: "empty strings should fail on constraint",
mappings: &sql.UUIDMapping{uuid.Nil, ""},
mappings: &sql.UUIDMapping{ID: uuid.Nil},
assertErr: assertCheckErr,
}, {
desc: "single with string rep should succeed",
Expand Down
Loading

0 comments on commit cb8509d

Please sign in to comment.