Skip to content

Commit

Permalink
chore: add some helpers (#1017)
Browse files Browse the repository at this point in the history
This adds a few helpers that will be useful in implementing #972. I started on that issue, and while I didn't get far, I figured these would be useful. @acamadeo will be taking on that issue now, so I want them to have the helpers.
  • Loading branch information
noahdietz authored Oct 5, 2022
1 parent 2743c39 commit a800f77
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 0 deletions.
19 changes: 19 additions & 0 deletions rules/internal/utils/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,22 @@ func GetResourceReference(f *desc.FieldDescriptor) *apb.ResourceReference {
}
return nil
}

// FindResource returns first resource of type matching the reference param.
// resource Type name being referenced. It looks within a given file and its
// depenedncies, it cannot search within the entire protobuf package.
// This is especially useful for resolving google.api.resource_reference
// annotations.
func FindResource(reference string, file *desc.FileDescriptor) *apb.ResourceDescriptor {
files := append(file.GetDependencies(), file)
for _, f := range files {
for _, m := range f.GetMessageTypes() {
if r := GetResource(m); r != nil {
if r.GetType() == reference {
return r
}
}
}
}
return nil
}
61 changes: 61 additions & 0 deletions rules/internal/utils/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,64 @@ func TestGetResourceReference(t *testing.T) {
}
})
}

func TestFindResource(t *testing.T) {
files := testutils.ParseProtoStrings(t, map[string]string{
"book.proto": `
syntax = "proto3";
package test;
import "google/api/resource.proto";
message Book {
option (google.api.resource) = {
type: "library.googleapis.com/Book"
pattern: "publishers/{publisher}/books/{book}"
};
string name = 1;
}
`,
"shelf.proto": `
syntax = "proto3";
package test;
import "book.proto";
import "google/api/resource.proto";
message Shelf {
option (google.api.resource) = {
type: "library.googleapis.com/Shelf"
pattern: "shelves/{shelf}"
};
string name = 1;
repeated Book books = 2;
}
`,
})

for _, tst := range []struct {
name, reference string
notFound bool
}{
{"local_reference", "library.googleapis.com/Shelf", false},
{"imported_reference", "library.googleapis.com/Book", false},
{"unresolvable", "foo.googleapis.com/Bar", true},
} {
t.Run(tst.name, func(t *testing.T) {
got := FindResource(tst.reference, files["shelf.proto"])

if tst.notFound && got != nil {
t.Fatalf("Expected to not find the resource, but found %q", got.GetType())
}

if !tst.notFound && got == nil {
t.Errorf("Got nil, expected %q", tst.reference)
} else if !tst.notFound && got.GetType() != tst.reference {
t.Errorf("Got %q, expected %q", got.GetType(), tst.reference)
}
})
}
}
24 changes: 24 additions & 0 deletions rules/internal/utils/find.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,27 @@ func GetRepeatedMessageFields(m *desc.MessageDescriptor) []*desc.FieldDescriptor

return fields
}

// FindFieldDotNotation returns a field descriptor from a given message that
// corresponds to the dot separated path e.g. "book.name". If the path is
// unresolable the method returns nil. This is especially useful for resolving
// path variables in google.api.http and nested fields in
// google.api.method_signature annotations.
func FindFieldDotNotation(msg *desc.MessageDescriptor, ref string) *desc.FieldDescriptor {
path := strings.Split(ref, ".")
for _, seg := range path {
field := msg.FindFieldByName(seg)
if field == nil {
return nil
}

if m := field.GetMessageType(); m != nil {
msg = m
continue
}

return field
}

return nil
}
47 changes: 47 additions & 0 deletions rules/internal/utils/find_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package utils

import (
"strings"
"testing"

"github.com/googleapis/api-linter/rules/internal/testutils"
Expand Down Expand Up @@ -43,3 +44,49 @@ func TestFindMessage(t *testing.T) {
t.Errorf("Got Sctoll message, expected nil.")
}
}

func TestFindFieldDotNotation(t *testing.T) {
file := testutils.ParseProto3String(t, `
package test;
message CreateBookRequest {
string parent = 1;
Book book = 2;
}
message Book {
string name = 1;
message PublishingInfo {
string publisher = 1;
int32 edition = 2;
}
PublishingInfo publishing_info = 2;
}
`)
msg := file.GetMessageTypes()[0]

for _, tst := range []struct {
name, path string
}{
{"top_level", "parent"},
{"nested", "book.name"},
{"double_nested", "book.publishing_info.publisher"},
} {
t.Run(tst.name, func(t *testing.T) {
split := strings.Split(tst.path, ".")
want := split[len(split)-1]

f := FindFieldDotNotation(msg, tst.path)

if f == nil {
t.Errorf("Got nil, expected %q field", want)
} else if got := f.GetName(); got != want {
t.Errorf("Got %q, expected %q", got, want)
}
})
}

}
7 changes: 7 additions & 0 deletions rules/internal/utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ import (
"google.golang.org/protobuf/proto"
)

// HasHTTPRules returns true when the given method descriptor is annotated with
// a google.api.http option.
func HasHTTPRules(m *desc.MethodDescriptor) bool {
got := proto.GetExtension(m.GetMethodOptions(), apb.E_Http).(*apb.HttpRule)
return got != nil
}

// GetHTTPRules returns a slice of HTTP rules for a given method descriptor.
//
// Note: This returns a slice -- it takes the google.api.http annotation,
Expand Down
32 changes: 32 additions & 0 deletions rules/internal/utils/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,35 @@ func TestGetVariables(t *testing.T) {
})
}
}

func TestHasHTTPRules(t *testing.T) {
for _, tst := range []struct {
name string
Annotation string
}{
{"has_rule", `option (google.api.http) = {get: "/v1/foos"};`},
{"no_rule", ""},
} {
t.Run(tst.name, func(t *testing.T) {
file := testutils.ParseProto3Tmpl(t, `
import "google/api/annotations.proto";
service Foo {
rpc ListFoos (ListFoosRequest) returns (ListFoosResponse) {
{{ .Annotation }}
}
}
message ListFoosRequest {}
message ListFoosResponse {}
`, tst)
want := tst.Annotation != ""

got := HasHTTPRules(file.GetServices()[0].GetMethods()[0])

if got != want {
t.Errorf("Got %v, expected %v", got, want)
}
})
}
}

0 comments on commit a800f77

Please sign in to comment.