diff --git a/error.go b/error.go index b8749c4..c2052ce 100644 --- a/error.go +++ b/error.go @@ -6,5 +6,7 @@ var ( ErrBadConfig = errors.New("bad config") ErrMissingData = errors.New("missing data") ErrNotExist = errors.New("not exist") + ErrNotFound = errors.New("not found") ErrNotValid = errors.New("invalid") + ErrUnexpected = errors.New("unexpected") ) diff --git a/postgres/service.go b/postgres/service.go index 4c77e92..d6f3938 100644 --- a/postgres/service.go +++ b/postgres/service.go @@ -1,27 +1,151 @@ package postgres import ( + "database/sql" + "errors" + "fmt" "math" + "github.com/xy-planning-network/trails" "gorm.io/gorm" + "gorm.io/gorm/clause" ) -// DatabaseService sets up the interface to be used at the handler/middelware level. These should be straightforward -// calls that allow us to skip creating a procedure method for the most basic database interactions. At the procedural -// layer, the *gorm.DB struct is available directly for more complex composition. This has the intended functionality -// that we are not testing the database in handlers, while it is tested directly at the procedural layer. -type DatabaseService interface { - CountByQuery(model any, query map[string]any) (int64, error) - FetchByQuery(models any, query string, params []any) error - FindByID(model any, ID any) error - FindByQuery(model any, query map[string]any) error - PagedByQuery(models any, query string, params []any, order string, page int, perPage int, preloads ...string) (PagedData, error) - PagedByQueryFromSession(models any, session *gorm.DB, page int, perPage int) (PagedData, error) -} - -// PagedData is returned from the PagedByQuery method. It contains paged database records and pagination metadata. -type PagedData struct { - Items any `json:"items"` +type Simple[T any] struct { + db *gorm.DB +} + +func NewSimple[T any](db *gorm.DB) *Simple[T] { + return &Simple[T]{db} +} + +func (db *Simple[T]) Debug() *Simple[T] { return db } + +func (db *Simple[T]) Distinct(args ...interface{}) *Simple[T] { return db } + +func (db *Simple[T]) Find() (T, error) { return nil } + +func (db *Simple[T]) First() (T, error) { + var dest T + err := db.db.Model(&dest).First(&dest).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("%w", trails.ErrNotFound) + } + + if err != nil { + return fmt.Errorf("%w: %s", trails.ErrUnexpected, err) + } + + return nil +} + +func (db *Simple[T]) Group(name string) *Simple[T] { return db } + +func (db *Simple[T]) Joins(query string, args ...interface{}) *Simple[T] { return db } + +func (db *Simple[T]) Limit(limit int) *Simple[T] { return db } + +func (db *Simple[T]) Offset(offset int) *Simple[T] { return db } + +func (db *Simple[T]) Or(query interface{}, args ...interface{}) *Simple[T] { return db } + +func (db *Simple[T]) Order(value interface{}) *Simple[T] { return db } + +func (db *Simple[T]) Paged(page, perPage int) (PagedData[T], error) { + var pd PagedData[T] + var items T + + // Make sure page/perPage are sane + page = max(1, page) + perPage = max(1, perPage) + + // Conduct unlimited count query to calculate totals + var totalRecords int64 + if err := db.db.Model(&items).Session(new(gorm.Session)).Count(&totalRecords).Error; err != nil { + err = fmt.Errorf("%w: %s", trails.ErrUnexpected, err) + return pd, err + } + + // Calculate offset and conduct limited query + offset := (page - 1) * perPage + if err := db.db.Model(&items).Limit(perPage).Offset(offset).Find(&items).Error; err != nil { + err = fmt.Errorf("%w: %s", trails.ErrUnexpected, err) + return pd, err + } + + pd.Items = items + pd.Page = page + pd.PerPage = perPage + pd.TotalItems = totalRecords + totalPagesFloat := float64(totalRecords) / float64(perPage) + pd.TotalPages = int(math.Ceil(totalPagesFloat)) + + return pd, nil +} + +func (db *Simple[T]) Preload(query string, args ...interface{}) *Simple[T] { return db } + +func (db *Simple[T]) Scopes(funcs ...func(*Simple[T]) *Simple[T]) *Simple[T] { return db } + +func (db *Simple[T]) Select(columns ...string) *Simple[T] { return db } + +// Table specifies the name of the table to query when T is not a struct matching a database table. +// +// e.g.: +// +// ids, err := NewSimple[[]uint].Table("users").Select("id").Find() +func (db *Simple[T]) Table(name string) *Simple[T] {} + +func (db *Simple[T]) Unscoped() *Simple[T] { return db } + +func (db *Simple[T]) Where(query interface{}, args ...interface{}) *Simple[T] { + return &Simple[T]{db.db.Where(query, args...)} +} + +type Robust[T any] struct { + *Simple[T] +} + +func (db *Robust[T]) Begin(opts ...*sql.TxOptions) *Robust[T] { return db } + +func (db *Robust[T]) Commit() error { return nil } + +func (db *Robust[T]) Create(value T) error { return nil } + +// escape hatch +func (db *Robust[T]) DB() *gorm.DB { return db.db } + +func (db *Robust[T]) Delete(value T) error { return nil } + +func (db *Robust[T]) Exec(sql string, values ...interface{}) (int64, error) { + var err error + res := db.db.Exec(sql, values...) + if res.Error != nil { + err = fmt.Errorf("%w: %s", trails.ErrUnexpected, err) + } + return res.RowsAffected, err +} + +func (db *Robust[T]) Raw(dest interface{}, sql string, values ...interface{}) error { return nil } + +func (db *Robust[T]) Rollback() *Robust[T] { return db } + +func (db *Robust[T]) Update(column string, value interface{}) (T, error) { return db } + +func (db *Robust[T]) Updates(values interface{}) (T, error) { + var dest T + err := db.db.Model(&dest).Clauses(clause.Returning{} /* GORM's RETURNING * syntax */).Updates(values).Error + if err != nil { + err = fmt.Errorf("%w: %s", trails.ErrUnexpected, err) + return dest, err + } + + return dest, nil +} + +// PagedData is returned from the Paged method. It contains paged database records and pagination metadata. +type PagedData[T any] struct { + Items T `json:"items"` Page int `json:"page"` PerPage int `json:"perPage"` TotalItems int64 `json:"totalItems"`