Skip to content

Commit

Permalink
tree: improve type-checking for placeholders with ambiguous type
Browse files Browse the repository at this point in the history
The key fix is to change the typeCheckSplitExprs function so that it marks _all_
placeholder indexes. This then causes the existing type-checking logic
in typeCheckOverloadedExprs to check all placeholder expressions, rather
than just ones that don't have type hints.

Release note (bug fix): Prepared statements that use type hints can now
succeed type-checking in more cases when the placeholder type is
ambiguous.
  • Loading branch information
rafiss committed Dec 6, 2022
1 parent f8bf34c commit 33ead8e
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 20 deletions.
73 changes: 73 additions & 0 deletions pkg/sql/pgwire/testdata/pgtest/collated_string
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
send
Query {"String": "DROP TABLE IF EXISTS collated_string_table"}
Query {"String": "CREATE TABLE collated_string_table (id UUID PRIMARY KEY, email TEXT COLLATE \"en-US-u-ks-level2\" NOT NULL)"}
Query {"String": "CREATE UNIQUE INDEX ON collated_string_table(email)"}
----

until ignore=NoticeResponse
ReadyForQuery
ReadyForQuery
ReadyForQuery
----
{"Type":"CommandComplete","CommandTag":"DROP TABLE"}
{"Type":"ReadyForQuery","TxStatus":"I"}
{"Type":"CommandComplete","CommandTag":"CREATE TABLE"}
{"Type":"ReadyForQuery","TxStatus":"I"}
{"Type":"CommandComplete","CommandTag":"CREATE INDEX"}
{"Type":"ReadyForQuery","TxStatus":"I"}

send
Parse {"Query": "INSERT INTO collated_string_table (email,id) VALUES ($1,$2)", "Name": "insert_0"}
Describe {"Name": "insert_0", "ObjectType": "S"}
Bind {"ParameterFormatCodes": [1,1], "PreparedStatement": "insert_0", "Parameters": [{"binary":"757365722d31406578616d706c652e636f6d"}, {"binary":"00e8febeba494ee0ae269550e08cae0f"}]}
Execute
Sync
----

until
ReadyForQuery
----
{"Type":"ParseComplete"}
{"Type":"ParameterDescription","ParameterOIDs":[25,2950]}
{"Type":"NoData"}
{"Type":"BindComplete"}
{"Type":"CommandComplete","CommandTag":"INSERT 0 1"}
{"Type":"ReadyForQuery","TxStatus":"I"}

# Check without sending ParameterOIDs.
send
Parse {"Query": "SELECT u0.id, u0.email FROM collated_string_table AS u0 WHERE (u0.email = $1)", "Name": "select_0"}
Describe {"Name": "select_0", "ObjectType": "S"}
Bind {"ParameterFormatCodes": [1], "PreparedStatement": "select_0", "Parameters": [{"binary":"555345522d32406578616d706c652e636f6d"}]}
Execute
Sync
----

until ignore_table_oids
ReadyForQuery
----
{"Type":"ParseComplete"}
{"Type":"ParameterDescription","ParameterOIDs":[25]}
{"Type":"RowDescription","Fields":[{"Name":"id","TableOID":0,"TableAttributeNumber":1,"DataTypeOID":2950,"DataTypeSize":16,"TypeModifier":-1,"Format":0},{"Name":"email","TableOID":0,"TableAttributeNumber":2,"DataTypeOID":25,"DataTypeSize":-1,"TypeModifier":-1,"Format":0}]}
{"Type":"BindComplete"}
{"Type":"CommandComplete","CommandTag":"SELECT 0"}
{"Type":"ReadyForQuery","TxStatus":"I"}

# Check with sending ParameterOIDs.
send
Parse {"Query": "SELECT u0.id, u0.email FROM collated_string_table AS u0 WHERE (u0.email = $1)", "Name": "select_1", "ParameterOIDs": [25]}
Describe {"Name": "select_1", "ObjectType": "S"}
Bind {"ParameterFormatCodes": [1], "PreparedStatement": "select_1", "Parameters": [{"binary":"555345522d32406578616d706c652e636f6d"}]}
Execute
Sync
----

until ignore_table_oids
ReadyForQuery
----
{"Type":"ParseComplete"}
{"Type":"ParameterDescription","ParameterOIDs":[25]}
{"Type":"RowDescription","Fields":[{"Name":"id","TableOID":0,"TableAttributeNumber":1,"DataTypeOID":2950,"DataTypeSize":16,"TypeModifier":-1,"Format":0},{"Name":"email","TableOID":0,"TableAttributeNumber":2,"DataTypeOID":25,"DataTypeSize":-1,"TypeModifier":-1,"Format":0}]}
{"Type":"BindComplete"}
{"Type":"CommandComplete","CommandTag":"SELECT 0"}
{"Type":"ReadyForQuery","TxStatus":"I"}
5 changes: 5 additions & 0 deletions pkg/sql/sem/tree/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ func isConstant(expr Expr) bool {
return ok
}

