From cdda311f256948f9914100dac8b423601351b2a2 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 24 Sep 2024 16:57:35 -0700 Subject: [PATCH] Add `ListPackages()` to persistence service Refs #988 This continues the work toward #988 by adding a `ListPackages()` method to the persistence service that includes a count of the total search results returned (before paging), as well as allowing filtering, sorting, and paging of the search. Other changes: - Add a telemetry wrapper for the `ListPackages()` method to provide extra performance and debugging data - Add additional commentary for the `Filter.Page()` method - Add many unit tests for the `ListPackages()` functionality --- internal/persistence/ent/client/filter.go | 18 +- internal/persistence/ent/client/package.go | 55 +++ .../persistence/ent/client/package_test.go | 352 ++++++++++++++++++ internal/persistence/fake/mock_persistence.go | 40 ++ internal/persistence/filter_test.go | 2 +- internal/persistence/persistence.go | 1 + internal/persistence/telemetry.go | 13 + 7 files changed, 474 insertions(+), 7 deletions(-) diff --git a/internal/persistence/ent/client/filter.go b/internal/persistence/ent/client/filter.go index 242ed3bf..23ca342a 100644 --- a/internal/persistence/ent/client/filter.go +++ b/internal/persistence/ent/client/filter.go @@ -129,7 +129,18 @@ func orderFunc(field string, desc bool) func(sel *sql.Selector) { } } -// Page sets the limit and offset criteria. +// Page sets the query limit and offset criteria. +// +// The actual query limit will be set to a value between one and MaxPageSize +// based on the input limit (x) as follows: +// (x < 0) -> MaxPageSize +// (x == 0) -> DefaultPageSize +// (0 < x < MaxPageSize) -> x +// (x >= MaxPageSize) -> MaxPageSize +// +// The actual query offset will be set based on the input offset (y) as follows: +// (y <= 0) -> 0 +// (y > 0) -> y func (f *Filter[Q, O, P]) Page(limit, offset int) { f.addLimit(limit) @@ -139,11 +150,6 @@ func (f *Filter[Q, O, P]) Page(limit, offset int) { } // addLimit adds the page limit l to a filter. -// -// If l < 0 the page limit is set to the maximum page size. -// If l == 0 the page limit is set to the default page size. -// If l > 0 but less than the max page size, the page limit is set to l. -// If l > max page size the page limit is set to the max page size. func (f *Filter[Q, O, P]) addLimit(l int) { switch { case l == 0: diff --git a/internal/persistence/ent/client/package.go b/internal/persistence/ent/client/package.go index a68de1b9..1f94c018 100644 --- a/internal/persistence/ent/client/package.go +++ b/internal/persistence/ent/client/package.go @@ -8,6 +8,8 @@ import ( "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/persistence" + "github.com/artefactual-sdps/enduro/internal/persistence/ent/db" + "github.com/artefactual-sdps/enduro/internal/persistence/ent/db/pkg" ) // CreatePackage creates and persists a new package using the values from pkg @@ -130,3 +132,56 @@ func (c *client) UpdatePackage( return convertPkgToPackage(p), nil } + +// ListPackages returns a slice of packages filtered according to f. +func (c *client) ListPackages(ctx context.Context, f persistence.PackageFilter) ( + []*datatypes.Package, *persistence.Page, error, +) { + res := []*datatypes.Package{} + + page, whole := filterPackages(c.ent.Pkg.Query(), &f) + + r, err := page.All(ctx) + if err != nil { + return nil, nil, newDBError(err) + } + + for _, i := range r { + res = append(res, convertPkgToPackage(i)) + } + + total, err := whole.Count(ctx) + if err != nil { + return nil, nil, newDBError(err) + } + + pp := &persistence.Page{ + Limit: f.Limit, + Offset: f.Offset, + Total: total, + } + + return res, pp, err +} + +// filterPackages filters a package query based on filtering inputs. +func filterPackages(q *db.PkgQuery, f *persistence.PackageFilter) (page, whole *db.PkgQuery) { + h := NewFilter(q, SortableFields{ + pkg.FieldID: {Name: "ID", Default: true}, + }) + h.Equals(pkg.FieldName, f.Name) + h.Equals(pkg.FieldAipID, f.AIPID) + h.Equals(pkg.FieldLocationID, f.LocationID) + h.Equals(pkg.FieldStatus, f.Status) + h.AddDateRange(pkg.FieldStartedAt, f.StartedAt) + h.OrderBy(f.Sort) + h.Page(f.Limit, f.Offset) + + // Update the filter values with the actual values set on the query. E.g. + // calling `h.Page(0,0)` will set the query limit equal to the default page + // size. + f.Limit = h.limit + f.Offset = h.offset + + return h.Apply() +} diff --git a/internal/persistence/ent/client/package_test.go b/internal/persistence/ent/client/package_test.go index 188979d3..b906a4de 100644 --- a/internal/persistence/ent/client/package_test.go +++ b/internal/persistence/ent/client/package_test.go @@ -10,14 +10,26 @@ import ( "github.com/go-logr/logr" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" + "go.artefactual.dev/tools/ref" "gotest.tools/v3/assert" "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/enums" "github.com/artefactual-sdps/enduro/internal/persistence" + entclient "github.com/artefactual-sdps/enduro/internal/persistence/ent/client" "github.com/artefactual-sdps/enduro/internal/persistence/ent/db" + "github.com/artefactual-sdps/enduro/internal/timerange" ) +func nullUUIDToPtr(u uuid.NullUUID) *uuid.UUID { + var p *uuid.UUID + if u.Valid { + p = &u.UUID + } + + return p +} + func TestCreatePackage(t *testing.T) { t.Parallel() @@ -328,3 +340,343 @@ func TestUpdatePackage(t *testing.T) { }) } } + +func TestListPackages(t *testing.T) { + t.Parallel() + + runID := uuid.MustParse("c5f7c35a-d5a6-4e00-b4da-b036ce5b40bc") + runID2 := uuid.MustParse("c04d0191-d7ce-46dd-beff-92d6830082ff") + + aipID := uuid.NullUUID{ + UUID: uuid.MustParse("e2ace0da-8697-453d-9ea1-4c9b62309e54"), + Valid: true, + } + aipID2 := uuid.NullUUID{ + UUID: uuid.MustParse("7d085541-af56-4444-9ce2-d6401ff4c97b"), + Valid: true, + } + + locID := uuid.NullUUID{ + UUID: uuid.MustParse("146182ff-9923-4869-bca1-0bbc0f822025"), + Valid: true, + } + locID2 := uuid.NullUUID{ + UUID: uuid.MustParse("6e30694b-6497-439f-bf99-83af165e02c3"), + Valid: true, + } + + started := sql.NullTime{ + Time: func() time.Time { + t, _ := time.Parse(time.RFC3339, "2024-09-25T09:31:11Z") + return t + }(), + Valid: true, + } + started2 := sql.NullTime{ + Time: func() time.Time { + t, _ := time.Parse(time.RFC3339, "2024-09-25T10:03:42Z") + return t + }(), + Valid: true, + } + + completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} + completed2 := sql.NullTime{Time: started2.Time.Add(time.Second), Valid: true} + + testData := []*datatypes.Package{ + { + Name: "Test package 1", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusDone, + StartedAt: started, + CompletedAt: completed, + }, + { + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + StartedAt: started2, + CompletedAt: completed2, + }, + } + + type results struct { + data []*datatypes.Package + page *persistence.Page + } + tests := []struct { + name string + data []*datatypes.Package + args persistence.PackageFilter + want results + wantErr string + }{ + { + name: "Returns all packages", + data: testData, + args: persistence.PackageFilter{}, + want: results{ + data: []*datatypes.Package{ + { + ID: 1, + Name: "Test package 1", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusDone, + CreatedAt: time.Now(), + StartedAt: started, + CompletedAt: completed, + }, + { + ID: 2, + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + page: &persistence.Page{ + Limit: entclient.DefaultPageSize, + Total: 2, + }, + }, + }, + { + name: "Returns first page of packages", + data: testData, + args: persistence.PackageFilter{ + Page: persistence.Page{Limit: 1}, + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 1, + Name: "Test package 1", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusDone, + CreatedAt: time.Now(), + StartedAt: started, + CompletedAt: completed, + }, + }, + page: &persistence.Page{ + Limit: 1, + Total: 2, + }, + }, + }, + { + name: "Returns second page of packages", + data: testData, + args: persistence.PackageFilter{ + Page: persistence.Page{Limit: 1, Offset: 1}, + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 2, + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + page: &persistence.Page{ + Limit: 1, + Offset: 1, + Total: 2, + }, + }, + }, + { + name: "Returns packages filtered by name", + data: testData, + args: persistence.PackageFilter{ + Name: ref.New("Test package 2"), + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 2, + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + page: &persistence.Page{ + Limit: entclient.DefaultPageSize, + Total: 1, + }, + }, + }, + { + name: "Returns packages filtered by AIPID", + data: testData, + args: persistence.PackageFilter{ + AIPID: nullUUIDToPtr(aipID2), + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 2, + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + page: &persistence.Page{ + Limit: entclient.DefaultPageSize, + Total: 1, + }, + }, + }, + { + name: "Returns packages filtered by LocationID", + data: testData, + args: persistence.PackageFilter{ + LocationID: nullUUIDToPtr(locID2), + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 2, + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + page: &persistence.Page{ + Limit: entclient.DefaultPageSize, + Total: 1, + }, + }, + }, + { + name: "Returns packages filtered by status", + data: testData, + args: persistence.PackageFilter{ + Status: ref.New(enums.PackageStatusInProgress), + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 2, + Name: "Test package 2", + WorkflowID: "workflow-1", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + page: &persistence.Page{ + Limit: entclient.DefaultPageSize, + Total: 1, + }, + }, + }, + { + name: "Returns packages filtered by startedAt", + data: testData, + args: persistence.PackageFilter{ + StartedAt: func(t *testing.T) *timerange.Range { + r, err := timerange.New( + time.Date(2024, 9, 25, 9, 0, 0, 0, time.UTC), + time.Date(2024, 9, 25, 10, 0, 0, 0, time.UTC), + ) + if err != nil { + t.Fatalf("Error: %v", err) + } + return &r + }(t), + }, + want: results{ + data: []*datatypes.Package{ + { + ID: 1, + Name: "Test package 1", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusDone, + CreatedAt: time.Now(), + StartedAt: started, + CompletedAt: completed, + }, + }, + page: &persistence.Page{ + Limit: entclient.DefaultPageSize, + Total: 1, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, svc := setUpClient(t, logr.Discard()) + ctx := context.Background() + + if len(tt.data) > 0 { + for _, pkg := range tt.data { + err := svc.CreatePackage(ctx, pkg) + assert.NilError(t, err) + } + } + + got, pg, err := svc.ListPackages(ctx, tt.args) + assert.NilError(t, err) + + assert.DeepEqual(t, got, tt.want.data, + cmpopts.EquateApproxTime(time.Millisecond*100), + cmpopts.IgnoreUnexported(db.Pkg{}, db.PkgEdges{}), + ) + assert.DeepEqual(t, pg, tt.want.page) + }) + } +} diff --git a/internal/persistence/fake/mock_persistence.go b/internal/persistence/fake/mock_persistence.go index 10d20eca..0ecfd0ae 100644 --- a/internal/persistence/fake/mock_persistence.go +++ b/internal/persistence/fake/mock_persistence.go @@ -155,6 +155,46 @@ func (c *MockServiceCreatePreservationTaskCall) DoAndReturn(f func(context.Conte return c } +// ListPackages mocks base method. +func (m *MockService) ListPackages(arg0 context.Context, arg1 persistence.PackageFilter) ([]*datatypes.Package, *persistence.Page, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListPackages", arg0, arg1) + ret0, _ := ret[0].([]*datatypes.Package) + ret1, _ := ret[1].(*persistence.Page) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ListPackages indicates an expected call of ListPackages. +func (mr *MockServiceMockRecorder) ListPackages(arg0, arg1 any) *MockServiceListPackagesCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPackages", reflect.TypeOf((*MockService)(nil).ListPackages), arg0, arg1) + return &MockServiceListPackagesCall{Call: call} +} + +// MockServiceListPackagesCall wrap *gomock.Call +type MockServiceListPackagesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockServiceListPackagesCall) Return(arg0 []*datatypes.Package, arg1 *persistence.Page, arg2 error) *MockServiceListPackagesCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockServiceListPackagesCall) Do(f func(context.Context, persistence.PackageFilter) ([]*datatypes.Package, *persistence.Page, error)) *MockServiceListPackagesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockServiceListPackagesCall) DoAndReturn(f func(context.Context, persistence.PackageFilter) ([]*datatypes.Package, *persistence.Page, error)) *MockServiceListPackagesCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // UpdatePackage mocks base method. func (m *MockService) UpdatePackage(arg0 context.Context, arg1 int, arg2 persistence.PackageUpdater) (*datatypes.Package, error) { m.ctrl.T.Helper() diff --git a/internal/persistence/filter_test.go b/internal/persistence/filter_test.go index f27e914c..9121bea2 100644 --- a/internal/persistence/filter_test.go +++ b/internal/persistence/filter_test.go @@ -8,7 +8,7 @@ import ( "github.com/artefactual-sdps/enduro/internal/persistence" ) -func TestOrder(t *testing.T) { +func TestSort(t *testing.T) { got := persistence.NewSort().AddCol("id", false).AddCol("date", true) assert.DeepEqual(t, got, persistence.Sort{ diff --git a/internal/persistence/persistence.go b/internal/persistence/persistence.go index 4c0d92c9..50ec2142 100644 --- a/internal/persistence/persistence.go +++ b/internal/persistence/persistence.go @@ -29,6 +29,7 @@ type Service interface { // (e.g. ID, CreatedAt). CreatePackage(context.Context, *datatypes.Package) error UpdatePackage(context.Context, int, PackageUpdater) (*datatypes.Package, error) + ListPackages(context.Context, PackageFilter) ([]*datatypes.Package, *Page, error) CreatePreservationAction(context.Context, *datatypes.PreservationAction) error diff --git a/internal/persistence/telemetry.go b/internal/persistence/telemetry.go index d40a1e56..3a7e8f8d 100644 --- a/internal/persistence/telemetry.go +++ b/internal/persistence/telemetry.go @@ -58,6 +58,19 @@ func (w *wrapper) UpdatePackage(ctx context.Context, id int, updater PackageUpda return r, nil } +func (w *wrapper) ListPackages(ctx context.Context, f PackageFilter) ([]*datatypes.Package, *Page, error) { + ctx, span := w.tracer.Start(ctx, "ListPackages") + defer span.End() + + r, pg, err := w.wrapped.ListPackages(ctx, f) + if err != nil { + telemetry.RecordError(span, err) + return nil, nil, updateError(err, "ListPackages") + } + + return r, pg, nil +} + func (w *wrapper) CreatePreservationAction(ctx context.Context, pa *datatypes.PreservationAction) error { ctx, span := w.tracer.Start(ctx, "CreatePreservationAction") defer span.End()