Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql,tree: improve function resolution efficiency #89317

Merged
merged 2 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/sql/catalog/resolver/resolver_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func BenchmarkResolveExistingObject(b *testing.B) {
require.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
desc, _, err := resolver.ResolveExistingObject(ctx, rs, uon, tc.flags)
desc, _, err := resolver.ResolveExistingObject(ctx, rs, &uon, tc.flags)
require.NoError(b, err)
require.NotNil(b, desc)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/parser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -13665,7 +13665,6 @@ typed_literal:
// types, return an unimplemented error message.
var typ tree.ResolvableTypeReference
var ok bool
var err error
var unimp int
typ, ok, unimp = types.TypeForNonKeywordTypeName(typName)
if !ok {
Expand All @@ -13674,8 +13673,9 @@ typed_literal:
// In this case, we don't think this type is one of our
// known unsupported types, so make a type reference for it.
aIdx := sqllex.(*lexer).NewAnnotation()
typ, err = name.ToUnresolvedObjectName(aIdx)
un, err := name.ToUnresolvedObjectName(aIdx)
if err != nil { return setErr(sqllex, err) }
typ = &un
case -1:
return unimplemented(sqllex, "type name " + typName)
default:
Expand All @@ -13688,7 +13688,7 @@ typed_literal:
aIdx := sqllex.(*lexer).NewAnnotation()
res, err := name.ToUnresolvedObjectName(aIdx)
if err != nil { return setErr(sqllex, err) }
$$.val = &tree.CastExpr{Expr: tree.NewStrVal($2), Type: res, SyntaxMode: tree.CastPrepend}
$$.val = &tree.CastExpr{Expr: tree.NewStrVal($2), Type: &res, SyntaxMode: tree.CastPrepend}
}
}
| const_typename SCONST
Expand Down
19 changes: 10 additions & 9 deletions pkg/sql/schema_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ func (sr *schemaResolver) canResolveDescUnderSchema(
case catalog.SchemaUserDefined:
return sr.authAccessor.CheckPrivilegeForUser(ctx, scDesc, privilege.USAGE, sr.sessionDataStack.Top().User())
default:
panic(errors.AssertionFailedf("unknown schema kind %d", kind))
forLog := kind // prevents kind from escaping
panic(errors.AssertionFailedf("unknown schema kind %d", forLog))
}
}

