Skip to content

Commit

Permalink
ast: Adding rego.metadata.* built-in functions (open-policy-agent#4537)
Browse files Browse the repository at this point in the history
New functions:
* rego.metadata.chain(): returns the chain of metadata, starting from the active rule, going outward
* rego.metadata.rule(): returns the metadata for the active rule

Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling authored Apr 13, 2022
1 parent 05dad98 commit c622662
Show file tree
Hide file tree
Showing 15 changed files with 1,709 additions and 256 deletions.
253 changes: 227 additions & 26 deletions ast/annotations.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ type (
}

AnnotationsRef struct {
Location *Location `json:"location"`
Path Ref `json:"path"`
Location *Location `json:"location"` // The location of the node the annotations are applied to
Path Ref `json:"path"` // The path of the node the annotations are applied to
Annotations *Annotations `json:"annotations,omitempty"`
node Node
node Node // The node the annotations are applied to
}
)

Expand All @@ -90,10 +90,22 @@ func (a *Annotations) SetLoc(l *Location) {
a.Location = l
}

// Compare returns an integer indicating if s is less than, equal to, or greater
// Compare returns an integer indicating if a is less than, equal to, or greater
// than other.
func (a *Annotations) Compare(other *Annotations) int {

if a == nil && other == nil {
return 0
}

if a == nil {
return -1
}

if other == nil {
return 1
}

if cmp := scopeCompare(a.Scope, other.Scope); cmp != 0 {
return cmp
}
Expand Down Expand Up @@ -141,6 +153,15 @@ func (a *Annotations) GetTargetPath() Ref {
}
}

func NewAnnotationsRef(a *Annotations) *AnnotationsRef {
return &AnnotationsRef{
Location: a.node.Loc(),
Path: a.GetTargetPath(),
Annotations: a,
node: a.node,
}
}

