Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format: Bracketing keyword ref elements in formatter output #7010

Merged
18 changes: 16 additions & 2 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ var Wildcard = &Term{Value: Var("_")}
var WildcardPrefix = "$"

// Keywords contains strings that map to language keywords.
var Keywords = KeywordsV0
var Keywords = KeywordsForRegoVersion(DefaultRegoVersion)

var KeywordsV0 = [...]string{
"not",
Expand Down Expand Up @@ -134,9 +134,23 @@ var KeywordsV1 = [...]string{
"every",
}

func KeywordsForRegoVersion(v RegoVersion) []string {
switch v {
case RegoV0:
return KeywordsV0[:]
case RegoV1, RegoV0CompatV1:
return KeywordsV1[:]
}
return nil
}

// IsKeyword returns true if s is a language keyword.
func IsKeyword(s string) bool {
for _, x := range Keywords {
return IsInKeywords(s, Keywords)
}

func IsInKeywords(s string, keywords []string) bool {
for _, x := range keywords {
if x == s {
return true
}
Expand Down
2 changes: 1 addition & 1 deletion compile/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1741,7 +1741,7 @@ contains := 2 {
import rego.v1
p if {
data.foo.contains = input.x
data.foo["contains"] = input.x
}
`,
`package foo
Expand Down
105 changes: 77 additions & 28 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,16 @@ type fmtOpts struct {
// than if they don't.
refHeads bool

regoV1 bool
regoV1 bool
futureKeywords []string
}

func (o fmtOpts) keywords() []string {
if o.regoV1 {
return ast.KeywordsV1[:]
}
kws := ast.KeywordsV0[:]
return append(kws, o.futureKeywords...)
}

func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
Expand Down Expand Up @@ -171,6 +180,10 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
}

case *ast.Import:
if kw, ok := future.WhichFutureKeyword(n); ok {
o.futureKeywords = append(o.futureKeywords, kw)
}

switch {
case isRegoV1Compatible(n):
o.contains = true
Expand Down Expand Up @@ -200,8 +213,9 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
})

w := &writer{
indent: "\t",
errs: make([]*ast.Error, 0),
indent: "\t",
errs: make([]*ast.Error, 0),
fmtOpts: o,
}

switch x := x.(type) {
Expand All @@ -219,18 +233,17 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
x.Imports = ensureFutureKeywordImport(x.Imports, kw)
}
}
w.writeModule(x, o)
w.writeModule(x)
case *ast.Package:
w.writePackage(x, nil)
case *ast.Import:
w.writeImports([]*ast.Import{x}, nil)
case *ast.Rule:
w.writeRule(x, false /* isElse */, o, nil)
w.writeRule(x, false /* isElse */, nil)
case *ast.Head:
w.writeHead(x,
false, // isDefault
false, // isExpandedConst
o,
nil)
case ast.Body:
w.writeBody(x, nil)
Expand Down Expand Up @@ -302,9 +315,10 @@ type writer struct {
beforeEnd *ast.Comment
delay bool
errs ast.Errors
fmtOpts fmtOpts
}

func (w *writer) writeModule(module *ast.Module, o fmtOpts) {
func (w *writer) writeModule(module *ast.Module) {
var pkg *ast.Package
var others []interface{}
var comments []*ast.Comment
Expand Down Expand Up @@ -342,7 +356,7 @@ func (w *writer) writeModule(module *ast.Module, o fmtOpts) {
imports, others = gatherImports(others)
comments = w.writeImports(imports, comments)
rules, others = gatherRules(others)
comments = w.writeRules(rules, o, comments)
comments = w.writeRules(rules, comments)
}

for i, c := range comments {
Expand All @@ -365,7 +379,15 @@ func (w *writer) writePackage(pkg *ast.Package, comments []*ast.Comment) []*ast.
comments = w.insertComments(comments, pkg.Location)

w.startLine()
w.write(pkg.String())

// Omit head as all packages have the DefaultRootDocument prepended at parse time.
path := make(ast.Ref, len(pkg.Path)-1)
path[0] = ast.VarTerm(string(pkg.Path[1].Value.(ast.String)))
copy(path[1:], pkg.Path[2:])

w.write("package ")
w.writeRef(path)

w.blankLine()

return comments
Expand All @@ -380,16 +402,16 @@ func (w *writer) writeComments(comments []*ast.Comment) {
}
}

func (w *writer) writeRules(rules []*ast.Rule, o fmtOpts, comments []*ast.Comment) []*ast.Comment {
func (w *writer) writeRules(rules []*ast.Rule, comments []*ast.Comment) []*ast.Comment {
for _, rule := range rules {
comments = w.insertComments(comments, rule.Location)
comments = w.writeRule(rule, false, o, comments)
comments = w.writeRule(rule, false, comments)
w.blankLine()
}
return comments
}

func (w *writer) writeRule(rule *ast.Rule, isElse bool, o fmtOpts, comments []*ast.Comment) []*ast.Comment {
func (w *writer) writeRule(rule *ast.Rule, isElse bool, comments []*ast.Comment) []*ast.Comment {
if rule == nil {
return comments
}
Expand All @@ -408,25 +430,25 @@ func (w *writer) writeRule(rule *ast.Rule, isElse bool, o fmtOpts, comments []*a
// pretend that the rule has no body in this case.
isExpandedConst := rule.Body.Equal(ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))) && rule.Else == nil

comments = w.writeHead(rule.Head, rule.Default, isExpandedConst, o, comments)
comments = w.writeHead(rule.Head, rule.Default, isExpandedConst, comments)

// this excludes partial sets UNLESS `contains` is used
partialSetException := o.contains || rule.Head.Value != nil
partialSetException := w.fmtOpts.contains || rule.Head.Value != nil

if len(rule.Body) == 0 || isExpandedConst {
w.endLine()
return comments
}

if (o.regoV1 || o.ifs) && partialSetException {
if (w.fmtOpts.regoV1 || w.fmtOpts.ifs) && partialSetException {
w.write(" if")
if len(rule.Body) == 1 {
if rule.Body[0].Location.Row == rule.Head.Location.Row {
w.write(" ")
comments = w.writeExpr(rule.Body[0], comments)
w.endLine()
if rule.Else != nil {
comments = w.writeElse(rule, o, comments)
comments = w.writeElse(rule, comments)
}
return comments
}
Expand Down Expand Up @@ -454,12 +476,12 @@ func (w *writer) writeRule(rule *ast.Rule, isElse bool, o fmtOpts, comments []*a
w.startLine()
w.write("}")
if rule.Else != nil {
comments = w.writeElse(rule, o, comments)
comments = w.writeElse(rule, comments)
}
return comments
}

func (w *writer) writeElse(rule *ast.Rule, o fmtOpts, comments []*ast.Comment) []*ast.Comment {
func (w *writer) writeElse(rule *ast.Rule, comments []*ast.Comment) []*ast.Comment {
// If there was nothing else on the line before the "else" starts
// then preserve this style of else block, otherwise it will be
// started as an "inline" else eg:
Expand Down Expand Up @@ -521,16 +543,16 @@ func (w *writer) writeElse(rule *ast.Rule, o fmtOpts, comments []*ast.Comment) [
rule.Else.Head.Value.Location = rule.Else.Head.Location
}

return w.writeRule(rule.Else, true, o, comments)
return w.writeRule(rule.Else, true, comments)
}

func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fmtOpts, comments []*ast.Comment) []*ast.Comment {
func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, comments []*ast.Comment) []*ast.Comment {
ref := head.Ref()
if head.Key != nil && head.Value == nil && !head.HasDynamicRef() {
ref = ref.GroundPrefix()
}
if o.refHeads || len(ref) == 1 {
w.write(ref.String())
if w.fmtOpts.refHeads || len(ref) == 1 {
w.writeRef(ref)
} else {
w.write(ref[0].String())
w.write("[")
Expand All @@ -548,7 +570,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fm
w.write(")")
}
if head.Key != nil {
if o.contains && head.Value == nil {
if w.fmtOpts.contains && head.Value == nil {
w.write(" contains ")
comments = w.writeTerm(head.Key, comments)
} else if head.Value == nil { // no `if` for p[x] notation
Expand All @@ -566,7 +588,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fm
// * a.b -> a contains "b"
// * a.b.c -> a.b.c := true
// * a.b.c.d -> a.b.c.d := true
isRegoV1RefConst := o.regoV1 && isExpandedConst && head.Key == nil && len(head.Args) == 0
isRegoV1RefConst := w.fmtOpts.regoV1 && isExpandedConst && head.Key == nil && len(head.Args) == 0

if head.Location == head.Value.Location &&
head.Name != "else" &&
Expand All @@ -578,7 +600,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fm
return comments
}

if head.Assign || o.regoV1 {
if head.Assign || w.fmtOpts.regoV1 {
// preserve assignment operator, and enforce it if formatting for Rego v1
w.write(" := ")
} else {
Expand Down Expand Up @@ -856,7 +878,7 @@ var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$")

func (w *writer) writeRefStringPath(s ast.String) {
str := string(s)
if varRegexp.MatchString(str) && !ast.IsKeyword(str) {
if varRegexp.MatchString(str) && !ast.IsInKeywords(str, w.fmtOpts.keywords()) {
w.write("." + str)
} else {
w.writeBracketed(s.String())
Expand Down Expand Up @@ -1067,7 +1089,7 @@ func (w *writer) writeImports(imports []*ast.Import, comments []*ast.Comment) []
})
for _, i := range group {
w.startLine()
w.write(i.String())
w.writeImport(i)
if c, ok := m[i]; ok {
w.write(" " + c.String())
}
Expand All @@ -1079,6 +1101,28 @@ func (w *writer) writeImports(imports []*ast.Import, comments []*ast.Comment) []
return comments
}

func (w *writer) writeImport(imp *ast.Import) {
path := imp.Path.Value.(ast.Ref)

buf := []string{"import"}

if _, ok := future.WhichFutureKeyword(imp); ok {
// We don't want to wrap future.keywords imports in parens, so we create a new writer that doesn't
w2 := writer{
buf: bytes.Buffer{},
}
w2.writeRef(path)
buf = append(buf, w2.buf.String())
} else {
buf = append(buf, path.String())
}

if len(imp.Alias) > 0 {
buf = append(buf, "as "+imp.Alias.String())
}
w.write(strings.Join(buf, " "))
}

type entryWriter func(interface{}, []*ast.Comment) []*ast.Comment

func (w *writer) writeIterable(elements []interface{}, last *ast.Location, close *ast.Location, comments []*ast.Comment, fn entryWriter) []*ast.Comment {
Expand Down Expand Up @@ -1505,7 +1549,12 @@ func ensureFutureKeywordImport(imps []*ast.Import, kw string) []*ast.Import {
}
}
imp := &ast.Import{
Path: ast.MustParseTerm("future.keywords." + kw),
// NOTE: This is a hack to not error on the ref containing a keyword already present in v1.
// A cleaner solution would be to instead allow refs to contain keyword terms.
// E.g. in v1, `import future.keywords["in"]` is valid, but `import future.keywords.in` is not
// as it contains a reserved keyword.
Path: ast.MustParseTerm("future.keywords[\"" + kw + "\"]"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding an example in the comment would be helpful.

//Path: ast.MustParseTerm("future.keywords." + kw),
}
imp.Location = defaultLocation(imp)
return append(imps, imp)
Expand Down
Loading