Skip to content

Commit

Permalink
sql/{mysql,sqlite}: support passing decode-level options (#3204)
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m authored Oct 28, 2024
1 parent fbb18a4 commit 7c6e629
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 64 deletions.
4 changes: 2 additions & 2 deletions sql/mysql/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ func init() {
DriverName,
sqlclient.OpenerFunc(opener),
sqlclient.RegisterDriverOpener(Open),
sqlclient.RegisterCodec(MarshalHCL, EvalHCL),
sqlclient.RegisterCodec(codec, codec),
sqlclient.RegisterFlavours("mysql+unix"),
sqlclient.RegisterURLParser(parser{}),
)
sqlclient.Register(
"mariadb",
sqlclient.OpenerFunc(opener),
sqlclient.RegisterDriverOpener(Open),
sqlclient.RegisterCodec(MarshalHCL, EvalMariaHCL),
sqlclient.RegisterCodec(mariaCodec, mariaCodec),
sqlclient.RegisterFlavours("mariadb+unix", "maria", "maria+unix"),
sqlclient.RegisterURLParser(parser{}),
)
Expand Down
64 changes: 36 additions & 28 deletions sql/mysql/sqlspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,22 @@ import (
"github.com/zclconf/go-cty/cty"
)

// evalSpec evaluates an Atlas DDL document into v using the input.
func evalSpec(state *schemahcl.State, p *hclparse.Parser, v any, input map[string]cty.Value) error {
// Codec for schemahcl.
type Codec struct {
State *schemahcl.State
}

// Eval evaluates an Atlas DDL document into v using the input.
func (c *Codec) Eval(p *hclparse.Parser, v any, input map[string]cty.Value) error {
return c.EvalOptions(p, v, &schemahcl.EvalOptions{Variables: input})
}

// EvalOptions decodes the HCL with the given options.
func (c *Codec) EvalOptions(p *hclparse.Parser, v any, opts *schemahcl.EvalOptions) error {
switch v := v.(type) {
case *schema.Realm:
var d specutil.Doc
if err := state.Eval(p, &d, input); err != nil {
if err := c.State.EvalOptions(p, &d, opts); err != nil {
return err
}
if err := specutil.Scan(v,
Expand All @@ -45,7 +55,7 @@ func evalSpec(state *schemahcl.State, p *hclparse.Parser, v any, input map[strin
}
case *schema.Schema:
var d specutil.Doc
if err := state.Eval(p, &d, input); err != nil {
if err := c.State.EvalOptions(p, &d, opts); err != nil {
return err
}
if len(d.Schemas) != 1 {
Expand All @@ -65,14 +75,14 @@ func evalSpec(state *schemahcl.State, p *hclparse.Parser, v any, input map[strin
case schema.Schema, schema.Realm:
return fmt.Errorf("mysql: Eval expects a pointer: received %[1]T, expected *%[1]T", v)
default:
return state.Eval(p, v, input)
return fmt.Errorf("mysql: unexpected type %T", v)
}
return nil
}

// MarshalSpec marshals v into an Atlas DDL document using a schemahcl.Marshaler.
func MarshalSpec(v any, marshaler schemahcl.Marshaler) ([]byte, error) {
return specutil.Marshal(v, marshaler, specutil.RealmFuncs{
func (c *Codec) MarshalSpec(v any) ([]byte, error) {
return specutil.Marshal(v, c.State, specutil.RealmFuncs{
Schema: schemaSpec,
Triggers: triggersSpec,
})
Expand All @@ -92,32 +102,30 @@ var (
schemahcl.WithScopedEnums("table.foreign_key.on_update", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.foreign_key.on_delete", specutil.ReferenceVars...),
}
hclState = schemahcl.New(
append(
specOptions,
sharedSpecOptions...,
)...,
)
mariaHCLState = schemahcl.New(
append(
mariaSpecOptions,
sharedSpecOptions...,
)...,
)
codec = &Codec{
State: schemahcl.New(
append(
specOptions,
sharedSpecOptions...,
)...,
),
}
mariaCodec = &Codec{
State: schemahcl.New(
append(
mariaSpecOptions,
sharedSpecOptions...,
)...,
),
}
// MarshalHCL marshals v into an Atlas HCL DDL document.
MarshalHCL = schemahcl.MarshalerFunc(func(v any) ([]byte, error) {
return MarshalSpec(v, hclState)
})
MarshalHCL = schemahcl.MarshalerFunc(codec.MarshalSpec)
// EvalHCL implements the schemahcl.Evaluator interface.
EvalHCL = schemahcl.EvalFunc(func(h *hclparse.Parser, v any, m map[string]cty.Value) error {
return evalSpec(hclState, h, v, m)
})
EvalHCL = schemahcl.EvalFunc(codec.Eval)
// EvalHCLBytes is a helper that evaluates an HCL document from a byte slice.
EvalHCLBytes = specutil.HCLBytesFunc(EvalHCL)
// EvalMariaHCL implements the schemahcl.Evaluator interface for MariaDB flavor.
EvalMariaHCL = schemahcl.EvalFunc(func(h *hclparse.Parser, v any, m map[string]cty.Value) error {
return evalSpec(mariaHCLState, h, v, m)
})
EvalMariaHCL = schemahcl.EvalFunc(mariaCodec.Eval)
// EvalMariaHCLBytes is a helper that evaluates a MariaDB HCL document from a byte slice.
EvalMariaHCLBytes = specutil.HCLBytesFunc(EvalMariaHCL)
)
Expand Down
16 changes: 8 additions & 8 deletions sql/mysql/sqlspec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func TestMarshalSpec_Charset(t *testing.T) {
}
s.Tables[0].Schema = s
s.Tables[1].Schema = s
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
// Charset and collate that are identical to their parent elements
// should not be printed as they are inherited by default from it.
Expand Down Expand Up @@ -458,7 +458,7 @@ func TestMarshalSpec_Comment(t *testing.T) {
},
},
}
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
// We expect a zero value comment to not be present in the marshaled HCL.
const expected = `table "users" {
Expand Down Expand Up @@ -507,7 +507,7 @@ func TestMarshalSpec_AutoIncrement(t *testing.T) {
},
}
s.Tables[0].Schema = s
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -536,7 +536,7 @@ func TestMarshalSpec_Check(t *testing.T) {
schema.NewCheck().SetExpr("price1 <> price2").AddAttrs(&Enforced{}),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "products" {
schema = schema.test
Expand Down Expand Up @@ -572,7 +572,7 @@ func TestMarshalSpec_TableEngine(t *testing.T) {
schema.NewTable("issues").AddAttrs(&Engine{V: "MYISAM"}).AddColumns(schema.NewIntColumn("id", TypeBigInt)),
schema.NewTable("commits").AddAttrs(&Engine{V: "MyRocks"}).AddColumns(schema.NewIntColumn("id", TypeBigInt)),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "repos" {
schema = schema.a8m
Expand Down Expand Up @@ -985,7 +985,7 @@ func TestMarshalSpec_TimePrecision(t *testing.T) {
schema.NewTimeColumn("tYear", TypeYear, schema.TimePrecision(2)),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "times" {
schema = schema.test
Expand Down Expand Up @@ -1036,7 +1036,7 @@ func TestMarshalSpec_GeneratedColumn(t *testing.T) {
SetGeneratedExpr(&schema.GeneratedExpr{Expr: "c3 * c4", Type: "STORED"}),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -1141,7 +1141,7 @@ func TestMarshalSpec_FloatUnsigned(t *testing.T) {
),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "test" {
schema = schema.test
Expand Down
53 changes: 31 additions & 22 deletions sql/sqlite/sqlspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,37 @@ type doc struct {
Schemas []*sqlspec.Schema `spec:"schema"`
}

// evalSpec evaluates an Atlas DDL document using an unmarshaler into v by using the input.
func evalSpec(p *hclparse.Parser, v any, input map[string]cty.Value) error {
// Codec for schemahcl.
type Codec struct {
State *schemahcl.State
}

// Eval evaluates an Atlas DDL document into v using the input.
func (c *Codec) Eval(p *hclparse.Parser, v any, input map[string]cty.Value) error {
return c.EvalOptions(p, v, &schemahcl.EvalOptions{Variables: input})
}

// EvalOptions decodes the HCL with the given options.
func (c *Codec) EvalOptions(p *hclparse.Parser, v any, opts *schemahcl.EvalOptions) error {
switch v := v.(type) {
case *schema.Realm:
var d doc
if err := hclState.Eval(p, &d, input); err != nil {
if err := c.State.EvalOptions(p, &d, opts); err != nil {
return err
}
if err := specutil.Scan(v,
&specutil.ScanDoc{Schemas: d.Schemas, Tables: d.Tables, Views: d.Views, Triggers: d.Triggers},
scanFuncs,
); err != nil {
return fmt.Errorf("specutil: failed converting to *schema.Realm: %w", err)
return fmt.Errorf("sqlite: failed converting to *schema.Realm: %w", err)
}
case *schema.Schema:
var d doc
if err := hclState.Eval(p, &d, input); err != nil {
if err := c.State.EvalOptions(p, &d, opts); err != nil {
return err
}
if len(d.Schemas) != 1 {
return fmt.Errorf("specutil: expecting document to contain a single schema, got %d", len(d.Schemas))
return fmt.Errorf("sqlite: expecting document to contain a single schema, got %d", len(d.Schemas))
}
r := &schema.Realm{}
if err := specutil.Scan(r,
Expand All @@ -59,14 +69,14 @@ func evalSpec(p *hclparse.Parser, v any, input map[string]cty.Value) error {
case schema.Schema, schema.Realm:
return fmt.Errorf("sqlite: Eval expects a pointer: received %[1]T, expected *%[1]T", v)
default:
return hclState.Eval(p, v, input)
return fmt.Errorf("sqlite: unexpected type %T", v)
}
return nil
}

// MarshalSpec marshals v into an Atlas DDL document using a schemahcl.Marshaler.
func MarshalSpec(v any, marshaler schemahcl.Marshaler) ([]byte, error) {
return specutil.Marshal(v, marshaler, specutil.RealmFuncs{
func (c *Codec) MarshalSpec(v any) ([]byte, error) {
return specutil.Marshal(v, c.State, specutil.RealmFuncs{
Schema: schemaSpec,
Triggers: triggersSpec,
})
Expand Down Expand Up @@ -285,21 +295,20 @@ var TypeRegistry = schemahcl.NewRegistry(
)

var (
hclState = schemahcl.New(append(
specOptions,
schemahcl.WithTypes("table.column.type", TypeRegistry.Specs()),
schemahcl.WithTypes("view.column.type", TypeRegistry.Specs()),
schemahcl.WithScopedEnums("table.column.as.type", stored, virtual),
schemahcl.WithScopedEnums("table.foreign_key.on_update", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.foreign_key.on_delete", specutil.ReferenceVars...),
)...)
codec = &Codec{
State: schemahcl.New(append(
specOptions,
schemahcl.WithTypes("table.column.type", TypeRegistry.Specs()),
schemahcl.WithTypes("view.column.type", TypeRegistry.Specs()),
schemahcl.WithScopedEnums("table.column.as.type", stored, virtual),
schemahcl.WithScopedEnums("table.foreign_key.on_update", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.foreign_key.on_delete", specutil.ReferenceVars...),
)...),
}
// MarshalHCL marshals v into an Atlas HCL DDL document.
MarshalHCL = schemahcl.MarshalerFunc(func(v any) ([]byte, error) {
return MarshalSpec(v, hclState)
})
MarshalHCL = schemahcl.MarshalerFunc(codec.MarshalSpec)
// EvalHCL implements the schemahcl.Evaluator interface.
EvalHCL = schemahcl.EvalFunc(evalSpec)

EvalHCL = schemahcl.EvalFunc(codec.Eval)
// EvalHCLBytes is a helper that evaluates an HCL document from a byte slice instead
// of from an hclparse.Parser instance.
EvalHCLBytes = specutil.HCLBytesFunc(EvalHCL)
Expand Down
8 changes: 4 additions & 4 deletions sql/sqlite/sqlspec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestMarshalSpec_AutoIncrement(t *testing.T) {
},
}
s.Tables[0].Schema = s
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -308,7 +308,7 @@ func TestMarshalSpec_IndexPredicate(t *testing.T) {
},
},
}
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -479,7 +479,7 @@ schema "test" {
require.NoError(t, err)
colspec := test.Tables[0].Columns[0]
require.EqualValues(t, tt.expected, colspec.Type.Type)
spec, err := MarshalSpec(&test, hclState)
spec, err := MarshalHCL(&test)
require.NoError(t, err)
var after schema.Schema
err = EvalHCLBytes(spec, &after, nil)
Expand All @@ -502,7 +502,7 @@ func TestMarshalSpec_TableOptions(t *testing.T) {
),
)
s.Tables[0].SetSchema(s)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down

0 comments on commit 7c6e629

Please sign in to comment.