diff --git a/go.mod b/go.mod index b8e8e85d..cdedf3c4 100644 --- a/go.mod +++ b/go.mod @@ -60,6 +60,7 @@ require ( goa.design/plugins/v3 v3.15.2 gocloud.dev v0.39.0 golang.org/x/crypto v0.26.0 + golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 gotest.tools/v3 v3.5.1 @@ -165,7 +166,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect go4.org v0.0.0-20200411211856-f5505b9728dd // indirect - golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect golang.org/x/oauth2 v0.22.0 // indirect diff --git a/internal/persistence/ent/client/filter.go b/internal/persistence/ent/client/filter.go new file mode 100644 index 00000000..242ed3bf --- /dev/null +++ b/internal/persistence/ent/client/filter.go @@ -0,0 +1,338 @@ +package entclient + +import ( + "slices" + + "entgo.io/ent/dialect/sql" + "github.com/google/uuid" + "golang.org/x/exp/maps" + + "github.com/artefactual-sdps/enduro/internal/enums" + "github.com/artefactual-sdps/enduro/internal/persistence" + "github.com/artefactual-sdps/enduro/internal/timerange" +) + +const ( + DefaultPageSize int = 20 + MaxPageSize int = 50_000 +) + +// Predicate (P) is the constraint for all Ent predicates, e.g. predicate.Batch, +// predicate.Transfer and so on. +type Predicate interface { + ~func(s *sql.Selector) +} + +// OrderOption (O) is the constraint for all Ent ordering options, e.g. +// batch.OrderOption, transfer.OrderOption and so on. +type OrderOption interface { + ~func(s *sql.Selector) +} + +// Querier (Q) wraps queriers methods provided by Ent queries. +type Querier[P Predicate, O OrderOption, Q any] interface { + Where(ps ...P) Q + Limit(int) Q + Offset(int) Q + Order(...O) Q + Clone() Q +} + +type columnFilter[P Predicate] struct { + column string + predicate P +} + +type orderOption[O OrderOption] struct { + column string + option O +} + +// Filter provides a mechanism to filter, order and paginate using Ent queries. +// Invoke the Apply method last to apply the remaining filters. +type Filter[Q Querier[P, O, Q], O OrderOption, P Predicate] struct { + query Q + filters []columnFilter[P] + sortableFields SortableFields + orderBy []orderOption[O] + limit int + offset int +} + +// NewFilter returns a new Filter. It panics if orderingFields is empty. +func NewFilter[Q Querier[P, O, Q], O OrderOption, P Predicate](query Q, sf SortableFields) *Filter[Q, O, P] { + if len(sf) == 0 { + panic("sortableFields is empty") + } + + f := &Filter[Q, O, P]{ + query: query, + filters: []columnFilter[P]{}, + sortableFields: sf, + orderBy: []orderOption[O]{}, + limit: DefaultPageSize, + } + + return f +} + +// OrderBy sets the query sort order. +func (f *Filter[Q, O, P]) OrderBy(sort persistence.Sort) { + if len(sort) == 0 { + return + } + + for _, c := range sort { + f.addOrderOpt(c.Name, c.Desc) + } +} + +func (f *Filter[Q, O, P]) addOrderOpt(field string, dsc bool) { + // Check that field is an allowed sortableField. + if !slices.Contains(f.sortableFields.Columns(), field) { + return + } + + opt := orderOption[O]{ + column: field, + option: orderFunc(field, dsc), + } + + // Check if we've already sorted on this field. + i := slices.IndexFunc(f.orderBy, func(o orderOption[O]) bool { + return o.column == field + }) + + switch { + case i < 0: + f.orderBy = append(f.orderBy, opt) + default: + // Replace any previous sort on this field. + f.orderBy[i] = opt + } +} + +func (f *Filter[Q, O, P]) setDefaultOrderBy(sf SortableFields) { + d := sf.Default() + f.addOrderOpt(d.Name, false) +} + +// orderFunc is called by the ent query builder to convert a selector +// OrderOption to a MySQL "order by" clause. +func orderFunc(field string, desc bool) func(sel *sql.Selector) { + return func(sel *sql.Selector) { + s := sel.C(field) + if desc { + s += " DESC" + } + sel.OrderBy(s) + } +} + +// Page sets the limit and offset criteria. +func (f *Filter[Q, O, P]) Page(limit, offset int) { + f.addLimit(limit) + + if offset > 0 { + f.offset = offset + } +} + +// addLimit adds the page limit l to a filter. +// +// If l < 0 the page limit is set to the maximum page size. +// If l == 0 the page limit is set to the default page size. +// If l > 0 but less than the max page size, the page limit is set to l. +// If l > max page size the page limit is set to the max page size. +func (f *Filter[Q, O, P]) addLimit(l int) { + switch { + case l == 0: + l = DefaultPageSize + case l < 0: + l = MaxPageSize + case l > MaxPageSize: + l = MaxPageSize + } + + f.limit = l +} + +// addFilter adds a new selector for column. Any existing filters on column will +// be retained to allow multiple criteria for the same column (e.g. name="foo" +// or name="bar"). +func (f *Filter[Q, O, P]) addFilter(column string, selector func(s *sql.Selector)) { + f.filters = append(f.filters, columnFilter[P]{column, selector}) +} + +// validPtrValue returns true if the given pointer ptr is not nil, and the +// underlying value is valid. +// +// Validating pointers is complicated because ptr has an interface{} type. The +// conditional `ptr == nil` doesn't evaluate true when ptr is a typed nil like +// (*enums.PackageStatus)(nil). A type switch case on the validator interface +// can then assign the nil *enums.PackageStatus to the validator interface and +// calling `t.IsValid()` causes a panic from trying to call `IsValid()` on a +// nil pointer. +func validPtrValue(ptr any) bool { + if ptr == nil { + return false + } + + switch t := ptr.(type) { + case *enums.PackageStatus: + return t != nil && t.IsValid() + case *enums.PreprocessingTaskOutcome: + return t != nil && t.IsValid() + case *int: + return t != nil + case *string: + return t != nil + case *uuid.UUID: + return t != nil && *t != uuid.Nil + default: + // Return false when v's type is unknown. + return false + } +} + +// Equals adds a filter on column being equal to value. If value implements the +// validator interface, value is validated before the filter is added. +func (f *Filter[Q, O, P]) Equals(column string, value any) { + // The current code always calls this function with a pointer value (e.g. + // *string, *enums.PackageStatus). If we need to pass value types (e.g. + // (string, enums.PackageStatus) in the future we'll have to combine the + // validPtrValue() & validValue() type switch cases. + if !validPtrValue(value) { + return + } + + f.addFilter(column, func(s *sql.Selector) { + s.Where(sql.EQ(s.C(column), value)) + }) +} + +// Validator is a simple validation interface. Validator is currently used for +// enums, but it could represent any type that implements validation. +type validator interface { + IsValid() bool +} + +func validValue(v any) bool { + switch t := v.(type) { + case validator: + return t.IsValid() + case uuid.UUID: + return t != uuid.Nil + default: + // Return true for all types that can't be validated. This allows + // filtering for empty values (e.g. the empty string ""). + return true + } +} + +// In adds a filter on column being equal to one of the given values. Each +// element in values that implements validator is validated before being added +// to the list of filter values. +func (f *Filter[Q, O, P]) In(column string, values []any) { + if len(values) == 0 { + return + } + + validated := make([]any, 0, len(values)) + for _, val := range values { + // I can't see any reason we'd want to pass pointers as elements in the + // values slice. We can and do pass ([]any)(nil) but doing so skips this + // loop altogether. + if !validValue(val) { + continue + } + validated = append(validated, val) + } + + if len(validated) == 0 { + return + } + + f.addFilter(column, func(s *sql.Selector) { + s.Where(sql.In(s.C(column), validated...)) + }) +} + +// dateRangeSelector returns a predicate matching rows within a date range +// (range.Start <= date < range.End). +func dateRangeSelector(column string, r *timerange.Range) func(*sql.Selector) { + return func(s *sql.Selector) { + var p *sql.Predicate + + switch { + case r.IsInstant(): + p = sql.EQ(column, r.Start) + default: + p = sql.And( + sql.GTE(column, r.Start), + sql.LT(column, r.End), + ) + } + + s.Where(p) + } +} + +func (f *Filter[Q, O, P]) AddDateRange(column string, r *timerange.Range) { + if r == nil || r.IsZero() { + return + } + + f.addFilter(column, dateRangeSelector(column, r)) +} + +// Apply filters, returning queriers of the filtered subset and the page. +func (f *Filter[Q, O, P]) Apply() (page, whole Q) { + whole = f.query.Clone() + + ps := []P{} + for _, cf := range f.filters { + ps = append(ps, cf.predicate) + } + whole.Where(ps...) + + if len(f.orderBy) == 0 { + f.setDefaultOrderBy(f.sortableFields) + } + + opts := []O{} + for _, ob := range f.orderBy { + opts = append(opts, ob.option) + } + whole.Order(opts...) + + page = whole.Clone() + page.Limit(f.limit) + page.Offset(f.offset) + + return page, whole +} + +type SortableField struct { + Name string + Default bool +} + +// SortableFields maps column names to Ent type field names. +// Usage examples: batchOrderingFields, transferOrderingFields... +type SortableFields map[string]SortableField + +// Default returns the default sort field. +func (sf SortableFields) Default() SortableField { + for _, f := range sf { + if f.Default { + return f + } + } + + panic("no default sort field specified") +} + +func (sf SortableFields) Columns() []string { + return maps.Keys(sf) +} diff --git a/internal/persistence/ent/client/filter_test.go b/internal/persistence/ent/client/filter_test.go new file mode 100644 index 00000000..f3df63e7 --- /dev/null +++ b/internal/persistence/ent/client/filter_test.go @@ -0,0 +1,525 @@ +package entclient_test + +import ( + "testing" + "time" + + "entgo.io/ent/dialect/sql" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "go.artefactual.dev/tools/ref" + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/enums" + "github.com/artefactual-sdps/enduro/internal/persistence" + entclient "github.com/artefactual-sdps/enduro/internal/persistence/ent/client" + "github.com/artefactual-sdps/enduro/internal/timerange" +) + +type pred func(*sql.Selector) + +type orderOpt func(*sql.Selector) + +// query is a querier like *db.PkgQuery for testing. +type query struct { + table string + limit int + offset int + order []string + where string + args []any +} + +func (q *query) Where(preds ...pred) *query { + sel := sql.Select().From(sql.Table(q.table)) + for _, pred := range preds { + pred(sel) + } + _, q.args = sel.Query() + if p := sel.P(); p != nil { + q.where = sel.P().String() + } + return q +} + +func (q *query) Limit(l int) *query { + q.limit = l + return q +} + +func (q *query) Offset(o int) *query { + q.offset = o + return q +} + +func (q *query) Order(fn ...orderOpt) *query { + sel := sql.Select().From(sql.Table(q.table)) + for _, f := range fn { + f(sel) + } + q.order = sel.OrderColumns() + return q +} + +func (q query) Clone() *query { + return &query{ + table: q.table, + limit: q.limit, + offset: q.offset, + order: append([]string{}, q.order...), + where: q.where, + args: append([]any{}, q.args...), + } +} + +func newSortableFields(fields ...string) entclient.SortableFields { + sf := map[string]entclient.SortableField{} + for i, name := range fields { + sf[name] = entclient.SortableField{Name: name, Default: i == 0} + } + + return sf +} + +func TestFilter(t *testing.T) { + t.Parallel() + + t.Run("Sorts allowed fields", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id", "name"), + ) + + f.OrderBy(persistence.NewSort(). + AddCol("id", false). + AddCol("name", false), + ) + page, whole := f.Apply() + + assert.DeepEqual( + t, + page, + &query{ + table: "data", + limit: entclient.DefaultPageSize, + order: []string{"`data`.`id`", "`data`.`name`"}, + args: []any{}, + }, + cmp.AllowUnexported(query{}), + ) + assert.DeepEqual( + t, + whole, + &query{ + table: "data", + order: []string{"`data`.`id`", "`data`.`name`"}, + }, + cmp.AllowUnexported(query{}), + ) + }) + + t.Run("Sorts allowed fields in descending order", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id", "name"), + ) + f.OrderBy(persistence.NewSort().AddCol("name", true)) + page, whole := f.Apply() + + assert.DeepEqual( + t, + page, + &query{ + table: "data", + limit: entclient.DefaultPageSize, + order: []string{"`data`.`name` DESC"}, + args: []any{}, + }, + cmp.AllowUnexported(query{}), + ) + assert.DeepEqual( + t, + whole, + &query{ + table: "data", + order: []string{"`data`.`name` DESC"}, + }, + cmp.AllowUnexported(query{}), + ) + }) + + t.Run("Sorts by default sort column", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + map[string]entclient.SortableField{ + "id": {Name: "id"}, + "name": {Name: "name", Default: true}, + }, + ) + page, whole := f.Apply() + + assert.DeepEqual( + t, + page, + &query{ + table: "data", + limit: entclient.DefaultPageSize, + order: []string{"`data`.`name`"}, + args: []any{}, + }, + cmp.AllowUnexported(query{}), + ) + assert.DeepEqual( + t, + whole, + &query{ + table: "data", + order: []string{"`data`.`name`"}, + }, + cmp.AllowUnexported(query{}), + ) + }) + + t.Run("Panics when sortableFields is empty", func(t *testing.T) { + t.Parallel() + + q := &query{table: "data"} + + defer func() { + r := recover() + assert.Equal(t, r.(string), "sortableFields is empty") + }() + + entclient.NewFilter(q, nil) + }) + + t.Run("Panics when no default sortableField is set", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + map[string]entclient.SortableField{ + "id": {Name: "id"}, + }, + ) + + defer func() { + r := recover() + assert.Equal(t, r.(string), "no default sort field specified") + }() + + _, _ = f.Apply() + }) + + t.Run("Handles unknown sorting field, defaults to first known field", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id", "age"), + ) + f.OrderBy(persistence.NewSort().AddCol("count", false)) + page, whole := f.Apply() + assert.DeepEqual( + t, + page, + &query{ + table: "data", + limit: entclient.DefaultPageSize, + order: []string{"`data`.`id`"}, + args: []any{}, + }, + cmp.AllowUnexported(query{}), + ) + assert.DeepEqual( + t, + whole, + &query{ + table: "data", + order: []string{"`data`.`id`"}, + }, + cmp.AllowUnexported(query{}), + ) + }) + + t.Run("Sorts by the final sort on a field", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id", "name"), + ) + f.OrderBy(persistence.NewSort(). + AddCol("name", true). + AddCol("name", false), + ) + page, whole := f.Apply() + + assert.DeepEqual(t, page.order, []string{"`data`.`name`"}) + assert.DeepEqual(t, whole.order, []string{"`data`.`name`"}) + }) + + t.Run("Default sort when given an empty order param", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id", "name"), + ) + f.OrderBy(persistence.NewSort()) + page, whole := f.Apply() + + assert.DeepEqual(t, page.order, []string{"`data`.`id`"}) + assert.DeepEqual(t, whole.order, []string{"`data`.`id`"}) + }) + + t.Run("Sets page limit and offset", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id", "name"), + ) + f.Page(50, 100) + f.OrderBy(persistence.NewSort().AddCol("name", false)) + page, whole := f.Apply() + + assert.Equal(t, page.limit, 50) + assert.Equal(t, page.offset, 100) + assert.DeepEqual(t, page.order, []string{"`data`.`name`"}) + + assert.Equal(t, whole.limit, 0) + assert.Equal(t, whole.offset, 0) + assert.DeepEqual(t, whole.order, []string{"`data`.`name`"}) + }) + + t.Run("Page size defaults to default value", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + page, whole := f.Apply() + + assert.Equal(t, page.limit, entclient.DefaultPageSize) + assert.Equal(t, whole.limit, 0) + }) + + t.Run("Passing a zero page limit uses the default page limit", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + f.Page(0, 0) + page, whole := f.Apply() + + assert.Equal(t, page.limit, entclient.DefaultPageSize) + assert.Equal(t, whole.limit, 0) + }) + + t.Run("Page size limited to max page size", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + f.Page(100_000, 0) + page, whole := f.Apply() + + assert.Equal(t, page.limit, entclient.MaxPageSize) + assert.Equal(t, whole.limit, 0) + }) + + t.Run("Page size is max page size when set to a negative number", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + f.Page(-100, 0) + page, whole := f.Apply() + + assert.Equal(t, page.limit, entclient.MaxPageSize) + assert.Equal(t, whole.limit, 0) + }) + + t.Run("Adds an equality filter", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + + id := 1234 + name := "Joe" + aipID := uuid.New() + + f.Equals("id", &id) // Filter on an *int value. + f.Equals("id2", id) // Ignore a non-pointer (int) value. + f.Equals("name", &name) // Filter on a *string value. + f.Equals("aip_id", &aipID) // Filter on a *uuid.UUID value. + f.Equals("address", (*string)(nil)) // Ignore a typed nil value. + f.Equals("address2", nil) // Ignore (interface{})(nil). + _, whole := f.Apply() + + assert.Equal(t, whole.where, "(`data`.`id` = ? AND `data`.`name` = ?) AND `data`.`aip_id` = ?") + assert.DeepEqual(t, whole.args, []any{&id, &name, &aipID}) + }) + + t.Run("Filters enums", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + + // Add a string enum filter. + taskOutcome := enums.PreprocessingTaskOutcomeSuccess + f.Equals("outcome", &taskOutcome) + + // Add an integer enum filter. + pkgStatus := enums.PackageStatusDone + f.Equals("status", &pkgStatus) + + // Omit invalid enum values. + f.Equals("outcome2", ref.New(enums.PreprocessingTaskOutcome("invalid"))) + + // Omit nil enum pointers. + f.Equals("status2", (*enums.PackageStatus)(nil)) + + _, whole := f.Apply() + + assert.Equal(t, whole.where, "`data`.`outcome` = ? AND `data`.`status` = ?") + assert.DeepEqual(t, whole.args, []any{&taskOutcome, &pkgStatus}) + }) + + t.Run("Filters on a list of strings", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + f.In("name", []any{"foo", "bar", ""}) + f.In("empty", []any{}) // Ignore an empty slice. + f.In("nil", ([]any)(nil)) // Ignore a nil slice. + _, whole := f.Apply() + + assert.Equal(t, whole.where, "`data`.`name` IN (?, ?, ?)") + assert.DeepEqual(t, whole.args, []any{"foo", "bar", ""}) + }) + + t.Run("Filters on a list of enums", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + f.In("status", []any{ + enums.PackageStatusInProgress, + enums.PackageStatusDone, + enums.PackageStatus(100), // Ignore an invalid enum. + }) + _, whole := f.Apply() + + assert.Equal(t, whole.where, "`data`.`status` IN (?, ?)") + assert.DeepEqual(t, whole.args, []any{ + enums.PackageStatusInProgress, + enums.PackageStatusDone, + }) + }) + + t.Run("Filters on a list of UUIDs", func(t *testing.T) { + t.Parallel() + + uuid0 := uuid.New() + uuid1 := uuid.New() + var uuid2 uuid.UUID + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + f.In("aip_id", []any{ + uuid0, + uuid1, + uuid2, // Ignore a nil UUID. + }) + _, whole := f.Apply() + + assert.Equal(t, whole.where, "`data`.`aip_id` IN (?, ?)") + assert.DeepEqual(t, whole.args, []any{uuid0, uuid1}) + }) + + t.Run("Filters on a date range", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + + r, err := timerange.New( + time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC), + time.Date(2024, 9, 1, 0, 0, 0, 0, time.UTC), + ) + assert.NilError(t, err) + + f.AddDateRange("created_at", &r) + _, whole := f.Apply() + + assert.Equal(t, whole.where, "`created_at` >= ? AND `created_at` < ?") + assert.DeepEqual(t, whole.args, []any{ + time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC), + time.Date(2024, 9, 1, 0, 0, 0, 0, time.UTC), + }) + }) + + t.Run("Filters on an exact time", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + + r := timerange.NewInstant(time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC)) + f.AddDateRange("created_at", &r) + _, whole := f.Apply() + + assert.Equal(t, whole.where, "`created_at` = ?") + assert.DeepEqual(t, whole.args, []any{ + time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC), + }) + }) + + t.Run("No filter added when date range is zero", func(t *testing.T) { + t.Parallel() + + f := entclient.NewFilter( + &query{table: "data"}, + newSortableFields("id"), + ) + + var r timerange.Range + f.AddDateRange("created_at", &r) + _, whole := f.Apply() + + assert.Equal(t, whole.where, "") + assert.Assert(t, whole.args == nil) + }) +} diff --git a/internal/persistence/filter.go b/internal/persistence/filter.go new file mode 100644 index 00000000..401daf7c --- /dev/null +++ b/internal/persistence/filter.go @@ -0,0 +1,60 @@ +package persistence + +import ( + "github.com/google/uuid" + + "github.com/artefactual-sdps/enduro/internal/enums" + "github.com/artefactual-sdps/enduro/internal/timerange" +) + +type ( + // Sort determines how the filtered results are sorted by specifying a + // slice of sort columns. The first SortColumn has the highest sort + // precedence, and the last SortColumn the lowest precedence. + Sort []SortColumn + + // SortColumn specifies a column name on which to sort results, and the + // direction of the sort (ascending or descending). + SortColumn struct { + // Name of the column on which to sort the results. + Name string + + // Desc is true if the sort order is descending. + Desc bool + } +) + +// NewSort returns a new sort instance. +func NewSort() Sort { + return Sort{} +} + +// AddCol adds a SortColumn to a Sort then returns the updated Sort. +func (s Sort) AddCol(name string, desc bool) Sort { + s = append(s, SortColumn{Name: name, Desc: desc}) + return s +} + +// Page represents a subset of results within a search result set. +type Page struct { + // Limit is the maximum number of results per page. + Limit int + + // Offset is the ordinal position, relative to the start of the unfiltered + // set, of the first result of the page. + Offset int + + // Total is the total number of search results before paging. + Total int +} + +type PackageFilter struct { + Name *string + AIPID *uuid.UUID + LocationID *uuid.UUID + Status *enums.PackageStatus + StartedAt *timerange.Range + + Sort + Page +} diff --git a/internal/persistence/filter_test.go b/internal/persistence/filter_test.go new file mode 100644 index 00000000..f27e914c --- /dev/null +++ b/internal/persistence/filter_test.go @@ -0,0 +1,18 @@ +package persistence_test + +import ( + "testing" + + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/persistence" +) + +func TestOrder(t *testing.T) { + got := persistence.NewSort().AddCol("id", false).AddCol("date", true) + + assert.DeepEqual(t, got, persistence.Sort{ + {Name: "id", Desc: false}, + {Name: "date", Desc: true}, + }) +} diff --git a/internal/timerange/timerange.go b/internal/timerange/timerange.go new file mode 100644 index 00000000..a106fbdc --- /dev/null +++ b/internal/timerange/timerange.go @@ -0,0 +1,37 @@ +package timerange + +import ( + "errors" + "time" +) + +type Range struct { + Start time.Time + End time.Time +} + +// New returns a new Range with the given Start and End times. New will return +// an error if the End time is before the Start time. +func New(start, end time.Time) (Range, error) { + if end.Before(start) { + return Range{}, errors.New("time range: end cannot be before start") + } + + return Range{Start: start, End: end}, nil +} + +// NewInstant returns a Range where the Start and End times are both set to the +// given time. +func NewInstant(t time.Time) Range { + return Range{Start: t, End: t} +} + +// IsZero returns true when both the Start and End times are zero. +func (r Range) IsZero() bool { + return r.Start.IsZero() && r.End.IsZero() +} + +// IsInstant returns true when the Start an End times are equal. +func (r Range) IsInstant() bool { + return r.Start == r.End +} diff --git a/internal/timerange/timerange_test.go b/internal/timerange/timerange_test.go new file mode 100644 index 00000000..137392df --- /dev/null +++ b/internal/timerange/timerange_test.go @@ -0,0 +1,93 @@ +package timerange_test + +import ( + "testing" + "time" + + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/timerange" +) + +func TestNew(t *testing.T) { + t.Parallel() + + t.Run("Returns a time range", func(t *testing.T) { + t.Parallel() + + r, err := timerange.New( + time.Date(2024, 9, 17, 0, 0, 0, 0, time.UTC), + time.Date(2024, 9, 18, 0, 0, 0, 0, time.UTC), + ) + assert.NilError(t, err) + assert.DeepEqual(t, r, timerange.Range{ + Start: time.Date(2024, 9, 17, 0, 0, 0, 0, time.UTC), + End: time.Date(2024, 9, 18, 0, 0, 0, 0, time.UTC), + }) + }) + + t.Run("Errors when end time is before start time", func(t *testing.T) { + t.Parallel() + + _, err := timerange.New( + time.Date(2024, 9, 17, 0, 0, 0, 0, time.UTC), + time.Date(2024, 9, 16, 0, 0, 0, 0, time.UTC), + ) + assert.Error(t, err, "time range: end cannot be before start") + }) +} + +func TestNewInstant(t *testing.T) { + t.Parallel() + + t.Run("Returns an instant time", func(t *testing.T) { + t.Parallel() + + r := timerange.NewInstant(time.Date(2024, 9, 17, 0, 0, 0, 0, time.UTC)) + assert.DeepEqual(t, r, timerange.Range{ + Start: time.Date(2024, 9, 17, 0, 0, 0, 0, time.UTC), + End: time.Date(2024, 9, 17, 0, 0, 0, 0, time.UTC), + }) + }) +} + +func TestIsZero(t *testing.T) { + t.Parallel() + + t.Run("Returns true when start and end times are zero", func(t *testing.T) { + t.Parallel() + + var r timerange.Range + assert.Assert(t, r.IsZero()) + }) + + t.Run("Returns false when the start or end time is not zero", func(t *testing.T) { + t.Parallel() + + var t0 time.Time + r, err := timerange.New(t0, time.Now()) + assert.NilError(t, err) + assert.Assert(t, !r.IsZero()) + }) +} + +func TestIsInstant(t *testing.T) { + t.Parallel() + + t.Run("Returns true when start and end time are equal", func(t *testing.T) { + t.Parallel() + + n := time.Now() + r, err := timerange.New(n, n) + assert.NilError(t, err) + assert.Assert(t, r.IsInstant()) + }) + + t.Run("Returns false when start and end time are not equal", func(t *testing.T) { + t.Parallel() + + r, err := timerange.New(time.Now(), time.Now()) + assert.NilError(t, err) + assert.Assert(t, !r.IsInstant()) + }) +}