Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[wip] Tag Stealing #22

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/common/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ type ComparisonOperator int

const (
Equal ComparisonOperator = iota
IsNull
// Add more operators as needed, ie., gte, lte
)
1 change: 1 addition & 0 deletions pkg/repositories/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func GetRepository(repoType RepoConfig, dbConfig config.DbConfig, scope promutil
if err != nil {
panic(err)
}
db.LogMode(true)
return NewPostgresRepo(
db,
errors.NewPostgresErrorTransformer(),
Expand Down
2 changes: 1 addition & 1 deletion pkg/repositories/gormimpl/artifact_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func TestGetArtifact(t *testing.T) {
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123))) ORDER BY partitions.created_at ASC,"partitions"."dataset_uuid" ASC`).WithReply(expectedPartitionResponse)
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND ((("artifact_id","dataset_uuid") IN ((123,test-uuid)))) ORDER BY "tags"."dataset_project" ASC`).WithReply(expectedTagResponse)
`SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND ((("artifact_id","dataset_uuid") IN ((123,test-uuid)))) ORDER BY "tags"."tag_name" ASC`).WithReply(expectedTagResponse)
getInput := models.ArtifactKey{
DatasetProject: artifact.DatasetProject,
DatasetDomain: artifact.DatasetDomain,
Expand Down
16 changes: 15 additions & 1 deletion pkg/repositories/gormimpl/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import (

// String formats for various GORM expression queries
const (
equalQuery = "%s.%s = ?"
equalQuery = "%s.%s = ?"
isNullQuery = "%s.%s IS NULL"
)

type gormValueFilterImpl struct {
Expand All @@ -27,7 +28,13 @@ func (g *gormValueFilterImpl) GetDBQueryExpression(tableName string) (models.DBQ
Query: fmt.Sprintf(equalQuery, tableName, g.field),
Args: g.value,
}, nil
case common.IsNull:
return models.DBQueryExpr{
Query: fmt.Sprintf(isNullQuery, tableName, g.field),
Args: g.value,
}, nil
}

return models.DBQueryExpr{}, errors.GetUnsupportedFilterExpressionErr(g.comparisonOperator)
}

Expand All @@ -39,3 +46,10 @@ func NewGormValueFilter(comparisonOperator common.ComparisonOperator, field stri
value: value,
}
}

func NewGormNullFilter(field string) models.ModelValueFilter {
return &gormValueFilterImpl{
comparisonOperator: common.IsNull,
field: field,
}
}
7 changes: 6 additions & 1 deletion pkg/repositories/gormimpl/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ func applyListModelsInput(tx *gorm.DB, sourceEntity common.Entity, in models.Lis
if err != nil {
return nil, err
}
tx = tx.Where(dbQueryExpr.Query, dbQueryExpr.Args)

if dbQueryExpr.Args == nil {
tx = tx.Where(dbQueryExpr.Query)
} else {
tx = tx.Where(dbQueryExpr.Query, dbQueryExpr.Args)
}
}
}

Expand Down
109 changes: 106 additions & 3 deletions pkg/repositories/gormimpl/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"

"github.com/jinzhu/gorm"
"github.com/lyft/datacatalog/pkg/common"
"github.com/lyft/datacatalog/pkg/repositories/errors"
"github.com/lyft/datacatalog/pkg/repositories/interfaces"
"github.com/lyft/datacatalog/pkg/repositories/models"
idl_datacatalog "github.com/lyft/datacatalog/protos/gen"
"github.com/lyft/flytestdlib/logger"
"github.com/lyft/flytestdlib/promutils"
)

Expand All @@ -25,14 +27,115 @@ func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope pro
}
}

