Skip to content

Commit

Permalink
schemahcl: support passing decode-level options
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Oct 27, 2024
1 parent f490381 commit c51643e
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 60 deletions.
66 changes: 48 additions & 18 deletions schemahcl/schemahcl.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,20 @@ func (s *State) MarshalSpec(v any) ([]byte, error) {
return s.encode(r)
}

// EvalOptions configures the evaluation of HCL documents.
type EvalOptions struct {
// Variables is a map of input variables to be used during evaluation.
Variables map[string]cty.Value

// RecordPos indicates whether to record the source code positions of the
// evaluated resources and attributes. It defaults to the State (Driver) config.
RecordPos bool

// Validator is the schema validator to be used during evaluation.
// It defaults to the State (Driver) config.
Validator SchemaValidator
}

// EvalFiles evaluates the files in the provided paths using the input variables and
// populates v with the result.
func (s *State) EvalFiles(paths []string, v any, input map[string]cty.Value) error {
Expand All @@ -318,6 +332,13 @@ func (s *State) EvalFiles(paths []string, v any, input map[string]cty.Value) err
// Eval evaluates the parsed HCL documents using the input variables and populates v
// using the result.
func (s *State) Eval(parsed *hclparse.Parser, v any, input map[string]cty.Value) error {
return s.EvalOptions(parsed, v, &EvalOptions{
Variables: input,
})
}

// EvalOptions evaluates the parsed HCL documents and populates v using the result.
func (s *State) EvalOptions(parsed *hclparse.Parser, v any, opts *EvalOptions) error {
var (
hasVars bool
ctx = s.newCtx()
Expand All @@ -334,7 +355,7 @@ func (s *State) Eval(parsed *hclparse.Parser, v any, input map[string]cty.Value)
}
for name, file := range files {
fileNames = append(fileNames, name)
if err := s.setInputVals(ctx, file.Body, input); err != nil {
if err := s.setInputVals(ctx, file.Body, opts.Variables); err != nil {
return err
}
body := file.Body.(*hclsyntax.Body)
Expand Down Expand Up @@ -400,13 +421,20 @@ func (s *State) Eval(parsed *hclparse.Parser, v any, input map[string]cty.Value)
sort.Slice(fileNames, func(i, j int) bool {
return fileNames[i] < fileNames[j]
})
vr := SchemaValidator(&nopValidator{})
if s.config.validator != nil {
var vr SchemaValidator
switch {
case opts.Validator != nil:
vr = opts.Validator
case s.config.validator != nil:
vr = s.config.validator()
opts.Validator = vr
default:
vr = &nopValidator{}
opts.Validator = vr
}
for _, name := range fileNames {
file := files[name]
r, err := s.resource(ctx, vr, file, reg)
r, err := s.resource(ctx, opts, file, reg)
if err != nil {
return err
}
Expand Down Expand Up @@ -492,16 +520,16 @@ func (r addrRef) load(res *Resource, track string) addrRef {
}

// resource converts the hcl file to a schemahcl.Resource.
func (s *State) resource(ctx *hcl.EvalContext, vr SchemaValidator, file *hcl.File, dec *blockDef) (*Resource, error) {
func (s *State) resource(ctx *hcl.EvalContext, opts *EvalOptions, file *hcl.File, dec *blockDef) (*Resource, error) {
body, ok := file.Body.(*hclsyntax.Body)
if !ok {
return nil, fmt.Errorf("schemahcl: expected remainder to be of type *hclsyntax.Body")
}
closeScope, err := vr.ValidateBody(ctx, body)
closeScope, err := opts.Validator.ValidateBody(ctx, body)
if err != nil {
return nil, err
}
attrs, err := s.toAttrs(ctx, vr, body.Attributes, nil)
attrs, err := s.toAttrs(ctx, opts, body.Attributes, nil)
if err != nil {
return nil, err
}
Expand All @@ -519,7 +547,7 @@ func (s *State) resource(ctx *hcl.EvalContext, vr SchemaValidator, file *hcl.Fil
if err != nil {
return nil, err
}
resource, err := s.toResource(ctx, vr, blk, []string{blk.Type}, dec.children[blk.Type])
resource, err := s.toResource(ctx, opts, blk, []string{blk.Type}, dec.children[blk.Type])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -576,7 +604,7 @@ func (s *State) mayScopeContext(ctx *hcl.EvalContext, scope []string) *hcl.EvalC
return nctx.NewChild()
}

func (s *State) toAttrs(ctx *hcl.EvalContext, vr SchemaValidator, hclAttrs hclsyntax.Attributes, scope []string) ([]*Attr, error) {
func (s *State) toAttrs(ctx *hcl.EvalContext, opts *EvalOptions, hclAttrs hclsyntax.Attributes, scope []string) ([]*Attr, error) {
attrs := make([]*Attr, 0, len(hclAttrs))
for _, hclAttr := range hclAttrs {
var (
Expand All @@ -591,11 +619,11 @@ func (s *State) toAttrs(ctx *hcl.EvalContext, vr SchemaValidator, hclAttrs hclsy
if value.IsNull() {
continue
}
if err := vr.ValidateAttribute(ctx, hclAttr, value); err != nil {
if err := opts.Validator.ValidateAttribute(ctx, hclAttr, value); err != nil {
return nil, err
}
at := &Attr{K: hclAttr.Name}
if s.config.withPos {
if s.config.withPos || opts.RecordPos {
at.SetRange(&hclAttr.SrcRange)
}
switch t := value.Type(); {
Expand Down Expand Up @@ -703,13 +731,13 @@ func isOneRef(v cty.Value) bool {
return t.IsObjectType() && t.HasAttribute("__ref")
}

func (s *State) toResource(ctx *hcl.EvalContext, vr SchemaValidator, block *hclsyntax.Block, scope []string, dec *blockDef) (spec *Resource, err error) {
closeScope, err := vr.ValidateBlock(ctx, block)
func (s *State) toResource(ctx *hcl.EvalContext, opts *EvalOptions, block *hclsyntax.Block, scope []string, dec *blockDef) (spec *Resource, err error) {
closeScope, err := opts.Validator.ValidateBlock(ctx, block)
if err != nil {
return nil, err
}
spec = &Resource{Type: block.Type}
if s.config.withPos {
if s.config.withPos || opts.RecordPos {
spec.SetRange(&block.TypeRange)
}
switch len(block.Labels) {
Expand All @@ -723,7 +751,7 @@ func (s *State) toResource(ctx *hcl.EvalContext, vr SchemaValidator, block *hcls
return nil, fmt.Errorf("too many labels for block: %s", block.Labels)
}
ctx = s.mayScopeContext(ctx, scope)
attrs, err := s.toAttrs(ctx, vr, block.Body.Attributes, scope)
attrs, err := s.toAttrs(ctx, opts, block.Body.Attributes, scope)
if err != nil {
return nil, err
}
Expand All @@ -734,7 +762,7 @@ func (s *State) toResource(ctx *hcl.EvalContext, vr SchemaValidator, block *hcls
if err != nil {
return nil, err
}
r, err := s.toResource(ctx, vr, blk, append(scope, blk.Type), cdec)
r, err := s.toResource(ctx, opts, blk, append(scope, blk.Type), cdec)
if err != nil {
return nil, err
}
Expand All @@ -749,8 +777,10 @@ func (s *State) toResource(ctx *hcl.EvalContext, vr SchemaValidator, block *hcls
// encode the given *schemahcl.Resource into a byte slice containing an Atlas HCL
// document representing it.
func (s *State) encode(r *Resource) ([]byte, error) {
f := hclwrite.NewFile()
body := f.Body()
var (
f = hclwrite.NewFile()
body = f.Body()
)
// If the resource has a Type then it is rendered as an HCL block.
if r.Type != "" {
blk := body.AppendNewBlock(r.Type, labels(r))
Expand Down
2 changes: 1 addition & 1 deletion sql/postgres/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func init() {
sqlclient.OpenerFunc(opener),
sqlclient.RegisterDriverOpener(Open),
sqlclient.RegisterFlavours("postgresql"),
sqlclient.RegisterCodec(MarshalHCL, EvalHCL),
sqlclient.RegisterCodec(codec, codec),
sqlclient.RegisterURLParser(parser{}),
)
}
Expand Down
69 changes: 39 additions & 30 deletions sql/postgres/sqlspec_oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,22 @@ func init() {
schemahcl.Register("event_trigger", &eventTrigger{})
}

// evalSpec evaluates an Atlas DDL document into v using the input.
func evalSpec(p *hclparse.Parser, v any, input map[string]cty.Value) error {
// Codec for schemahcl.
type Codec struct {
State *schemahcl.State
}

// Eval evaluates an Atlas DDL document into v using the input.
func (c *Codec) Eval(p *hclparse.Parser, v any, input map[string]cty.Value) error {
return c.EvalOptions(p, v, &schemahcl.EvalOptions{Variables: input})
}

// EvalOptions decodes the HCL with the given options.
func (c *Codec) EvalOptions(p *hclparse.Parser, v any, opts *schemahcl.EvalOptions) error {
switch v := v.(type) {
case *schema.Realm:
var d doc
if err := hclState.Eval(p, &d, input); err != nil {
if err := c.State.EvalOptions(p, &d, opts); err != nil {
return err
}
if err := specutil.Scan(v, d.ScanDoc(), scanFuncs); err != nil {
Expand Down Expand Up @@ -242,7 +252,7 @@ func evalSpec(p *hclparse.Parser, v any, input map[string]cty.Value) error {
}
case *schema.Schema:
var d doc
if err := hclState.Eval(p, &d, input); err != nil {
if err := c.State.EvalOptions(p, &d, opts); err != nil {
return err
}
if len(d.Schemas) != 1 {
Expand Down Expand Up @@ -278,7 +288,7 @@ func evalSpec(p *hclparse.Parser, v any, input map[string]cty.Value) error {
}

// MarshalSpec marshals v into an Atlas DDL document using a schemahcl.Marshaler.
func MarshalSpec(v any, marshaler schemahcl.Marshaler) ([]byte, error) {
func (c *Codec) MarshalSpec(v any) ([]byte, error) {
var (
d doc
ts []*schema.Trigger
Expand Down Expand Up @@ -345,38 +355,37 @@ func MarshalSpec(v any, marshaler schemahcl.Marshaler) ([]byte, error) {
if err := triggersSpec(ts, &d); err != nil {
return nil, err
}
return marshaler.MarshalSpec(&d)
return c.State.MarshalSpec(&d)
}

var (
hclState = schemahcl.New(append(specOptions,
schemahcl.WithTypes("table.column.type", TypeRegistry.Specs()),
schemahcl.WithTypes("view.column.type", TypeRegistry.Specs()),
schemahcl.WithTypes("materialized.column.type", TypeRegistry.Specs()),
schemahcl.WithScopedEnums("view.check_option", schema.ViewCheckOptionLocal, schema.ViewCheckOptionCascaded),
schemahcl.WithScopedEnums("table.index.type", IndexTypeBTree, IndexTypeBRIN, IndexTypeHash, IndexTypeGIN, IndexTypeGiST, "GiST", IndexTypeSPGiST, "SPGiST"),
schemahcl.WithScopedEnums("table.partition.type", PartitionTypeRange, PartitionTypeList, PartitionTypeHash),
schemahcl.WithScopedEnums("table.column.identity.generated", GeneratedTypeAlways, GeneratedTypeByDefault),
schemahcl.WithScopedEnums("table.column.as.type", "STORED"),
schemahcl.WithScopedEnums("table.foreign_key.on_update", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.foreign_key.on_delete", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.index.on.ops", func() (ops []string) {
for _, op := range postgresop.Classes {
ops = append(ops, op.Name)
}
return ops
}()...))...,
)
codec = &Codec{
State: schemahcl.New(append(specOptions,
schemahcl.WithTypes("table.column.type", TypeRegistry.Specs()),
schemahcl.WithTypes("view.column.type", TypeRegistry.Specs()),
schemahcl.WithTypes("materialized.column.type", TypeRegistry.Specs()),
schemahcl.WithScopedEnums("view.check_option", schema.ViewCheckOptionLocal, schema.ViewCheckOptionCascaded),
schemahcl.WithScopedEnums("table.index.type", IndexTypeBTree, IndexTypeBRIN, IndexTypeHash, IndexTypeGIN, IndexTypeGiST, "GiST", IndexTypeSPGiST, "SPGiST"),
schemahcl.WithScopedEnums("table.partition.type", PartitionTypeRange, PartitionTypeList, PartitionTypeHash),
schemahcl.WithScopedEnums("table.column.identity.generated", GeneratedTypeAlways, GeneratedTypeByDefault),
schemahcl.WithScopedEnums("table.column.as.type", "STORED"),
schemahcl.WithScopedEnums("table.foreign_key.on_update", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.foreign_key.on_delete", specutil.ReferenceVars...),
schemahcl.WithScopedEnums("table.index.on.ops", func() (ops []string) {
for _, op := range postgresop.Classes {
ops = append(ops, op.Name)
}
return ops
}()...))...,
),
}
// MarshalHCL marshals v into an Atlas HCL DDL document.
MarshalHCL = schemahcl.MarshalerFunc(func(v any) ([]byte, error) {
return MarshalSpec(v, hclState)
})
MarshalHCL = schemahcl.MarshalerFunc(codec.MarshalSpec)
// EvalHCL implements the schemahcl.Evaluator interface.
EvalHCL = schemahcl.EvalFunc(evalSpec)

EvalHCL = schemahcl.EvalFunc(codec.Eval)
// EvalHCLBytes is a helper that evaluates an HCL document from a byte slice instead
// of from an hclparse.Parser instance.
EvalHCLBytes = specutil.HCLBytesFunc(EvalHCL)
EvalHCLBytes = specutil.HCLBytesFunc(codec)
)

// convertTable converts a sqlspec.Table to a schema.Table. Table conversion is done without converting
Expand Down
22 changes: 11 additions & 11 deletions sql/postgres/sqlspec_oss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ func TestMarshalSpec_IndexPredicate(t *testing.T) {
},
},
}
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -952,7 +952,7 @@ func TestMarshalSpec_IndexNullsDistinct(t *testing.T) {
AddAttrs(&IndexNullsDistinct{V: false}),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.public
Expand Down Expand Up @@ -1000,7 +1000,7 @@ func TestMarshalSpec_IndexNullsLastFirst(t *testing.T) {
AddParts(schema.NewColumnPart(schema.NewColumn("c")).SetDesc(false).AddAttrs(&IndexColumnProperty{NullsFirst: true})),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.public
Expand Down Expand Up @@ -1040,7 +1040,7 @@ schema "public" {
err = EvalHCLBytes([]byte(expected), &got, nil)
require.NoError(t, err)

buf, err = MarshalSpec(s, hclState)
buf, err = MarshalHCL(s)
require.NoError(t, err)
require.EqualValues(t, expected, string(buf))
}
Expand Down Expand Up @@ -1076,7 +1076,7 @@ func TestMarshalSpec_BRINIndex(t *testing.T) {
},
},
}
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -1185,7 +1185,7 @@ func TestMarshalSpec_IndexOpClass(t *testing.T) {
},
},
}
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -1346,7 +1346,7 @@ func TestMarshalSpec_IndexInclude(t *testing.T) {
},
},
}
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -1382,7 +1382,7 @@ func TestMarshalSpec_PrimaryKey(t *testing.T) {
schema.NewPrimaryKey(s.Tables[0].Columns[:1]...).
AddAttrs(&IndexInclude{Columns: s.Tables[0].Columns[1:]}),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -1452,7 +1452,7 @@ func TestMarshalSpec_GeneratedColumn(t *testing.T) {
SetGeneratedExpr(&schema.GeneratedExpr{Expr: "c3 * c4", Type: "STORED"}),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "users" {
schema = schema.test
Expand Down Expand Up @@ -1553,7 +1553,7 @@ func TestMarshalSpec_Enum(t *testing.T) {
SetType(typeE),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "account" {
schema = schema.test
Expand Down Expand Up @@ -1599,7 +1599,7 @@ func TestMarshalSpec_TimePrecision(t *testing.T) {
schema.NewTimeColumn("t_timestamptz", TypeTimestampTZ, schema.TimePrecision(2)),
),
)
buf, err := MarshalSpec(s, hclState)
buf, err := MarshalHCL(s)
require.NoError(t, err)
const expected = `table "times" {
schema = schema.test
Expand Down

0 comments on commit c51643e

Please sign in to comment.