diff --git a/compat.go b/compat.go new file mode 100644 index 0000000..ee421cd --- /dev/null +++ b/compat.go @@ -0,0 +1,33 @@ +package sqltest + +var ( + // DriversByProduct maps a Product to the Driver implementations that it is + // compatible with. + DriversByProduct = map[Product][]Driver{} + + // ProductsByDriver maps a Driver to the Product implementations that it is + // compatible with. + ProductsByDriver = map[Driver][]Product{} + + // CompatiblePairs contains all compatible driver/product pairs. + CompatiblePairs []Pair +) + +// Pair is a struct containing a driver and product that are compatible with +// each other. +type Pair struct { + Driver Driver + Product Product +} + +func init() { + for _, d := range Drivers { + for _, p := range Products { + if p.IsCompatibleWith(d) { + DriversByProduct[p] = append(DriversByProduct[p], d) + ProductsByDriver[d] = append(ProductsByDriver[d], p) + CompatiblePairs = append(CompatiblePairs, Pair{d, p}) + } + } + } +} diff --git a/database.go b/database.go index f47d8a6..3017132 100644 --- a/database.go +++ b/database.go @@ -141,6 +141,14 @@ func (db *Database) Close() error { // It first checks for an environment variable containing a DSN. If that is not // present it askes the product to generate a default DSN. func dataSource(d Driver, p Product) (DataSource, error) { + if !p.IsCompatibleWith(d) { + return nil, fmt.Errorf( + "%s is incompatible with the '%s' driver", + p.Name(), + d.Name(), + ) + } + key := strings.ToUpper(fmt.Sprintf("DOGMATIQ_TEST_DSN_%s_%s", p.Name(), d.Name())) dsn := os.Getenv(key) @@ -159,14 +167,6 @@ func dataSource(d Driver, p Product) (DataSource, error) { ds, err := p.DefaultDataSource(d) - if errors.Is(err, ErrIncompatibleDriver) { - return nil, fmt.Errorf( - "%s is incompatible with the '%s' driver", - p.Name(), - d.Name(), - ) - } - if err != nil { return nil, fmt.Errorf( "can not build a default %s DSN using the '%s' driver: %w", diff --git a/driver.go b/driver.go index d5ac55c..1b49036 100644 --- a/driver.go +++ b/driver.go @@ -49,4 +49,12 @@ var ( // SQLite3Driver is the "sqlite3" driver (github.com/mattn/go-sqlite3). SQLite3Driver Driver = sqlite3Driver{} + + // Drivers is a slice containing all known products. + Drivers = []Driver{ + MySQLDriver, + PGXDriver, + PostgresDriver, + SQLite3Driver, + } ) diff --git a/product.go b/product.go index e21bf62..555c929 100644 --- a/product.go +++ b/product.go @@ -3,13 +3,8 @@ package sqltest import ( "context" "database/sql" - "errors" ) -// ErrIncompatibleDriver indicates that a specific driver can not be used to -// connect to a specific product. -var ErrIncompatibleDriver = errors.New("the driver is not compatible with the product") - // Product is a specific database product such as MySQL or MariaDB. // // The product correlates with a running service that tests are run against. @@ -18,6 +13,9 @@ type Product interface { // Name returns the human-readable name of the product. Name() string + // IsCompatibleWith return true if the product is compatible with d. + IsCompatibleWith(d Driver) bool + // DefaultDataSource returns the default data source to use to connect to // the product. // @@ -64,4 +62,12 @@ var ( // SQLite is the Product for SQLite (https://www.sqlite.org). SQLite Product = sqliteProduct{} + + // Products is a slice containing all known products. + Products = []Product{ + MySQL, + MariaDB, + PostgreSQL, + SQLite, + } ) diff --git a/productmysql.go b/productmysql.go index 14c1c7d..6fcf590 100644 --- a/productmysql.go +++ b/productmysql.go @@ -25,15 +25,16 @@ func (p MySQLCompatibleProduct) Name() string { return p.ProductName } +// IsCompatibleWith return true if the product is compatible with d. +func (p MySQLCompatibleProduct) IsCompatibleWith(d Driver) bool { + _, ok := d.(MySQLProtocol) + return ok +} + // DefaultDataSource returns the default data source to use to connect to the // product. func (p MySQLCompatibleProduct) DefaultDataSource(d Driver) (DataSource, error) { - proto, ok := d.(MySQLProtocol) - if !ok { - return nil, ErrIncompatibleDriver - } - - return proto.DataSourceForMySQL( + return d.(MySQLProtocol).DataSourceForMySQL( "root", "rootpass", "127.0.0.1", p.DefaultPort, "mysql", diff --git a/productpostgres.go b/productpostgres.go index ddef438..b962d0f 100644 --- a/productpostgres.go +++ b/productpostgres.go @@ -26,15 +26,16 @@ func (p PostgresCompatibleProduct) Name() string { return p.ProductName } +// IsCompatibleWith return true if the product is compatible with d. +func (p PostgresCompatibleProduct) IsCompatibleWith(d Driver) bool { + _, ok := d.(PostgresProtocol) + return ok +} + // DefaultDataSource returns the default data source to use to connect to the // product. func (p PostgresCompatibleProduct) DefaultDataSource(d Driver) (DataSource, error) { - proto, ok := d.(PostgresProtocol) - if !ok { - return nil, ErrIncompatibleDriver - } - - return proto.DataSourceForPostgres( + return d.(PostgresProtocol).DataSourceForPostgres( "postgres", "rootpass", "127.0.0.1", p.DefaultPort, "", // default database diff --git a/productsqlite.go b/productsqlite.go index 88954dc..ebd85db 100644 --- a/productsqlite.go +++ b/productsqlite.go @@ -17,13 +17,13 @@ func (sqliteProduct) Name() string { return "SQLite" } -func (sqliteProduct) DefaultDataSource(d Driver) (DataSource, error) { - proto, ok := d.(SQLiteProtocol) - if !ok { - return nil, ErrIncompatibleDriver - } +func (sqliteProduct) IsCompatibleWith(d Driver) bool { + _, ok := d.(SQLiteProtocol) + return ok +} - return proto.DataSourceForSQLite( +func (sqliteProduct) DefaultDataSource(d Driver) (DataSource, error) { + return d.(SQLiteProtocol).DataSourceForSQLite( filepath.Join(os.TempDir(), "dogmatiq.sqlite3"), ) }