// A tag is associated with a single artifact for each partition combination
// When creating a tag, we remove the tag from any artifacts of the same partition
// Then add the tag to the new artifact
func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error {
timer := h.repoMetrics.CreateDuration.Start(ctx)
defer timer.Stop()

db := h.db.Create(&tag)
// There are several steps that need to be done in a transaction in order for tag stealing to occur
tx := h.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()

if db.Error != nil {
return h.errorTransformer.ToDataCatalogError(db.Error)
// 1. Find the set of partitions this artifact belongs to
var artifactToTag models.Artifact
tx.Preload("Partitions").Find(&artifactToTag, models.Artifact{
ArtifactKey: models.ArtifactKey{ArtifactID: tag.ArtifactID},
})

// 2. List artifacts in the partitions that are currently tagged
modelFilters := make([]models.ModelFilter, 0, len(artifactToTag.Partitions)+2)
for _, partition := range artifactToTag.Partitions {
modelFilters = append(modelFilters, models.ModelFilter{
Entity: common.Partition,
ValueFilters: []models.ModelValueFilter{
NewGormValueFilter(common.Equal, "key", partition.Key),
NewGormValueFilter(common.Equal, "value", partition.Value),
},
JoinCondition: NewGormJoinCondition(common.Artifact, common.Partition),
})
}

modelFilters = append(modelFilters, models.ModelFilter{
Entity: common.Tag,
ValueFilters: []models.ModelValueFilter{
NewGormValueFilter(common.Equal, "tag_name", tag.TagName),
NewGormNullFilter("deleted_at"),
},
JoinCondition: NewGormJoinCondition(common.Artifact, common.Tag),
})

listTaggedInput := models.ListModelsInput{
ModelFilters: modelFilters,
Limit: 100,
}

listArtifactsScope, err := applyListModelsInput(tx, common.Artifact, listTaggedInput)
if err != nil {
logger.Errorf(ctx, "Unable to construct artiact list, rolling back, tag: [%v], err [%v]", tag, tx.Error)
tx.Rollback()
return h.errorTransformer.ToDataCatalogError(err)

}

var artifacts []models.Artifact
if err := listArtifactsScope.Find(&artifacts).Error; err != nil {
logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, err)
tx.Rollback()
return h.errorTransformer.ToDataCatalogError(err)
}

// 3. Remove the tags from the currently tagged artifacts
if len(artifacts) != 0 {
// Soft-delete the existing tags on the artifacts that are currently tagged
for _, artifact := range artifacts {

// if the artifact to tag is already tagged, no need to remove it
if artifactToTag.ArtifactID != artifact.ArtifactID {
oldTag := models.Tag{
TagKey: models.TagKey{TagName: tag.TagName},
ArtifactID: artifact.ArtifactID,
DatasetUUID: artifact.DatasetUUID,
}
deleteScope := tx.NewScope(&models.Tag{}).DB().Delete(&models.Tag{}, oldTag)
if deleteScope.Error != nil {
logger.Errorf(ctx, "Unable to delete previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, deleteScope.Error)
tx.Rollback()
return h.errorTransformer.ToDataCatalogError(deleteScope.Error)
}
}
}
}

// 4. If the artifact was ever previously tagged with this tag, we need to
// un-delete the record because we cannot tag the artifact again since
// the primary keys are the same.
undeleteScope := tx.Unscoped().Model(&tag).Update("deleted_at", gorm.Expr("NULL")) // unscope will ignore deletedAt
if undeleteScope.Error != nil {
logger.Errorf(ctx, "Unable to undelete tag tag, rolling back, tag: [%v], err [%v]", tag, tx.Error)
tx.Rollback()
return h.errorTransformer.ToDataCatalogError(tx.Error)
}

// 5. Tag the new artifact, if it didn't previously exist
if undeleteScope.RowsAffected == 0 {
if err := tx.Create(&tag).Error; err != nil {
logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, err)
tx.Rollback()
return h.errorTransformer.ToDataCatalogError(err)
}
}

tx = tx.Commit()
if tx.Error != nil {
logger.Errorf(ctx, "Unable to commit transaction, rolling back, tag: [%v], err [%v]", tag, tx.Error)
tx.Rollback()
return h.errorTransformer.ToDataCatalogError(tx.Error)
}
return nil
}
Expand Down
57 changes: 54 additions & 3 deletions pkg/repositories/gormimpl/tag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,72 @@ func getTestTag() models.Tag {
}
}

func TestCreateTag(t *testing.T) {
func TestCreateTagNew(t *testing.T) {
tagCreated := false
GlobalMock := mocket.Catcher.Reset()
GlobalMock.Logging = true

newArtifact := getTestArtifact()

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(newArtifact))

GlobalMock.NewMock().WithQuery(
`SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBPartitionResponse(newArtifact))

GlobalMock.NewMock().WithQuery(
`SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at IS NULL)) LIMIT 100 OFFSET 0`).WithReply([]map[string]interface{}{})

GlobalMock.NewMock().WithQuery(
`INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback(
func(s string, values []driver.NamedValue) {
tagCreated = true
},
)

newTag := getTestTag()
newTag.ArtifactID = newArtifact.ArtifactID

tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope())
err := tagRepo.Create(context.Background(), getTestTag())
err := tagRepo.Create(context.Background(), newTag)

assert.NoError(t, err)
assert.True(t, tagCreated)
}

func TestStealOldTag(t *testing.T) {
tagCreated := false
GlobalMock := mocket.Catcher.Reset()
GlobalMock.Logging = true

oldArtifact := getTestArtifact()
newArtifact := getTestArtifact()
newArtifact.ArtifactID = "111"

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 111))`).WithReply(getDBArtifactResponse(newArtifact))

GlobalMock.NewMock().WithQuery(
`SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (111)))`).WithReply(getDBPartitionResponse(newArtifact))

GlobalMock.NewMock().WithQuery(
`SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at IS NULL)) LIMIT 100 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact))

GlobalMock.NewMock().WithQuery(
`INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback(
func(s string, values []driver.NamedValue) {
tagCreated = true
},
)

newTag := getTestTag()
newTag.ArtifactID = newArtifact.ArtifactID

tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope())
err := tagRepo.Create(context.Background(), newTag)

assert.NoError(t, err)
assert.True(t, tagCreated)
}
Expand All @@ -71,7 +122,7 @@ func TestGetTag(t *testing.T) {

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag)) ORDER BY tags.created_at DESC,"tags"."dataset_project" ASC LIMIT 1`).WithReply(getDBTagResponse(artifact))
`SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag)) ORDER BY tags.created_at DESC,"tags"."tag_name" ASC LIMIT 1`).WithReply(getDBTagResponse(artifact))
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND ((("dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id") IN ((testProject,testName,testDomain,testVersion,123))))`).WithReply(getDBArtifactResponse(artifact))
GlobalMock.NewMock().WithQuery(
Expand Down
10 changes: 5 additions & 5 deletions pkg/repositories/models/tag.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package models

type TagKey struct {
DatasetProject string `gorm:"primary_key"`
DatasetName string `gorm:"primary_key"`
DatasetDomain string `gorm:"primary_key"`
DatasetVersion string `gorm:"primary_key"`
DatasetProject string
DatasetName string
DatasetDomain string
DatasetVersion string
TagName string `gorm:"primary_key"`
}

type Tag struct {
BaseModel
TagKey
ArtifactID string
ArtifactID string `gorm:"primary_key"`
DatasetUUID string `gorm:"type:uuid;index:tags_dataset_uuid_idx"`
Artifact Artifact `gorm:"association_foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID;foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID"`
}