func isPlaceholder(expr Expr) bool {
_, isPlaceholder := StripParens(expr).(*Placeholder)
return isPlaceholder
}

func typeCheckConstant(
ctx context.Context, semaCtx *SemaContext, c Constant, desired *types.T,
) (ret TypedExpr, err error) {
Expand Down
30 changes: 21 additions & 9 deletions pkg/sql/sem/tree/overload.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,9 +679,7 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs(
} else {
s.typedExprs = make([]TypedExpr, len(s.exprs))
}
s.constIdxs, s.placeholderIdxs, s.resolvableIdxs = typeCheckSplitExprs(
semaCtx, s.exprs,
)
s.constIdxs, s.placeholderIdxs, s.resolvableIdxs = typeCheckSplitExprs(s.exprs)

// If no overloads are provided, just type check parameters and return.
if numOverloads == 0 {
Expand Down Expand Up @@ -726,8 +724,18 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs(
// out impossible candidates based on identical parameters. For instance,
// f(int, float) is not a possible candidate for the expression f($1, $1).

// Filter out overloads on resolved types.
// Filter out overloads on resolved types. This includes resolved placeholders
// and any other resolvable exprs.
var typeableIdxs = util.FastIntSet{}
for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) {
typeableIdxs.Add(i)
}
for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) {
if !semaCtx.isUnresolvedPlaceholder(s.exprs[i]) {
typeableIdxs.Add(i)
}
}
for i, ok := typeableIdxs.Next(0); ok; i, ok = typeableIdxs.Next(i + 1) {
paramDesired := types.Any

// If all remaining candidates require the same type for this parameter,
Expand Down Expand Up @@ -789,10 +797,10 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs(
}

var homogeneousTyp *types.T
if !s.resolvableIdxs.Empty() {
idx, _ := s.resolvableIdxs.Next(0)
if !typeableIdxs.Empty() {
idx, _ := typeableIdxs.Next(0)
homogeneousTyp = s.typedExprs[idx].ResolvedType()
for i, ok := s.resolvableIdxs.Next(idx); ok; i, ok = s.resolvableIdxs.Next(i + 1) {
for i, ok := typeableIdxs.Next(idx); ok; i, ok = typeableIdxs.Next(i + 1) {
if !homogeneousTyp.Equivalent(s.typedExprs[i].ResolvedType()) {
homogeneousTyp = nil
break
Expand Down Expand Up @@ -1196,8 +1204,9 @@ func defaultTypeCheck(
}
for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) {
if errorOnPlaceholders {
_, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any)
return err
if _, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any); err != nil {
return err
}
}
// If we dont want to error on args, avoid type checking them without a desired type.
s.typedExprs[i] = StripParens(s.exprs[i]).(*Placeholder)
Expand Down Expand Up @@ -1268,6 +1277,9 @@ func checkReturnPlaceholdersAtIdx(
}
return false, err
}
if typ.ResolvedType().IsAmbiguous() {
return false, nil
}
s.typedExprs[i] = typ
}
s.overloadIdxs = append(s.overloadIdxs[:0], idx)
Expand Down
7 changes: 6 additions & 1 deletion pkg/sql/sem/tree/placeholders.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ func (p *PlaceholderTypesInfo) SetType(idx PlaceholderIdx, typ *types.T) error {
pgcode.DatatypeMismatch,
"placeholder %s already has type %s, cannot assign %s", idx, t, typ)
}
return nil
// If `t` is not ambiguous or if `typ` is ambiguous, then we shouldn't
// change the type that's already set. Otherwise, we can use `typ` since
// it is more specific.
if !t.IsAmbiguous() || typ.IsAmbiguous() {
return nil
}
}
p.Types[idx] = typ
return nil
Expand Down
41 changes: 32 additions & 9 deletions pkg/sql/sem/tree/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -1696,10 +1696,19 @@ func (expr *Placeholder) TypeCheck(
return expr, err
} else if ok {
typ = typ.WithoutTypeModifiers()
if !desired.Equivalent(typ) {
// This indicates there's a conflict between what the type system thinks
// the type for this position should be, and the actual type of the
// placeholder. This actual placeholder type could be either a type hint
if !desired.Equivalent(typ) || (typ.IsAmbiguous() && !desired.IsAmbiguous()) {
// This indicates either:
// - There's a conflict between what the type system thinks
// the type for this position should be, and the actual type of the
// placeholder.
// - A type was already set for the placeholder, but it was ambiguous. If
// the desired type is not ambiguous then it can be used as the
// placeholder type. This can happen during overload type checking: an
// overload that operates on collated strings might cause the type
// checker to assign AnyCollatedString to a placeholder, but a later
// stage of type checking can further refine the desired type.
//
// This actual placeholder type could be either a type hint
// (from pgwire or from a SQL PREPARE), or the actual value type.
//
// To resolve this situation, we *override* the placeholder type with what
Expand Down Expand Up @@ -2372,7 +2381,7 @@ func typeCheckSameTypedExprs(
// TODO(nvanbenschoten): Look into reducing allocations here.
typedExprs := make([]TypedExpr, len(exprs))

constIdxs, placeholderIdxs, resolvableIdxs := typeCheckSplitExprs(semaCtx, exprs)
constIdxs, placeholderIdxs, resolvableIdxs := typeCheckSplitExprs(exprs)

s := typeCheckExprsState{
ctx: ctx,
Expand Down Expand Up @@ -2530,11 +2539,25 @@ func typeCheckSameTypedConsts(
return nil, errors.AssertionFailedf("should throw error above")
}

// Used to type check all constants with the optional desired type. The
// type that is chosen here will then be set to any placeholders.
// Used to type check all constants with the optional desired type. First,
// placeholders with type hints are checked, then constants are checked to
// match the resulting type. The type that is chosen here will then be set
// to any unresolved placeholders.
func typeCheckConstsAndPlaceholdersWithDesired(
s typeCheckExprsState, desired *types.T,
) ([]TypedExpr, *types.T, error) {
if !s.placeholderIdxs.Empty() {
for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) {
if !s.semaCtx.isUnresolvedPlaceholder(s.exprs[i]) {
typedExpr, err := typeCheckAndRequire(s.ctx, s.semaCtx, s.exprs[i], desired, "placeholder")
if err != nil {
return nil, nil, err
}
s.typedExprs[i] = typedExpr
desired = typedExpr.ResolvedType()
}
}
}
typ, err := typeCheckSameTypedConsts(s, desired, false)
if err != nil {
return nil, nil, err
Expand All @@ -2552,13 +2575,13 @@ func typeCheckConstsAndPlaceholdersWithDesired(
// - Placeholders
// - All other Exprs
func typeCheckSplitExprs(
semaCtx *SemaContext, exprs []Expr,
exprs []Expr,
) (constIdxs util.FastIntSet, placeholderIdxs util.FastIntSet, resolvableIdxs util.FastIntSet) {
for i, expr := range exprs {
switch {
case isConstant(expr):
constIdxs.Add(i)
case semaCtx.isUnresolvedPlaceholder(expr):
case isPlaceholder(expr):
placeholderIdxs.Add(i)
default:
resolvableIdxs.Add(i)
Expand Down
3 changes: 2 additions & 1 deletion pkg/sql/sem/tree/type_check_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func TestTypeCheckSameTypedExprsError(t *testing.T) {
tupleIntMismatchErr := `expected .* to be of type (tuple|int), found type (tuple|int)`
tupleLenErr := `expected tuple .* to have a length of .*`
placeholderErr := `could not determine data type of placeholder .*`
placeholderAlreadyAssignedErr := `placeholder .* already has type (decimal|int), cannot assign (decimal|int)`

testData := []struct {
ptypes tree.PlaceholderTypes
Expand All @@ -325,7 +326,7 @@ func TestTypeCheckSameTypedExprsError(t *testing.T) {
// Single type mismatches.
{nil, nil, exprs(dint(1), decConst("1.1")), decimalIntMismatchErr},
{nil, nil, exprs(dint(1), ddecimal(1)), decimalIntMismatchErr},
{ptypesInt, nil, exprs(decConst("1.1"), placeholder(0)), decimalIntMismatchErr},
{ptypesInt, nil, exprs(decConst("1.1"), placeholder(0)), placeholderAlreadyAssignedErr},
// Tuple type mismatches.
{nil, nil, exprs(tuple(dint(1)), tuple(ddecimal(1))), tupleFloatIntMismatchErr},
{nil, nil, exprs(tuple(dint(1)), dint(1), dint(1)), tupleIntMismatchErr},
Expand Down
28 changes: 28 additions & 0 deletions pkg/sql/sem/tree/type_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/testutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/stretchr/testify/require"
)

// The following tests need both the type checking infrastructure and also
Expand Down Expand Up @@ -437,3 +438,30 @@ func TestTypeCheckVolatility(t *testing.T) {
}
}
}

func TestTypeCheckCollatedString(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

ctx := context.Background()

// Typecheck without any restrictions.
semaCtx := tree.MakeSemaContext()
semaCtx.Properties.Require("", 0 /* flags */)

// Hint a normal string type for $1.
placeholderTypes := []*types.T{types.String}
err := semaCtx.Placeholders.Init(len(placeholderTypes), placeholderTypes)
require.NoError(t, err)

// The collated string constant must be on the LHS for this test, so that
// the type-checker chooses the collated string overload first.
expr, err := parser.ParseExpr("'cat'::STRING COLLATE \"en-US-u-ks-level2\" = ($1)")
require.NoError(t, err)
typed, err := tree.TypeCheck(ctx, expr, &semaCtx, types.Any)
require.NoError(t, err)

rightTyp := typed.(*tree.ComparisonExpr).Right.(tree.TypedExpr).ResolvedType()
require.Equal(t, rightTyp.Family(), types.CollatedStringFamily)
require.Equal(t, rightTyp.Locale(), "en-US-u-ks-level2")
}

0 comments on commit 33ead8e

Please sign in to comment.