Expand Down Expand Up @@ -400,23 +401,23 @@ func (sr *schemaResolver) ResolveFunction(
sc := prefix.Schema
udfDef, _ = sc.GetResolvedFuncDefinition(fn.Object())
} else {
if err := path.IterateSearchPath(func(schema string) error {
for i, n := 0, path.NumElements(); i < n; i++ {
schema := path.GetSchema(i)
found, prefix, err := sr.LookupSchema(ctx, sr.CurrentDatabase(), schema)
if err != nil {
return err
return nil, err
}
if !found {
return nil
continue
}
curUdfDef, found := prefix.Schema.GetResolvedFuncDefinition(fn.Object())
if !found {
return nil
continue
}

udfDef, err = udfDef.MergeWith(curUdfDef)
return err
}); err != nil {
return nil, err
if err != nil {
return nil, err
}
}
}

Expand Down
1 change: 0 additions & 1 deletion pkg/sql/sem/tree/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ go_library(
"//pkg/util/encoding",
"//pkg/util/errorutil/unimplemented",
"//pkg/util/ipaddr",
"//pkg/util/iterutil",
"//pkg/util/json",
"//pkg/util/pretty",
"//pkg/util/stringencoding",
Expand Down
27 changes: 10 additions & 17 deletions pkg/sql/sem/tree/function_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/iterutil"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)
Expand Down Expand Up @@ -292,15 +291,10 @@ func (fd *ResolvedFunctionDefinition) MatchOverload(
if explicitSchema != "" {
findMatches(explicitSchema)
} else {
err := searchPath.IterateSearchPath(func(schema string) error {
findMatches(schema)
if found {
return iterutil.StopIteration()
for i, n := 0, searchPath.NumElements(); i < n; i++ {
if findMatches(searchPath.GetSchema(i)); found {
break
}
return nil
})
if err != nil {
return QualifiedOverload{}, err
}
}

Expand Down Expand Up @@ -386,14 +380,15 @@ func QualifyBuiltinFunctionDefinition(
// GetBuiltinFuncDefinitionOrFail is similar to GetBuiltinFuncDefinition but
// returns an error if function is not found.
func GetBuiltinFuncDefinitionOrFail(
fName *FunctionName, searchPath SearchPath,
fName FunctionName, searchPath SearchPath,
) (*ResolvedFunctionDefinition, error) {
def, err := GetBuiltinFuncDefinition(fName, searchPath)
if err != nil {
return nil, err
}
if def == nil {
return nil, errors.Wrapf(ErrFunctionUndefined, "unknown function: %s()", ErrString(fName))
forError := fName // prevent fName from escaping
return nil, errors.Wrapf(ErrFunctionUndefined, "unknown function: %s()", ErrString(&forError))
}
return def, nil
}
Expand All @@ -409,7 +404,7 @@ func GetBuiltinFuncDefinitionOrFail(
// error is still checked and return from the function signature just in case
// we change the iterating function in the future.
func GetBuiltinFuncDefinition(
fName *FunctionName, searchPath SearchPath,
fName FunctionName, searchPath SearchPath,
) (*ResolvedFunctionDefinition, error) {
if fName.ExplicitSchema {
return ResolvedBuiltinFuncDefs[fName.Schema()+"."+fName.Object()], nil
Expand All @@ -429,15 +424,13 @@ func GetBuiltinFuncDefinition(

// If not in pg_catalog, go through search path.
var resolvedDef *ResolvedFunctionDefinition
if err := searchPath.IterateSearchPath(func(schema string) error {
for i, n := 0, searchPath.NumElements(); i < n; i++ {
schema := searchPath.GetSchema(i)
fullName := schema + "." + fName.Object()
if def, ok := ResolvedBuiltinFuncDefs[fullName]; ok {
resolvedDef = def
return iterutil.StopIteration()
break
}
return nil
}); err != nil {
return nil, err
}

return resolvedDef, nil
Expand Down
13 changes: 6 additions & 7 deletions pkg/sql/sem/tree/name_part.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,23 +218,22 @@ func MakeUnresolvedName(args ...string) UnresolvedName {
}

// ToUnresolvedObjectName converts an UnresolvedName to an UnresolvedObjectName.
func (u *UnresolvedName) ToUnresolvedObjectName(idx AnnotationIdx) (*UnresolvedObjectName, error) {
func (u *UnresolvedName) ToUnresolvedObjectName(idx AnnotationIdx) (UnresolvedObjectName, error) {
if u.NumParts == 4 {
return nil, pgerror.Newf(pgcode.Syntax, "improper qualified name (too many dotted names): %s", u)
return UnresolvedObjectName{}, pgerror.Newf(pgcode.Syntax, "improper qualified name (too many dotted names): %s", u)
}
return NewUnresolvedObjectName(
return MakeUnresolvedObjectName(
u.NumParts,
[3]string{u.Parts[0], u.Parts[1], u.Parts[2]},
idx,
)
}

// ToFunctionName converts an UnresolvedName to a FunctionName.
func (u *UnresolvedName) ToFunctionName() (*FunctionName, error) {
func (u *UnresolvedName) ToFunctionName() (FunctionName, error) {
un, err := u.ToUnresolvedObjectName(NoAnnotation)
if err != nil {
return nil, errors.Newf("invalid function name: %s", u.String())
return FunctionName{}, errors.Newf("invalid function name: %s", u.String())
}
fn := un.ToFunctionName()
return &fn, nil
return un.ToFunctionName(), nil
}
15 changes: 7 additions & 8 deletions pkg/sql/sem/tree/name_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,21 @@ type QualifiedNameResolver interface {
// SearchPath encapsulates the ordered list of schemas in the current database
// to search during name resolution.
type SearchPath interface {
// NumElements returns the number of elements in the SearchPath.
NumElements() int

// IterateSearchPath calls the passed function for every element of the
// SearchPath in order. If an error is returned, iteration stops. If the
// error is iterutil.StopIteration, no error will be returned from the
// method.
IterateSearchPath(func(schema string) error) error
// GetSchema returns the schema at the ord offset in the SearchPath.
// Note that it will return the empty string if the ordinal is out of range.
GetSchema(ord int) string
}

// EmptySearchPath is a SearchPath with no members.
var EmptySearchPath SearchPath = emptySearchPath{}

type emptySearchPath struct{}

func (emptySearchPath) IterateSearchPath(func(string) error) error {
return nil
}
func (emptySearchPath) NumElements() int { return 0 }
func (emptySearchPath) GetSchema(i int) string { return "" }

func newInvColRef(n *UnresolvedName) error {
return pgerror.NewWithDepthf(1, pgcode.InvalidColumnReference,
Expand Down
20 changes: 17 additions & 3 deletions pkg/sql/sem/tree/object_name.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,26 @@ func (*UnresolvedObjectName) tableExpr() {}
func NewUnresolvedObjectName(
numParts int, parts [3]string, annotationIdx AnnotationIdx,
) (*UnresolvedObjectName, error) {
u := &UnresolvedObjectName{
n, err := MakeUnresolvedObjectName(numParts, parts, annotationIdx)
if err != nil {
return nil, err
}
return &n, nil
}

// MakeUnresolvedObjectName creates an unresolved object name, verifying that it
// is well-formed.
func MakeUnresolvedObjectName(
numParts int, parts [3]string, annotationIdx AnnotationIdx,
) (UnresolvedObjectName, error) {
u := UnresolvedObjectName{
NumParts: numParts,
Parts: parts,
AnnotatedNode: AnnotatedNode{AnnIdx: annotationIdx},
}
if u.NumParts < 1 {
return nil, newInvTableNameError(u)
forErr := u // prevents u from escaping
return UnresolvedObjectName{}, newInvTableNameError(&forErr)
}

// Check that all the parts specified are not empty.
Expand All @@ -154,7 +167,8 @@ func NewUnresolvedObjectName(
}
for i := 0; i < lastCheck; i++ {
if len(u.Parts[i]) == 0 {
return nil, newInvTableNameError(u)
forErr := u // prevents u from escaping
return UnresolvedObjectName{}, newInvTableNameError(&forErr)
}
}
return u, nil
Expand Down
12 changes: 4 additions & 8 deletions pkg/sql/sem/tree/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/duration"
"github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented"
"github.com/cockroachdb/cockroach/pkg/util/iterutil"
"github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
Expand Down Expand Up @@ -2989,23 +2988,20 @@ func getMostSignificantOverload(

found := false
var ret QualifiedOverload
err := searchPath.IterateSearchPath(func(schema string) error {
for i, n := 0, searchPath.NumElements(); i < n; i++ {
schema := searchPath.GetSchema(i)
for i := range overloads {
if overloads[i].(QualifiedOverload).Schema == schema {
if found {
return ambiguousError()
return QualifiedOverload{}, ambiguousError()
}
found = true
ret = overloads[i].(QualifiedOverload)
}
}
if found {
return iterutil.StopIteration()
break
}
return nil
})
if err != nil {
return QualifiedOverload{}, err
}
if !found {
// This should never happen. Otherwise, it means we get function from a
Expand Down
1 change: 0 additions & 1 deletion pkg/sql/sessiondata/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ go_library(
"//pkg/sql/sem/catconstants",
"//pkg/sql/sessiondatapb",
"//pkg/util/duration",
"//pkg/util/iterutil",
"//pkg/util/syncutil",
"//pkg/util/timeutil",
"//pkg/util/timeutil/pgdate",
Expand Down
31 changes: 23 additions & 8 deletions pkg/sql/sessiondata/search_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants"
"github.com/cockroachdb/cockroach/pkg/util/iterutil"
)

// DefaultSearchPath is the search path used by virgin sessions.
Expand Down Expand Up @@ -292,15 +291,31 @@ func (iter *SearchPathIter) Next() (path string, ok bool) {
return "", false
}

// IterateSearchPath iterates the search path. If a non-nil error is
// returned, iteration is stopped. If iterutils.StopIteration() is returned
// from the iteration function, a nil error is returned to the caller.
func (s *SearchPath) IterateSearchPath(f func(schema string) error) error {
// NumElements returns the number of elements in the search path.
func (s *SearchPath) NumElements() int {
// TODO(ajwerner): Refactor this so that we don't need to do an O(N)
// operation to find the number of elements. In practice it doesn't matter
// much because search paths tend to be short.
iter := s.Iter()
var i int
for _, ok := iter.Next(); ok; _, ok = iter.Next() {
i++
}
return i
}

// GetSchema returns the ith schema element if it is in range.
func (s *SearchPath) GetSchema(ord int) string {
// TODO(ajwerner): Refactor this so that we don't need to do an O(n)
// operation to find the nth element. In practice it doesn't matter
// much because search paths tend to be short.
iter := s.Iter()
var i int
for schema, ok := iter.Next(); ok; schema, ok = iter.Next() {
if err := f(schema); err != nil {
return iterutil.Map(err)
if ord == i {
return schema
}
i++
}
return nil
return ""
}