func (ar *AnnotationsRef) GetPackage() *Package {
switch n := ar.node.(type) {
case *Package:
Expand Down Expand Up @@ -287,6 +308,147 @@ func (a *Annotations) Copy(node Node) *Annotations {
return &cpy
}

// toObject constructs an AST Object from a.
func (a *Annotations) toObject() (*Object, *Error) {
obj := NewObject()

if a == nil {
return &obj, nil
}

if len(a.Scope) > 0 {
obj.Insert(StringTerm("scope"), StringTerm(a.Scope))
}

if len(a.Title) > 0 {
obj.Insert(StringTerm("title"), StringTerm(a.Title))
}

if len(a.Description) > 0 {
obj.Insert(StringTerm("description"), StringTerm(a.Description))
}

if len(a.Organizations) > 0 {
orgs := make([]*Term, 0, len(a.Organizations))
for _, org := range a.Organizations {
orgs = append(orgs, StringTerm(org))
}
obj.Insert(StringTerm("organizations"), ArrayTerm(orgs...))
}

if len(a.RelatedResources) > 0 {
rrs := make([]*Term, 0, len(a.RelatedResources))
for _, rr := range a.RelatedResources {
rrObj := NewObject(Item(StringTerm("ref"), StringTerm(rr.Ref.String())))
if len(rr.Description) > 0 {
rrObj.Insert(StringTerm("description"), StringTerm(rr.Description))
}
rrs = append(rrs, NewTerm(rrObj))
}
obj.Insert(StringTerm("related_resources"), ArrayTerm(rrs...))
}

if len(a.Authors) > 0 {
as := make([]*Term, 0, len(a.Authors))
for _, author := range a.Authors {
aObj := NewObject()
if len(author.Name) > 0 {
aObj.Insert(StringTerm("name"), StringTerm(author.Name))
}
if len(author.Email) > 0 {
aObj.Insert(StringTerm("email"), StringTerm(author.Email))
}
as = append(as, NewTerm(aObj))
}
obj.Insert(StringTerm("authors"), ArrayTerm(as...))
}

if len(a.Schemas) > 0 {
ss := make([]*Term, 0, len(a.Schemas))
for _, s := range a.Schemas {
sObj := NewObject()
if len(s.Path) > 0 {
sObj.Insert(StringTerm("path"), NewTerm(s.Path.toArray()))
}
if len(s.Schema) > 0 {
sObj.Insert(StringTerm("schema"), NewTerm(s.Schema.toArray()))
}
if s.Definition != nil {
def, err := InterfaceToValue(s.Definition)
if err != nil {
return nil, NewError(CompileErr, a.Location, "invalid definition in schema annotation: %s", err.Error())
}
sObj.Insert(StringTerm("definition"), NewTerm(def))
}
ss = append(ss, NewTerm(sObj))
}
obj.Insert(StringTerm("schemas"), ArrayTerm(ss...))
}

if len(a.Custom) > 0 {
c, err := InterfaceToValue(a.Custom)
if err != nil {
return nil, NewError(CompileErr, a.Location, "invalid custom annotation %s", err.Error())
}
obj.Insert(StringTerm("custom"), NewTerm(c))
}

return &obj, nil
}

func attachAnnotationsNodes(mod *Module) Errors {
var errs Errors

// Find first non-annotation statement following each annotation and attach
// the annotation to that statement.
for _, a := range mod.Annotations {
for _, stmt := range mod.stmts {
_, ok := stmt.(*Annotations)
if !ok {
if stmt.Loc().Row > a.Location.Row {
a.node = stmt
break
}
}
}

if a.Scope == "" {
switch a.node.(type) {
case *Rule:
a.Scope = annotationScopeRule
case *Package:
a.Scope = annotationScopePackage
case *Import:
a.Scope = annotationScopeImport
}
}

if err := validateAnnotationScopeAttachment(a); err != nil {
errs = append(errs, err)
}
}

return errs
}

func validateAnnotationScopeAttachment(a *Annotations) *Error {

switch a.Scope {
case annotationScopeRule, annotationScopeDocument:
if _, ok := a.node.(*Rule); ok {
return nil
}
return newScopeAttachmentErr(a, "rule")
case annotationScopePackage, annotationScopeSubpackages:
if _, ok := a.node.(*Package); ok {
return nil
}
return newScopeAttachmentErr(a, "package")
}

return NewError(ParseErr, a.Loc(), "invalid annotation scope '%v'", a.Scope)
}

// Copy returns a deep copy of a.
func (a *AuthorAnnotation) Copy() *AuthorAnnotation {
cpy := *a
Expand Down Expand Up @@ -481,32 +643,22 @@ func (as *AnnotationSet) Flatten() []*AnnotationsRef {

refs = as.byPath.flatten(refs)

for p, a := range as.byPackage {
refs = append(refs, &AnnotationsRef{
Location: p.Location,
Path: p.Path,
Annotations: a,
node: p,
})
for _, a := range as.byPackage {
refs = append(refs, NewAnnotationsRef(a))
}

for r, as := range as.byRule {
for _, as := range as.byRule {
for _, a := range as {
refs = append(refs, &AnnotationsRef{
Location: r.Location,
Path: r.Path(),
Annotations: a,
node: r,
})
refs = append(refs, NewAnnotationsRef(a))
}
}

// Sort by path, then location, for stable output
// Sort by path, then annotation location, for stable output
sort.SliceStable(refs, func(i, j int) bool {
if refs[i].Path.Compare(refs[j].Path) < 0 {
return true
}
if refs[i].Location.Compare(refs[j].Location) < 0 {
if refs[i].Annotations.Location.Compare(refs[j].Annotations.Location) < 0 {
return true
}
return false
Expand All @@ -515,6 +667,59 @@ func (as *AnnotationSet) Flatten() []*AnnotationsRef {
return refs
}

// Chain returns the chain of annotations leading up to the given rule.
// The returned slice is ordered as follows
// 0. Entries for the given rule, ordered from the METADATA block declared immediately above the rule, to the block declared farthest away (always at least one entry)
// 1. The 'document' scope entry, if any
// 2. The 'package' scope entry, if any
// 3. Entries for the 'subpackages' scope, if any; ordered from the closest package path to the fartest. E.g.: 'do.re.mi', 'do.re', 'do'
// The returned slice is guaranteed to always contain at least one entry, corresponding to the given rule.
func (as *AnnotationSet) Chain(rule *Rule) []*AnnotationsRef {
var refs []*AnnotationsRef

ruleAnnots := as.GetRuleScope(rule)

if len(ruleAnnots) >= 1 {
for _, a := range ruleAnnots {
refs = append(refs, NewAnnotationsRef(a))
}
} else {
// Make sure there is always a leading entry representing the passed rule, even if it has no annotations
refs = append(refs, &AnnotationsRef{
Location: rule.Location,
Path: rule.Path(),
node: rule,
})
}

if len(refs) > 1 {
// Sort by annotation location; chain must start with annotations declared closest to rule, then going outward
sort.SliceStable(refs, func(i, j int) bool {
return refs[i].Annotations.Location.Compare(refs[j].Annotations.Location) > 0
})
}

docAnnots := as.GetDocumentScope(rule.Path())
if docAnnots != nil {
refs = append(refs, NewAnnotationsRef(docAnnots))
}

pkg := rule.Module.Package
pkgAnnots := as.GetPackageScope(pkg)
if pkgAnnots != nil {
refs = append(refs, NewAnnotationsRef(pkgAnnots))
}

subPkgAnnots := as.GetSubpackagesScope(pkg.Path)
// We need to reverse the order, as subPkgAnnots ordering will start at the root,
// whereas we want to end at the root.
for i := len(subPkgAnnots) - 1; i >= 0; i-- {
refs = append(refs, NewAnnotationsRef(subPkgAnnots[i]))
}

return refs
}

func newAnnotationTree() *annotationTreeNode {
return &annotationTreeNode{
Value: nil,
Expand Down Expand Up @@ -550,6 +755,7 @@ func (t *annotationTreeNode) get(path Ref) *annotationTreeNode {
return node
}

// ancestors returns a slice of annotations in ascending order, starting with the root of ref; e.g.: 'root', 'root.foo', 'root.foo.bar'.
func (t *annotationTreeNode) ancestors(path Ref) (result []*Annotations) {
node := t
for _, k := range path {
Expand All @@ -570,12 +776,7 @@ func (t *annotationTreeNode) ancestors(path Ref) (result []*Annotations) {

func (t *annotationTreeNode) flatten(refs []*AnnotationsRef) []*AnnotationsRef {
if a := t.Value; a != nil {
refs = append(refs, &AnnotationsRef{
Location: a.Location,
Path: a.GetTargetPath(),
Annotations: a,
node: a.node,
})
refs = append(refs, NewAnnotationsRef(a))
}
for _, c := range t.Children {
refs = c.flatten(refs)
Expand Down
Loading

0 comments on commit c622662

Please sign in to comment.