From c91855bf277d548c5ee1bb902f72c1542b3af7f6 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 19 Jun 2024 15:56:30 +0200 Subject: [PATCH 01/19] feat: make the arguments print themselves with type info Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_format.go | 33 +++++++++++++++++++ go/vt/sqlparser/ast_format_fast.go | 51 ++++++++++++++++++++++++++++++ go/vt/sqlparser/normalizer_test.go | 14 ++++---- 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 49de08381d2..b48bbd00926 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1361,6 +1361,39 @@ func (node *Literal) Format(buf *TrackedBuffer) { // Format formats the node. func (node *Argument) Format(buf *TrackedBuffer) { + // We need to make sure that any value used still returns + // the right type when interpolated. For example, if we have a + // decimal type with 0 scale, we don't want it to be interpreted + // as an integer after interpolation as that would the default + // literal interpretation in MySQL. + switch { + case node.Type == sqltypes.Unknown: + // Ensure we handle unknown first as we don't want to treat + // the type as a bitmask for the further tests. + // do nothing, the default literal will be correct. + case sqltypes.IsDecimal(node.Type): + buf.astPrintf(node, "CAST(:%#s AS DECIMAL(%d, %d))", node.Name, node.Size, node.Scale) + return + case sqltypes.IsUnsigned(node.Type): + buf.astPrintf(node, "CAST(:%#s AS UNSIGNED)", node.Name) + return + case node.Type == sqltypes.Float64: + buf.astPrintf(node, "CAST(:%#s AS DOUBLE)", node.Name) + return + case node.Type == sqltypes.Float32: + buf.astPrintf(node, "CAST(:%#s AS FLOAT)", node.Name) + return + case sqltypes.IsDate(node.Type): + buf.astPrintf(node, "CAST(:%#s AS DATE)", node.Name) + return + case node.Type == sqltypes.Time: + buf.astPrintf(node, "CAST(:%#s AS TIME)", node.Name) + return + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.astPrintf(node, "CAST(:%#s AS TIMESTAMP)", node.Name) + return + } + // Nothing special to do, the default literal will be correct. buf.WriteArg(":", node.Name) if node.Type >= 0 { // For bind variables that are statically typed, emit their type as an adjacent comment. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 87626f0b799..038ea1b6f6d 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1780,6 +1780,57 @@ func (node *Literal) FormatFast(buf *TrackedBuffer) { // FormatFast formats the node. func (node *Argument) FormatFast(buf *TrackedBuffer) { + // We need to make sure that any value used still returns + // the right type when interpolated. For example, if we have a + // decimal type with 0 scale, we don't want it to be interpreted + // as an integer after interpolation as that would the default + // literal interpretation in MySQL. + switch { + case node.Type == sqltypes.Unknown: + // Ensure we handle unknown first as we don't want to treat + // the type as a bitmask for the further tests. + // do nothing, the default literal will be correct. + case sqltypes.IsDecimal(node.Type): + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS DECIMAL(") + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString(", ") + buf.WriteString(fmt.Sprintf("%d", node.Scale)) + buf.WriteString("))") + return + case sqltypes.IsUnsigned(node.Type): + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS UNSIGNED)") + return + case node.Type == sqltypes.Float64: + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS DOUBLE)") + return + case node.Type == sqltypes.Float32: + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS FLOAT)") + return + case sqltypes.IsDate(node.Type): + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS DATE)") + return + case node.Type == sqltypes.Time: + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS TIME)") + return + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS TIMESTAMP)") + return + } + // Nothing special to do, the default literal will be correct. buf.WriteArg(":", node.Name) if node.Type >= 0 { // For bind variables that are statically typed, emit their type as an adjacent comment. diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 19b0cfbcac6..b1df0c3a386 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -75,28 +75,28 @@ func TestNormalize(t *testing.T) { }, { // float val in: "select * from t where foobar = 1.2", - outstmt: "select * from t where foobar = :foobar /* DECIMAL(2,1) */", + outstmt: "select * from t where foobar = CAST(:foobar AS DECIMAL(2, 1))", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.DecimalBindVariable("1.2"), }, }, { // datetime val in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'", - outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */", + outstmt: "select * from t where foobar = CAST(:foobar AS DATE)", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")), }, }, { // time val in: "select * from t where foobar = time'12:34:56.123456'", - outstmt: "select * from t where foobar = :foobar /* TIME(6) */", + outstmt: "select * from t where foobar = CAST(:foobar AS TIME)", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")), }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", - outstmt: "select * from t where foo = :foo /* DECIMAL(2,1) */ and bar = :bar /* INT64 */", + outstmt: "select * from t where foo = CAST(:foo AS DECIMAL(2, 1)) and bar = :bar /* INT64 */", outbv: map[string]*querypb.BindVariable{ "foo": sqltypes.DecimalBindVariable("1.2"), "bar": sqltypes.Int64BindVariable(2), @@ -334,21 +334,21 @@ func TestNormalize(t *testing.T) { }, { // DateVal should also be normalized in: `select date'2022-08-06'`, - outstmt: `select :bv1 /* DATE */ from dual`, + outstmt: `select CAST(:bv1 AS DATE) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Date, []byte("2022-08-06"))), }, }, { // TimeVal should also be normalized in: `select time'17:05:12'`, - outstmt: `select :bv1 /* TIME */ from dual`, + outstmt: `select CAST(:bv1 AS TIME) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Time, []byte("17:05:12"))), }, }, { // TimestampVal should also be normalized in: `select timestamp'2022-08-06 17:05:12'`, - outstmt: `select :bv1 /* DATETIME */ from dual`, + outstmt: `select CAST(:bv1 AS DATE) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))), }, From 8dd65d590a713da91f652c9e13b925c61db99b08 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 25 Jun 2024 15:15:41 +0200 Subject: [PATCH 02/19] feat: add fsp to date/time/timestamp Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_format.go | 20 ++++++++++++++++--- go/vt/sqlparser/ast_format_fast.go | 32 +++++++++++++++++++++++++++--- go/vt/sqlparser/normalizer_test.go | 11 ++++++++-- 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index b48bbd00926..7e386bbe5d5 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1384,13 +1384,27 @@ func (node *Argument) Format(buf *TrackedBuffer) { buf.astPrintf(node, "CAST(:%#s AS FLOAT)", node.Name) return case sqltypes.IsDate(node.Type): - buf.astPrintf(node, "CAST(:%#s AS DATE)", node.Name) + if node.Size == 0 { + buf.astPrintf(node, "CAST(:%#s AS DATE)", node.Name) + return + } + buf.astPrintf(node, "CAST(:%#s AS DATE(%d))", node.Name, node.Size) return case node.Type == sqltypes.Time: - buf.astPrintf(node, "CAST(:%#s AS TIME)", node.Name) + if node.Size == 0 { + buf.astPrintf(node, "CAST(:%#s AS TIME)", node.Name) + return + } + + buf.astPrintf(node, "CAST(:%#s AS TIME(%d))", node.Name, node.Size) return case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: - buf.astPrintf(node, "CAST(:%#s AS TIMESTAMP)", node.Name) + if node.Size == 0 { + buf.astPrintf(node, "CAST(:%#s AS TIMESTAMP)", node.Name) + return + } + + buf.astPrintf(node, "CAST(:%#s AS TIMESTAMP(%d))", node.Name, node.Size) return } // Nothing special to do, the default literal will be correct. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 038ea1b6f6d..05ea43119ad 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1815,19 +1815,45 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { buf.WriteString(" AS FLOAT)") return case sqltypes.IsDate(node.Type): + if node.Size == 0 { + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS DATE)") + return + } buf.WriteString("CAST(:") buf.WriteString(node.Name) - buf.WriteString(" AS DATE)") + buf.WriteString(" AS DATE(") + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString("))") return case node.Type == sqltypes.Time: + if node.Size == 0 { + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS TIME)") + return + } + buf.WriteString("CAST(:") buf.WriteString(node.Name) - buf.WriteString(" AS TIME)") + buf.WriteString(" AS TIME(") + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString("))") return case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + if node.Size == 0 { + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS TIMESTAMP)") + return + } + buf.WriteString("CAST(:") buf.WriteString(node.Name) - buf.WriteString(" AS TIMESTAMP)") + buf.WriteString(" AS TIMESTAMP(") + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString("))") return } // Nothing special to do, the default literal will be correct. diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index b1df0c3a386..9cab1f03b0c 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -82,17 +82,24 @@ func TestNormalize(t *testing.T) { }, { // datetime val in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'", - outstmt: "select * from t where foobar = CAST(:foobar AS DATE)", + outstmt: "select * from t where foobar = CAST(:foobar AS DATE(6))", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")), }, }, { // time val in: "select * from t where foobar = time'12:34:56.123456'", - outstmt: "select * from t where foobar = CAST(:foobar AS TIME)", + outstmt: "select * from t where foobar = CAST(:foobar AS TIME(6))", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")), }, + }, { + // time val + in: "select * from t where foobar = time'12:34:56'", + outstmt: "select * from t where foobar = CAST(:foobar AS TIME)", + outbv: map[string]*querypb.BindVariable{ + "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56")), + }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", From 35bdac5544ec5c1034ca45ac93d1de88436b6d30 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Tue, 25 Jun 2024 15:21:37 -0600 Subject: [PATCH 03/19] Add TestCastBindVars and apply review suggestions Signed-off-by: Florent Poinsard --- go/vt/sqlparser/ast_format.go | 47 ++++++++------ go/vt/sqlparser/ast_format_fast.go | 61 ++++++++---------- go/vt/sqlparser/parsed_query_test.go | 94 +++++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 56 deletions(-) diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 7e386bbe5d5..35564eccdcd 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1372,39 +1372,50 @@ func (node *Argument) Format(buf *TrackedBuffer) { // the type as a bitmask for the further tests. // do nothing, the default literal will be correct. case sqltypes.IsDecimal(node.Type): - buf.astPrintf(node, "CAST(:%#s AS DECIMAL(%d, %d))", node.Name, node.Size, node.Scale) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.astPrintf(node, " AS DECIMAL(%d, %d))", node.Size, node.Scale) return case sqltypes.IsUnsigned(node.Type): - buf.astPrintf(node, "CAST(:%#s AS UNSIGNED)", node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS UNSIGNED)") return case node.Type == sqltypes.Float64: - buf.astPrintf(node, "CAST(:%#s AS DOUBLE)", node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DOUBLE)") return case node.Type == sqltypes.Float32: - buf.astPrintf(node, "CAST(:%#s AS FLOAT)", node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS FLOAT)") return - case sqltypes.IsDate(node.Type): + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATETIME") if node.Size == 0 { - buf.astPrintf(node, "CAST(:%#s AS DATE)", node.Name) + buf.WriteString(")") return } - buf.astPrintf(node, "CAST(:%#s AS DATE(%d))", node.Name, node.Size) + buf.astPrintf(node, "(%d))", node.Size) return - case node.Type == sqltypes.Time: - if node.Size == 0 { - buf.astPrintf(node, "CAST(:%#s AS TIME)", node.Name) - return - } - - buf.astPrintf(node, "CAST(:%#s AS TIME(%d))", node.Name, node.Size) + case sqltypes.IsDate(node.Type): + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATE") + buf.WriteString(")") return - case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + case node.Type == sqltypes.Time: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS TIME") if node.Size == 0 { - buf.astPrintf(node, "CAST(:%#s AS TIMESTAMP)", node.Name) + buf.WriteString(")") return } - - buf.astPrintf(node, "CAST(:%#s AS TIMESTAMP(%d))", node.Name, node.Size) + buf.astPrintf(node, "(%d))", node.Size) return } // Nothing special to do, the default literal will be correct. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 05ea43119ad..e2fff95804a 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1791,8 +1791,8 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { // the type as a bitmask for the further tests. // do nothing, the default literal will be correct. case sqltypes.IsDecimal(node.Type): - buf.WriteString("CAST(:") - buf.WriteString(node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) buf.WriteString(" AS DECIMAL(") buf.WriteString(fmt.Sprintf("%d", node.Size)) buf.WriteString(", ") @@ -1800,58 +1800,47 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { buf.WriteString("))") return case sqltypes.IsUnsigned(node.Type): - buf.WriteString("CAST(:") - buf.WriteString(node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) buf.WriteString(" AS UNSIGNED)") return case node.Type == sqltypes.Float64: - buf.WriteString("CAST(:") - buf.WriteString(node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) buf.WriteString(" AS DOUBLE)") return case node.Type == sqltypes.Float32: - buf.WriteString("CAST(:") - buf.WriteString(node.Name) + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) buf.WriteString(" AS FLOAT)") return - case sqltypes.IsDate(node.Type): + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATETIME") if node.Size == 0 { - buf.WriteString("CAST(:") - buf.WriteString(node.Name) - buf.WriteString(" AS DATE)") + buf.WriteString(")") return } - buf.WriteString("CAST(:") - buf.WriteString(node.Name) - buf.WriteString(" AS DATE(") + buf.WriteByte('(') buf.WriteString(fmt.Sprintf("%d", node.Size)) buf.WriteString("))") return - case node.Type == sqltypes.Time: - if node.Size == 0 { - buf.WriteString("CAST(:") - buf.WriteString(node.Name) - buf.WriteString(" AS TIME)") - return - } - - buf.WriteString("CAST(:") - buf.WriteString(node.Name) - buf.WriteString(" AS TIME(") - buf.WriteString(fmt.Sprintf("%d", node.Size)) - buf.WriteString("))") + case sqltypes.IsDate(node.Type): + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATE") + buf.WriteString(")") return - case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + case node.Type == sqltypes.Time: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS TIME") if node.Size == 0 { - buf.WriteString("CAST(:") - buf.WriteString(node.Name) - buf.WriteString(" AS TIMESTAMP)") + buf.WriteString(")") return } - - buf.WriteString("CAST(:") - buf.WriteString(node.Name) - buf.WriteString(" AS TIMESTAMP(") + buf.WriteByte('(') buf.WriteString(fmt.Sprintf("%d", node.Size)) buf.WriteString("))") return diff --git a/go/vt/sqlparser/parsed_query_test.go b/go/vt/sqlparser/parsed_query_test.go index ef59676883f..8ade9d4d31c 100644 --- a/go/vt/sqlparser/parsed_query_test.go +++ b/go/vt/sqlparser/parsed_query_test.go @@ -20,10 +20,11 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - - "github.com/stretchr/testify/assert" ) func TestNewParsedQuery(t *testing.T) { @@ -205,3 +206,92 @@ func TestParseAndBind(t *testing.T) { }) } } + +func TestCastBindVars(t *testing.T) { + testcases := []struct { + typ sqltypes.Type + size int + binds map[string]*querypb.BindVariable + out string + }{ + { + typ: sqltypes.Decimal, + binds: map[string]*querypb.BindVariable{"arg": sqltypes.DecimalBindVariable("50")}, + out: "select CAST(50 AS DECIMAL(0, 0)) from ", + }, + { + typ: sqltypes.Uint32, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Uint32, Value: sqltypes.NewUint32(42).Raw()}}, + out: "select CAST(42 AS UNSIGNED) from ", + }, + { + typ: sqltypes.Float64, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float64, Value: sqltypes.NewFloat64(42.42).Raw()}}, + out: "select CAST(42.42 AS DOUBLE) from ", + }, + { + typ: sqltypes.Float32, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float32, Value: sqltypes.NewFloat32(42).Raw()}}, + out: "select CAST(42 AS FLOAT) from ", + }, + { + typ: sqltypes.Date, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Date, Value: sqltypes.NewDate("2021-10-30").Raw()}}, + out: "select CAST('2021-10-30' AS DATE) from ", + }, + { + typ: sqltypes.Time, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}}, + out: "select CAST('12:00:00' AS TIME) from ", + }, + { + typ: sqltypes.Time, + size: 6, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}}, + out: "select CAST('12:00:00' AS TIME(6)) from ", + }, + { + typ: sqltypes.Timestamp, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ", + }, + { + typ: sqltypes.Timestamp, + size: 6, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ", + }, + { + typ: sqltypes.Datetime, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ", + }, + { + typ: sqltypes.Datetime, + size: 6, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ", + }, + } + + for _, testcase := range testcases { + t.Run(testcase.out, func(t *testing.T) { + argument := NewTypedArgument("arg", testcase.typ) + if testcase.size > 0 { + argument.Size = int32(testcase.size) + } + + s := &Select{ + SelectExprs: SelectExprs{ + NewAliasedExpr(argument, ""), + }, + } + + pq := NewParsedQuery(s) + out, err := pq.GenerateQuery(testcase.binds, nil) + + require.NoError(t, err) + require.Equal(t, testcase.out, out) + }) + } +} From 77a41babb7f447de6a90a7e232969bfaa0cdef7d Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Tue, 25 Jun 2024 20:03:26 -0600 Subject: [PATCH 04/19] Fix review suggestions Signed-off-by: Florent Poinsard --- .../vtgate/queries/normalize/normalize_test.go | 6 +++--- .../vtgate/queries/subquery/subquery_test.go | 17 +++++++++++++++++ go/vt/sqlparser/normalizer_test.go | 4 ++-- .../multi-output/selectsharded-output.txt | 8 ++++---- .../testdata/multi-output/unsharded-output.txt | 6 +++--- .../testdata/twopc-output/unsharded-output.txt | 4 ++-- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 51d9f9f24bf..283c325bc1d 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -40,11 +40,11 @@ func TestNormalizeAllFields(t *testing.T) { defer conn.Close() insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)` - normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` + normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, CAST(:vtg6 AS DECIMAL(3, 2)), CAST(:vtg7 AS DECIMAL(3, 2)), :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` vtgateVersion, err := cluster.GetMajorVersion("vtgate") require.NoError(t, err) - if vtgateVersion < 20 { - normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` + if vtgateVersion < 21 { + normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` } selectQuery := "select * from t1" utils.Exec(t, conn, insertQuery) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index abbf5ff15e8..576644f5c5c 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -23,6 +23,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" ) @@ -232,3 +234,18 @@ func TestSubqueries(t *testing.T) { }) } } + +func TestToto(t *testing.T) { + t.Skip("WIP - does not work") + + query := "select (select sum(id) from user) from user_extra" + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')") + mcmp.Exec("INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')") + + r := mcmp.Exec(query) + require.True(t, r.Fields[0].Type == sqltypes.Decimal) +} diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 9cab1f03b0c..3ab1b8b6998 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -82,7 +82,7 @@ func TestNormalize(t *testing.T) { }, { // datetime val in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'", - outstmt: "select * from t where foobar = CAST(:foobar AS DATE(6))", + outstmt: "select * from t where foobar = CAST(:foobar AS DATETIME(6))", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")), }, @@ -355,7 +355,7 @@ func TestNormalize(t *testing.T) { }, { // TimestampVal should also be normalized in: `select timestamp'2022-08-06 17:05:12'`, - outstmt: `select CAST(:bv1 AS DATE) from dual`, + outstmt: `select CAST(:bv1 AS DATETIME) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))), }, diff --git a/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt index 6fb355b4a05..3109aef341c 100644 --- a/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt @@ -91,10 +91,10 @@ select name, count(*) from user group by name /* scatter aggregate */ ---------------------------------------------------------------------- select 1, "hello", 3.14, null from user limit 10 /* select constant sql values */ -1 ks_sharded/-40: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ -1 ks_sharded/40-80: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ -1 ks_sharded/80-c0: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ -1 ks_sharded/c0-: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/-40: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/40-80: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/80-c0: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/c0-: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ ---------------------------------------------------------------------- select * from (select id from user) s /* scatter paren select */ diff --git a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt index b63683ca274..40a602ae37f 100644 --- a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt @@ -11,7 +11,7 @@ select * from t1 ---------------------------------------------------------------------- insert into t1 (id,intval,floatval) values (1,2,3.14) -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) ---------------------------------------------------------------------- update t1 set intval = 10 @@ -24,7 +24,7 @@ update t1 set intval = 10 update t1 set floatval = 9.99 1 ks_unsharded/-: begin -1 ks_unsharded/-: update t1 set floatval = 9.99 limit 10001 /* DECIMAL(3,2) */ +1 ks_unsharded/-: update t1 set floatval = cast(9.99 as DECIMAL(3, 2)) limit 10001 1 ks_unsharded/-: commit ---------------------------------------------------------------------- @@ -37,7 +37,7 @@ delete from t1 where id = 100 ---------------------------------------------------------------------- insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14 -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 /* DECIMAL(3,2) */ +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) on duplicate key update intval = 3, floatval = cast(3.14 as DECIMAL(3, 2)) ---------------------------------------------------------------------- select ID from t1 diff --git a/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt b/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt index 0db31f10110..a1cf83a5726 100644 --- a/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt @@ -12,7 +12,7 @@ select * from t1 insert into t1 (id,intval,floatval) values (1,2,3.14) 1 ks_unsharded/-: begin -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) 2 ks_unsharded/-: commit ---------------------------------------------------------------------- @@ -42,7 +42,7 @@ delete from t1 where id = 100 insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14 1 ks_unsharded/-: begin -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) on duplicate key update intval = 3, floatval = cast(3.14 as DECIMAL(3, 2)) 2 ks_unsharded/-: commit ---------------------------------------------------------------------- From b671bb683dfc67a36321b3e6a77865c45b46c8de Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 26 Jun 2024 11:26:36 +0530 Subject: [PATCH 05/19] feat: fix the sqltype used for arguments for subqueries Signed-off-by: Manan Gupta --- .../vtgate/queries/subquery/subquery_test.go | 2 - .../operators/subquery_planning.go | 11 ++- .../planbuilder/testdata/aggr_cases.json | 8 +-- .../planbuilder/testdata/select_cases.json | 68 ++++++++++++++++--- .../planbuilder/testdata/wireup_cases.json | 4 +- 5 files changed, 76 insertions(+), 17 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index 576644f5c5c..a32ed6fe8cc 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -236,8 +236,6 @@ func TestSubqueries(t *testing.T) { } func TestToto(t *testing.T) { - t.Skip("WIP - does not work") - query := "select (select sum(id) from user) from user_extra" mcmp, closer := start(t) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index cdc0b8b191a..06ff0300482 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -23,6 +23,7 @@ import ( "golang.org/x/exp/slices" "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" @@ -388,7 +389,15 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp } return sqlparser.NewArgument(sq1.HasValuesName) default: - return sqlparser.NewArgument(s) + argType := sqltypes.Unknown + ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) + if isAe { + evalType, found := ctx.TypeForExpr(ae.Expr) + if found { + argType = evalType.Type() + } + } + return sqlparser.NewTypedArgument(s, argType) } } } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index a272954725d..c1d7e6a15c8 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6114,8 +6114,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select :__sq1 + :__sq2 as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual where 1 != 1", - "Query": "select :__sq1 + :__sq2 as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual", + "FieldQuery": "select :__sq1 /* INT64 */ + :__sq2 /* INT64 */ as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual where 1 != 1", + "Query": "select :__sq1 /* INT64 */ + :__sq2 /* INT64 */ as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual", "Table": "dual" } ] @@ -6764,8 +6764,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", - "Query": "select max(:__sq1), weight_string(:__sq1) from `user` where id = 2 group by weight_string(:__sq1)", + "FieldQuery": "select max(:__sq1 /* INT16 */), weight_string(:__sq1 /* INT16 */) from `user` where 1 != 1 group by weight_string(:__sq1 /* INT16 */)", + "Query": "select max(:__sq1 /* INT16 */), weight_string(:__sq1 /* INT16 */) from `user` where id = 2 group by weight_string(:__sq1 /* INT16 */)", "Table": "`user`", "Values": [ "2" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index fdb189d067b..b312c9dd484 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1418,8 +1418,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select a, :__sq1 as `(select col from ``user``)` from unsharded where 1 != 1", - "Query": "select a, :__sq1 as `(select col from ``user``)` from unsharded", + "FieldQuery": "select a, :__sq1 /* INT16 */ as `(select col from ``user``)` from unsharded where 1 != 1", + "Query": "select a, :__sq1 /* INT16 */ as `(select col from ``user``)` from unsharded", "Table": "unsharded" } ] @@ -1463,8 +1463,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select a, 1 + :__sq1 as `1 + (select col from ``user``)` from unsharded where 1 != 1", - "Query": "select a, 1 + :__sq1 as `1 + (select col from ``user``)` from unsharded", + "FieldQuery": "select a, 1 + :__sq1 /* INT16 */ as `1 + (select col from ``user``)` from unsharded where 1 != 1", + "Query": "select a, 1 + :__sq1 /* INT16 */ as `1 + (select col from ``user``)` from unsharded", "Table": "unsharded" } ] @@ -2510,8 +2510,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select t.a from (select :__sq1 as a from `user` where 1 != 1) as t where 1 != 1", - "Query": "select t.a from (select :__sq1 as a from `user`) as t", + "FieldQuery": "select t.a from (select :__sq1 /* INT16 */ as a from `user` where 1 != 1) as t where 1 != 1", + "Query": "select t.a from (select :__sq1 /* INT16 */ as a from `user`) as t", "Table": "`user`" } ] @@ -2580,8 +2580,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :__sq1 as a from `user` where 1 != 1", - "Query": "select :__sq1 as a from `user`", + "FieldQuery": "select :__sq1 /* INT16 */ as a from `user` where 1 != 1", + "Query": "select :__sq1 /* INT16 */ as a from `user`", "Table": "`user`" } ] @@ -2717,6 +2717,58 @@ ] } }, + { + "comment": "PullOut subquery with an aggregation that should be typed in the final output", + "query": "select (select sum(col) from user) from user_extra", + "plan": { + "QueryType": "SELECT", + "Original": "select (select sum(col) from user) from user_extra", + "Instructions": { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS sum(col)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(col) from `user` where 1 != 1", + "Query": "select sum(col) from `user`", + "Table": "`user`" + } + ] + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select CAST(:__sq1 AS DECIMAL(0, 0)) as `(select sum(col) from ``user``)` from user_extra where 1 != 1", + "Query": "select CAST(:__sq1 AS DECIMAL(0, 0)) as `(select sum(col) from ``user``)` from user_extra", + "Table": "user_extra" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, { "comment": "Straight Join preserved in MySQL query", "query": "select user.id, user_extra.user_id from user straight_join user_extra where user.id = user_extra.user_id", diff --git a/go/vt/vtgate/planbuilder/testdata/wireup_cases.json b/go/vt/vtgate/planbuilder/testdata/wireup_cases.json index 3aca1f1dc66..f34c51f009a 100644 --- a/go/vt/vtgate/planbuilder/testdata/wireup_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/wireup_cases.json @@ -737,8 +737,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select u.id, :__sq1 as `(select col from ``user``)`, u.col from `user` as u where 1 != 1", - "Query": "select u.id, :__sq1 as `(select col from ``user``)`, u.col from `user` as u", + "FieldQuery": "select u.id, :__sq1 /* INT16 */ as `(select col from ``user``)`, u.col from `user` as u where 1 != 1", + "Query": "select u.id, :__sq1 /* INT16 */ as `(select col from ``user``)`, u.col from `user` as u", "Table": "`user`" } ] From f95feb89c5a245c0e29bf94a13d8ec9248fbb75a Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 26 Jun 2024 11:29:42 +0530 Subject: [PATCH 06/19] refactor: make the for loops unnested Signed-off-by: Manan Gupta --- .../operators/subquery_planning.go | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 06ff0300482..0df8a1c446e 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -372,37 +372,44 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Expr, se SubQueryExpression, subqueries ...*SubQuery) sqlparser.Expr { rewriteIt := func(s string) sqlparser.SQLNode { - for _, sq1 := range se { - if sq1.ArgName != s && sq1.HasValuesName != s { - continue + var sq1, sq2 *SubQuery + for _, sq := range se { + if sq.ArgName == s || sq.HasValuesName == s { + sq1 = sq + break } + } + for _, sq := range subqueries { + if s == sq.ArgName { + sq2 = sq + break + } + } - for _, sq2 := range subqueries { - if s == sq2.ArgName { - switch { - case sq1.FilterType.NeedsListArg(): - return sqlparser.NewListArg(s) - case sq1.FilterType == opcode.PulloutExists: - if sq1.HasValuesName == "" { - sq1.HasValuesName = ctx.ReservedVars.ReserveHasValuesSubQuery() - sq2.HasValuesName = sq1.HasValuesName - } - return sqlparser.NewArgument(sq1.HasValuesName) - default: - argType := sqltypes.Unknown - ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) - if isAe { - evalType, found := ctx.TypeForExpr(ae.Expr) - if found { - argType = evalType.Type() - } - } - return sqlparser.NewTypedArgument(s, argType) - } + if sq1 == nil || sq2 == nil { + return nil + } + + switch { + case sq1.FilterType.NeedsListArg(): + return sqlparser.NewListArg(s) + case sq1.FilterType == opcode.PulloutExists: + if sq1.HasValuesName == "" { + sq1.HasValuesName = ctx.ReservedVars.ReserveHasValuesSubQuery() + sq2.HasValuesName = sq1.HasValuesName + } + return sqlparser.NewArgument(sq1.HasValuesName) + default: + argType := sqltypes.Unknown + ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) + if isAe { + evalType, found := ctx.TypeForExpr(ae.Expr) + if found { + argType = evalType.Type() } } + return sqlparser.NewTypedArgument(s, argType) } - return nil } // replace the ColNames with Argument inside the subquery From 23a06268e81e1e86876faf978014323d3e97f5f9 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 26 Jun 2024 11:54:39 +0530 Subject: [PATCH 07/19] feat: use typed arguments wherever we can, and remove unused code Signed-off-by: Manan Gupta --- .../planbuilder/operators/expressions.go | 2 +- .../planbuilder/operators/query_planning.go | 16 ------------ .../vtgate/planbuilder/operators/subquery.go | 15 ----------- .../operators/subquery_planning.go | 7 ++--- .../plancontext/planning_context.go | 10 +++++++ .../planbuilder/testdata/aggr_cases.json | 12 ++++----- .../planbuilder/testdata/cte_cases.json | 2 +- .../planbuilder/testdata/dml_cases.json | 14 +++++----- .../planbuilder/testdata/filter_cases.json | 18 ++++++------- .../planbuilder/testdata/from_cases.json | 26 +++++++++---------- .../testdata/info_schema57_cases.json | 4 +-- .../testdata/info_schema80_cases.json | 6 ++--- .../testdata/postprocess_cases.json | 6 ++--- .../planbuilder/testdata/rails_cases.json | 4 +-- .../planbuilder/testdata/select_cases.json | 12 ++++----- .../testdata/vindex_func_cases.json | 8 +++--- .../planbuilder/testdata/wireup_cases.json | 22 ++++++++-------- 17 files changed, 80 insertions(+), 104 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index a39ae96fa88..d7c3a620289 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -44,7 +44,7 @@ func breakExpressionInLHSandRHS( Name: bvName, Expr: nodeExpr, }) - arg := sqlparser.NewArgument(bvName) + arg := sqlparser.NewTypedArgument(bvName, ctx.SQLTypeForExpr(nodeExpr)) // we are replacing one of the sides of the comparison with an argument, // but we don't want to lose the type information we have, so we copy it over ctx.SemTable.CopyExprInfo(nodeExpr, arg) diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 2f3259394e2..37b26f9181c 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -532,26 +532,10 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator, return pushOrderingUnderAggr(ctx, in, src) case *SubQueryContainer: return pushOrderingToOuterOfSubqueryContainer(ctx, in, src) - case *SubQuery: - return pushOrderingToOuterOfSubquery(ctx, in, src) } return in, NoRewrite } -func pushOrderingToOuterOfSubquery(ctx *plancontext.PlanningContext, in *Ordering, sq *SubQuery) (Operator, *ApplyResult) { - outerTableID := TableID(sq.Outer) - for idx, order := range in.Order { - deps := ctx.SemTable.RecursiveDeps(order.Inner.Expr) - if !deps.IsSolvedBy(outerTableID) { - return in, NoRewrite - } - in.Order[idx].SimplifiedExpr = sq.rewriteColNameToArgument(order.SimplifiedExpr) - in.Order[idx].Inner.Expr = sq.rewriteColNameToArgument(order.Inner.Expr) - } - sq.Outer, in.Source = in, sq.Outer - return sq, Rewrote("push ordering into outer side of subquery") -} - func pushOrderingToOuterOfSubqueryContainer(ctx *plancontext.PlanningContext, in *Ordering, subq *SubQueryContainer) (Operator, *ApplyResult) { outerTableID := TableID(subq.Outer) for _, order := range in.Order { diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index a950c3720c2..5ae0fb52e7f 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -309,18 +309,3 @@ func (sq *SubQuery) mapExpr(f func(expr sqlparser.Expr) sqlparser.Expr) { sq.Original = f(sq.Original) sq.originalSubquery = f(sq.originalSubquery).(*sqlparser.Subquery) } - -func (sq *SubQuery) rewriteColNameToArgument(expr sqlparser.Expr) sqlparser.Expr { - pre := func(cursor *sqlparser.Cursor) bool { - colName, ok := cursor.Node().(*sqlparser.ColName) - if !ok || colName.Qualifier.NonEmpty() || !colName.Name.EqualString(sq.ArgName) { - // we only want to rewrite the column name to an argument if it's the right column - return true - } - - cursor.Replace(sqlparser.NewArgument(sq.ArgName)) - return true - } - - return sqlparser.Rewrite(expr, pre, nil).(sqlparser.Expr) -} diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 0df8a1c446e..df563105f43 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -365,7 +365,7 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql // need to find the argument name for it and use that instead // we can't use the column name directly, because we're in the RHS of the join name := outer.findOrAddColNameBindVarName(ctx, col) - cursor.Replace(sqlparser.NewArgument(name)) + cursor.Replace(sqlparser.NewTypedArgument(name, ctx.SQLTypeForExpr(col))) }, nil) return result.(sqlparser.Expr) } @@ -403,10 +403,7 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp argType := sqltypes.Unknown ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) if isAe { - evalType, found := ctx.TypeForExpr(ae.Expr) - if found { - argType = evalType.Type() - } + argType = ctx.SQLTypeForExpr(ae.Expr) } return sqlparser.NewTypedArgument(s, argType) } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 90a6bdac6f8..2f33539f858 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -17,6 +17,7 @@ limitations under the License. package plancontext import ( + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -222,3 +223,12 @@ func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool } return t, true } + +// SQLTypeForExpr returns the sql type of the given expression, with nullable set if the expression is from an outer table. +func (ctx *PlanningContext) SQLTypeForExpr(e sqlparser.Expr) sqltypes.Type { + t, found := ctx.TypeForExpr(e) + if !found { + return sqltypes.Unknown + } + return t.Type() +} diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index c1d7e6a15c8..6942464665c 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -3372,7 +3372,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1 group by .0", - "Query": "select 1 from user_extra where user_extra.col = :user_col group by .0", + "Query": "select 1 from user_extra where user_extra.col = :user_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -3862,7 +3862,7 @@ "Sharded": true }, "FieldQuery": "select count(*) from user_extra as ue where 1 != 1 group by .0", - "Query": "select count(*) from user_extra as ue where ue.col = :u_col group by .0", + "Query": "select count(*) from user_extra as ue where ue.col = :u_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -3922,7 +3922,7 @@ "Sharded": true }, "FieldQuery": "select count(ue.id) from user_extra as ue where 1 != 1 group by .0", - "Query": "select count(ue.id) from user_extra as ue where ue.col = :u_col group by .0", + "Query": "select count(ue.id) from user_extra as ue where ue.col = :u_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -5153,7 +5153,7 @@ "Sharded": true }, "FieldQuery": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where 1 != 1 group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", - "Query": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where user_extra.bar = :user_col group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", + "Query": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where user_extra.bar = :user_col /* INT16 */ group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", "Table": "user_extra" } ] @@ -5431,7 +5431,7 @@ "Sharded": true }, "FieldQuery": "select count(*), sum(user_extra.bar) from user_extra where 1 != 1 group by .0", - "Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col group by .0", + "Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -5781,7 +5781,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.col = :u_col", + "Query": "select 1 from music as m where m.col = :u_col /* INT16 */", "Table": "music" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 09d155b19f6..4a69fd85fad 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -1132,7 +1132,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.col = :user_col", + "Query": "select 1 from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 9c2ed1920ee..b5d0fa8951f 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -5165,7 +5165,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.col = :u_col", + "Query": "select 1 from music as m where m.col = :u_col /* INT16 */", "Table": "music" } ] @@ -5335,7 +5335,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m, user_extra as ue where 1 != 1", - "Query": "select 1 from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col and ue.foo = 20 and m.user_id = ue.user_id", + "Query": "select 1 from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col /* INT16 */ and ue.foo = 20 and m.user_id = ue.user_id", "Table": "music, user_extra" } ] @@ -5408,7 +5408,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m, user_extra as ue where 1 != 1", - "Query": "select m.id from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col and ue.foo = 20 and m.user_id = ue.user_id", + "Query": "select m.id from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col /* INT16 */ and ue.foo = 20 and m.user_id = ue.user_id", "Table": "music, user_extra" } ] @@ -5880,7 +5880,7 @@ "Sharded": true }, "TargetTabletType": "PRIMARY", - "Query": "update `user` as u set u.col = :ue_col where u.id in ::dml_vals", + "Query": "update `user` as u set u.col = :ue_col /* INT16 */ where u.id in ::dml_vals", "Table": "user", "Values": [ "::dml_vals" @@ -6483,7 +6483,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m where 1 != 1", - "Query": "select m.id from music as m where m.baz = 21 and m.bar = :u_foo and m.col = :u_col for update", + "Query": "select m.id from music as m where m.baz = 21 and m.bar = :u_foo and m.col = :u_col /* INT16 */ for update", "Table": "music" } ] @@ -6738,7 +6738,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m where 1 != 1", - "Query": "select m.id from music as m where m.col = :u_col for update", + "Query": "select m.id from music as m where m.col = :u_col /* INT16 */ for update", "Table": "music" } ] @@ -6961,7 +6961,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m where 1 != 1", - "Query": "select m.id from music as m where m.col = :u_col for update", + "Query": "select m.id from music as m where m.col = :u_col /* INT16 */ for update", "Table": "music" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index d36c060ed6d..b60e8812dda 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -1141,7 +1141,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -1213,7 +1213,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.user_id = :user_col and user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.user_id = :user_col /* INT16 */ and user_extra.col = :user_col /* INT16 */", "Table": "user_extra", "Values": [ ":user_col" @@ -1262,7 +1262,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col and 1 = 1", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */ and 1 = 1", "Table": "user_extra" } ] @@ -1614,7 +1614,7 @@ "Sharded": true }, "FieldQuery": "select u.m from `user` as u where 1 != 1", - "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col)", + "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col /* INT16 */)", "Table": "`user`", "Values": [ "(:user_extra_col, 1)" @@ -1758,7 +1758,7 @@ "Sharded": true }, "FieldQuery": "select u.m from `user` as u where 1 != 1", - "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col and `user`.id in (select m3 from user_extra where user_extra.user_id = `user`.id))", + "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col /* INT16 */ and `user`.id in (select m3 from user_extra where user_extra.user_id = `user`.id))", "Table": "`user`", "Values": [ "(:user_extra_col, 1)" @@ -3100,7 +3100,7 @@ "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where `user`.id = :user_extra_col", + "Query": "select id from `user` where `user`.id = :user_extra_col /* INT16 */", "Table": "`user`", "Values": [ ":user_extra_col" @@ -3171,7 +3171,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.foobar = 5 and user_extra.col = :user_col", + "Query": "select 1 from user_extra where user_extra.foobar = 5 and user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3226,7 +3226,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -4228,7 +4228,7 @@ "Sharded": true }, "FieldQuery": "select count(*) from `user` as b where 1 != 1 group by .0", - "Query": "select count(*) from `user` as b where b.textcol2 = :a_textcol1 group by .0", + "Query": "select count(*) from `user` as b where b.textcol2 = :a_textcol1 /* VARCHAR */ group by .0", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 81381f3d7d7..6db17511a2a 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -585,7 +585,7 @@ "Sharded": false }, "FieldQuery": "select m1.col from unsharded as m1 where 1 != 1", - "Query": "select m1.col from unsharded as m1 where m1.col = :user_col", + "Query": "select m1.col from unsharded as m1 where m1.col = :user_col /* INT16 */", "Table": "unsharded" } ] @@ -651,7 +651,7 @@ "Sharded": true }, "FieldQuery": "select e.col from user_extra as e where 1 != 1", - "Query": "select e.col from user_extra as e where e.col = :user_col", + "Query": "select e.col from user_extra as e where e.col = :user_col /* INT16 */", "Table": "user_extra" }, { @@ -662,7 +662,7 @@ "Sharded": false }, "FieldQuery": "select 1 from unsharded as m1 where 1 != 1", - "Query": "select 1 from unsharded as m1 where m1.col = :e_col", + "Query": "select 1 from unsharded as m1 where m1.col = :e_col /* INT16 */", "Table": "unsharded" } ] @@ -1221,7 +1221,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user` where `user`.id = :user_extra_col", + "Query": "select `user`.col from `user` where `user`.id = :user_extra_col /* INT16 */", "Table": "`user`", "Values": [ ":user_extra_col" @@ -1924,7 +1924,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.col = :user_col", + "Query": "select 1 from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3244,7 +3244,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as uu where 1 != 1", - "Query": "select 1 from `user` as uu where uu.intcol = :u_intcol", + "Query": "select 1 from `user` as uu where uu.intcol = :u_intcol /* INT16 */", "Table": "`user`" } ] @@ -3357,7 +3357,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3594,7 +3594,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3654,7 +3654,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3720,7 +3720,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3799,7 +3799,7 @@ "Sharded": false }, "FieldQuery": "select unsharded_authoritative.col2 from unsharded_authoritative where 1 != 1", - "Query": "select unsharded_authoritative.col2 from unsharded_authoritative where unsharded_authoritative.col1 = :authoritative_col1", + "Query": "select unsharded_authoritative.col2 from unsharded_authoritative where unsharded_authoritative.col1 = :authoritative_col1 /* VARCHAR */", "Table": "unsharded_authoritative" } ] @@ -3921,7 +3921,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -4430,7 +4430,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.user_id = 5 and m.id = 20 and m.col = :u_col", + "Query": "select 1 from music as m where m.user_id = 5 and m.id = 20 and m.col = :u_col /* INT16 */", "Table": "music", "Values": [ "20" diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json index 12ddfa6e049..230982b2870 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json @@ -319,7 +319,7 @@ "Sharded": false }, "FieldQuery": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where 1 != 1", - "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name", + "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR */", "SysTableTableSchema": "[:v2]", "Table": "information_schema.referential_constraints" } @@ -723,7 +723,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME", + "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR */", "Table": "`user`", "Values": [ ":x_COLUMN_NAME" diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json index 3eec3685fd2..63f3bf2373f 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json @@ -319,7 +319,7 @@ "Sharded": false }, "FieldQuery": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where 1 != 1", - "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name", + "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR */", "SysTableTableSchema": "[:v2]", "Table": "information_schema.referential_constraints" } @@ -445,7 +445,7 @@ "Sharded": false }, "FieldQuery": "select 1 from information_schema.table_constraints as tc where 1 != 1", - "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name and tc.constraint_schema = :__vtschemaname /* VARCHAR */", + "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name /* VARCHAR */ and tc.constraint_schema = :__vtschemaname /* VARCHAR */", "SysTableTableName": "[tc_table_name:'table_name']", "SysTableTableSchema": "['table_schema', :cc_constraint_schema]", "Table": "information_schema.table_constraints" @@ -788,7 +788,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME", + "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR */", "Table": "`user`", "Values": [ ":x_COLUMN_NAME" diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 454740f0498..010e22c2108 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -949,7 +949,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.col = :u_col", + "Query": "select e.id from user_extra as e where e.col = :u_col /* INT16 */", "Table": "user_extra" } ] @@ -2060,8 +2060,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra where 1 != 1", - "Query": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra", + "FieldQuery": "select coalesce(:user_col /* INT16 */, user_extra.col), weight_string(coalesce(:user_col /* INT16 */, user_extra.col)) from user_extra where 1 != 1", + "Query": "select coalesce(:user_col /* INT16 */, user_extra.col), weight_string(coalesce(:user_col /* INT16 */, user_extra.col)) from user_extra", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/rails_cases.json b/go/vt/vtgate/planbuilder/testdata/rails_cases.json index c8ab8b7b9d8..3887547e628 100644 --- a/go/vt/vtgate/planbuilder/testdata/rails_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/rails_cases.json @@ -62,7 +62,7 @@ "Sharded": true }, "FieldQuery": "select 1 from book6s_order2s where 1 != 1", - "Query": "select 1 from book6s_order2s where book6s_order2s.order2_id = :order2s_id and book6s_order2s.book6_id = :book6s_id", + "Query": "select 1 from book6s_order2s where book6s_order2s.order2_id = :order2s_id /* INT64 */ and book6s_order2s.book6_id = :book6s_id /* INT64 */", "Table": "book6s_order2s", "Values": [ ":book6s_id" @@ -79,7 +79,7 @@ "Sharded": true }, "FieldQuery": "select 1 from supplier5s where 1 != 1", - "Query": "select 1 from supplier5s where supplier5s.id = :book6s_supplier5_id", + "Query": "select 1 from supplier5s where supplier5s.id = :book6s_supplier5_id /* INT64 */", "Table": "supplier5s", "Values": [ ":book6s_supplier5_id" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index b312c9dd484..51aae618daf 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2233,7 +2233,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user` where `user`.col = :t_title and `user`.id <= 4", + "Query": "select `user`.col from `user` where `user`.col = :t_title /* VARCHAR */ and `user`.id <= 4", "Table": "`user`" } ] @@ -2967,7 +2967,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra as ue where 1 != 1", - "Query": "select 1 from user_extra as ue where ue.col = :u1_col and ue.col = :u2_col limit 1", + "Query": "select 1 from user_extra as ue where ue.col = :u1_col /* INT16 */ and ue.col = :u2_col /* INT16 */ limit 1", "Table": "user_extra" } ] @@ -3024,7 +3024,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra as ue where 1 != 1", - "Query": "select 1 from user_extra as ue where ue.col = :u_col and ue.col2 = :u_col limit 1", + "Query": "select 1 from user_extra as ue where ue.col = :u_col /* INT16 */ and ue.col2 = :u_col /* INT16 */ limit 1", "Table": "user_extra" } ] @@ -3163,8 +3163,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :user_extra_col + `user`.col as `user_extra.col + ``user``.col` from `user` where 1 != 1", - "Query": "select :user_extra_col + `user`.col as `user_extra.col + ``user``.col` from `user` where `user`.id = :user_extra_id", + "FieldQuery": "select :user_extra_col /* INT16 */ + `user`.col as `user_extra.col + ``user``.col` from `user` where 1 != 1", + "Query": "select :user_extra_col /* INT16 */ + `user`.col as `user_extra.col + ``user``.col` from `user` where `user`.id = :user_extra_id", "Table": "`user`", "Values": [ ":user_extra_id" @@ -3741,7 +3741,7 @@ "Sharded": true }, "FieldQuery": "select user_metadata.user_id from user_extra, user_metadata where 1 != 1", - "Query": "select user_metadata.user_id from user_extra, user_metadata where user_extra.col = :user_col and user_extra.user_id = user_metadata.user_id", + "Query": "select user_metadata.user_id from user_extra, user_metadata where user_extra.col = :user_col /* INT16 */ and user_extra.user_id = user_metadata.user_id", "Table": "user_extra, user_metadata" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json index de5356346b2..3ac35761051 100644 --- a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json @@ -265,7 +265,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id /* VARBINARY */", "Table": "unsharded" } ] @@ -313,7 +313,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id /* VARBINARY */", "Table": "unsharded" } ] @@ -361,7 +361,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id /* VARBINARY */", "Table": "unsharded" } ] @@ -409,7 +409,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :ui_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :ui_id /* VARBINARY */", "Table": "unsharded" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/wireup_cases.json b/go/vt/vtgate/planbuilder/testdata/wireup_cases.json index f34c51f009a..62a3e65a35f 100644 --- a/go/vt/vtgate/planbuilder/testdata/wireup_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/wireup_cases.json @@ -148,7 +148,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.col = :u1_col", + "Query": "select 1 from `user` as u3 where u3.col = :u1_col /* INT16 */", "Table": "`user`" } ] @@ -210,7 +210,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.col = :u2_col", + "Query": "select 1 from `user` as u3 where u3.col = :u2_col /* INT16 */", "Table": "`user`" } ] @@ -265,7 +265,7 @@ "Sharded": true }, "FieldQuery": "select u1.id, u1.col from `user` as u1 where 1 != 1", - "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u3_col", + "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u3_col /* INT16 */", "Table": "`user`" }, { @@ -276,7 +276,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u2 where 1 != 1", - "Query": "select 1 from `user` as u2 where u2.col = :u1_col", + "Query": "select 1 from `user` as u2 where u2.col = :u1_col /* INT16 */", "Table": "`user`" } ] @@ -348,7 +348,7 @@ "Sharded": true }, "FieldQuery": "select u1.id, u1.col from `user` as u1 where 1 != 1", - "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u4_col", + "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u4_col /* INT16 */", "Table": "`user`" }, { @@ -359,7 +359,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.id = :u1_col", + "Query": "select 1 from `user` as u3 where u3.id = :u1_col /* INT16 */", "Table": "`user`", "Values": [ ":u1_col" @@ -420,7 +420,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u2 where 1 != 1", - "Query": "select 1 from `user` as u2 where u2.id = :u1_col", + "Query": "select 1 from `user` as u2 where u2.id = :u1_col /* INT16 */", "Table": "`user`", "Values": [ ":u1_col" @@ -437,7 +437,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.id = :u1_col", + "Query": "select 1 from `user` as u3 where u3.id = :u1_col /* INT16 */", "Table": "`user`", "Values": [ ":u1_col" @@ -591,7 +591,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.id = :u_col limit 10", + "Query": "select e.id from user_extra as e where e.id = :u_col /* INT16 */ limit 10", "Table": "user_extra" } ] @@ -658,7 +658,7 @@ "Sharded": true }, "FieldQuery": "select :u_id + e.id as `u.id + e.id` from user_extra as e where 1 != 1", - "Query": "select :u_id + e.id as `u.id + e.id` from user_extra as e where e.id = :u_col limit 10", + "Query": "select :u_id + e.id as `u.id + e.id` from user_extra as e where e.id = :u_col /* INT16 */ limit 10", "Table": "user_extra" } ] @@ -751,7 +751,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.id = :u_col", + "Query": "select e.id from user_extra as e where e.id = :u_col /* INT16 */", "Table": "user_extra" } ] From 67c5414974ac57eef27d443e5ec62a3beb027bfa Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 19 Jun 2024 14:45:04 +0200 Subject: [PATCH 08/19] Fix more types Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/planbuilder/operators/queryprojection.go | 4 +++- go/vt/vtgate/semantics/binder.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 5729dbd0c2e..086bafa8782 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -24,6 +24,8 @@ import ( "sort" "strings" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" @@ -97,7 +99,7 @@ func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type { if aggr.Func == nil { - return evalengine.Type{} + return evalengine.NewType(sqltypes.Unknown, collations.Unknown) } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 90a36b1f0d7..3520979efb3 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -19,6 +19,8 @@ package semantics import ( "strings" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -169,7 +171,7 @@ func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableNam continue } ts := b.org.tableSetFor(table.GetAliasedTableExpr()) - c := createCertain(ts, ts, evalengine.Type{}) + c := createCertain(ts, ts, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) deps = deps.merge(c, false) } finalDep, err := deps.get(nil) From 64d76c2d75a3db9709976285064109807ac48d44 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 26 Jun 2024 11:16:21 +0200 Subject: [PATCH 09/19] refactor: add NewUnknownType() Signed-off-by: Andres Taylor --- go/vt/vtgate/evalengine/compiler.go | 4 ++++ go/vt/vtgate/planbuilder/operators/queryprojection.go | 4 +--- go/vt/vtgate/semantics/binder.go | 4 +--- go/vt/vtgate/semantics/table_collector.go | 4 +--- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index bcb2281f1a6..b0a7edd285d 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -81,6 +81,10 @@ func (v *EnumSetValues) Equal(other *EnumSetValues) bool { return slices.Equal(*v, *other) } +func NewUnknownType() Type { + return NewType(sqltypes.Unknown, collations.Unknown) +} + func NewType(t sqltypes.Type, collation collations.ID) Type { // New types default to being nullable return NewTypeEx(t, collation, true, 0, 0, nil) diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 086bafa8782..3beec57e413 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -24,8 +24,6 @@ import ( "sort" "strings" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" @@ -99,7 +97,7 @@ func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type { if aggr.Func == nil { - return evalengine.NewType(sqltypes.Unknown, collations.Unknown) + return evalengine.NewUnknownType() } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 3520979efb3..78148f4bb1f 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -19,8 +19,6 @@ package semantics import ( "strings" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -171,7 +169,7 @@ func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableNam continue } ts := b.org.tableSetFor(table.GetAliasedTableExpr()) - c := createCertain(ts, ts, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) + c := createCertain(ts, ts, evalengine.NewUnknownType()) deps = deps.merge(c, false) } finalDep, err := deps.get(nil) diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index ae107cc070c..948edb37d47 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -19,8 +19,6 @@ package semantics import ( "fmt" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" querypb "vitess.io/vitess/go/vt/proto/query" @@ -234,7 +232,7 @@ for2: continue for2 } } - types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) + types = append(types, evalengine.NewUnknownType()) } return colNames, types } From 1118e3f5fb94afc23c317ec2ba9e3b4121e48bba Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 26 Jun 2024 11:59:37 +0200 Subject: [PATCH 10/19] feat: only use CAST when we don't have decimals Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_format.go | 2 +- go/vt/sqlparser/ast_format_fast.go | 2 +- go/vt/sqlparser/normalizer_test.go | 4 ++-- .../testdata/multi-output/selectsharded-output.txt | 8 ++++---- .../vtexplain/testdata/multi-output/unsharded-output.txt | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 35564eccdcd..da88129ee63 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1371,7 +1371,7 @@ func (node *Argument) Format(buf *TrackedBuffer) { // Ensure we handle unknown first as we don't want to treat // the type as a bitmask for the further tests. // do nothing, the default literal will be correct. - case sqltypes.IsDecimal(node.Type): + case sqltypes.IsDecimal(node.Type) && node.Scale == 0: buf.WriteString("CAST(") buf.WriteArg(":", node.Name) buf.astPrintf(node, " AS DECIMAL(%d, %d))", node.Size, node.Scale) diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index e2fff95804a..b1dd010f5ed 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1790,7 +1790,7 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { // Ensure we handle unknown first as we don't want to treat // the type as a bitmask for the further tests. // do nothing, the default literal will be correct. - case sqltypes.IsDecimal(node.Type): + case sqltypes.IsDecimal(node.Type) && node.Scale == 0: buf.WriteString("CAST(") buf.WriteArg(":", node.Name) buf.WriteString(" AS DECIMAL(") diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 3ab1b8b6998..c574b00832d 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -75,7 +75,7 @@ func TestNormalize(t *testing.T) { }, { // float val in: "select * from t where foobar = 1.2", - outstmt: "select * from t where foobar = CAST(:foobar AS DECIMAL(2, 1))", + outstmt: "select * from t where foobar = :foobar /* DECIMAL(2,1) */", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.DecimalBindVariable("1.2"), }, @@ -103,7 +103,7 @@ func TestNormalize(t *testing.T) { }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", - outstmt: "select * from t where foo = CAST(:foo AS DECIMAL(2, 1)) and bar = :bar /* INT64 */", + outstmt: "select * from t where foo = :foo /* DECIMAL(2,1) */ and bar = :bar /* INT64 */", outbv: map[string]*querypb.BindVariable{ "foo": sqltypes.DecimalBindVariable("1.2"), "bar": sqltypes.Int64BindVariable(2), diff --git a/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt index 3109aef341c..6fb355b4a05 100644 --- a/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt @@ -91,10 +91,10 @@ select name, count(*) from user group by name /* scatter aggregate */ ---------------------------------------------------------------------- select 1, "hello", 3.14, null from user limit 10 /* select constant sql values */ -1 ks_sharded/-40: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ -1 ks_sharded/40-80: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ -1 ks_sharded/80-c0: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ -1 ks_sharded/c0-: select 1, 'hello', cast(3.14 as DECIMAL(3, 2)), null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/-40: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/40-80: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/80-c0: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ +1 ks_sharded/c0-: select 1, 'hello', 3.14, null from `user` limit 10 /* INT64 */ /* select constant sql values */ ---------------------------------------------------------------------- select * from (select id from user) s /* scatter paren select */ diff --git a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt index 40a602ae37f..b63683ca274 100644 --- a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt @@ -11,7 +11,7 @@ select * from t1 ---------------------------------------------------------------------- insert into t1 (id,intval,floatval) values (1,2,3.14) -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) ---------------------------------------------------------------------- update t1 set intval = 10 @@ -24,7 +24,7 @@ update t1 set intval = 10 update t1 set floatval = 9.99 1 ks_unsharded/-: begin -1 ks_unsharded/-: update t1 set floatval = cast(9.99 as DECIMAL(3, 2)) limit 10001 +1 ks_unsharded/-: update t1 set floatval = 9.99 limit 10001 /* DECIMAL(3,2) */ 1 ks_unsharded/-: commit ---------------------------------------------------------------------- @@ -37,7 +37,7 @@ delete from t1 where id = 100 ---------------------------------------------------------------------- insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14 -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) on duplicate key update intval = 3, floatval = cast(3.14 as DECIMAL(3, 2)) +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 /* DECIMAL(3,2) */ ---------------------------------------------------------------------- select ID from t1 From 3425c8cc54184293d6c8ea8913088efc4765cebd Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 26 Jun 2024 16:07:27 -0600 Subject: [PATCH 11/19] Naive fix for bind vars types Signed-off-by: Florent Poinsard --- go/vt/vtgate/evalengine/expr_bvar.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 0fffe3140a2..daf64296e98 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -83,9 +83,6 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "query argument '%s' cannot be a tuple", bv.Key) } typ := bvar.Type - if bv.typed() { - typ = bv.Type - } return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)), nil) } } From 812ba9c7adf3d4a97483e0f4d45d926b5278f35a Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 26 Jun 2024 16:14:12 -0600 Subject: [PATCH 12/19] Fix TestNormalizeAllFields Signed-off-by: Florent Poinsard --- .../endtoend/vtgate/queries/normalize/normalize_test.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 283c325bc1d..f7d6f45a784 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -28,7 +28,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" "vitess.io/vitess/go/mysql" @@ -40,12 +39,8 @@ func TestNormalizeAllFields(t *testing.T) { defer conn.Close() insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)` - normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, CAST(:vtg6 AS DECIMAL(3, 2)), CAST(:vtg7 AS DECIMAL(3, 2)), :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` - vtgateVersion, err := cluster.GetMajorVersion("vtgate") - require.NoError(t, err) - if vtgateVersion < 21 { - normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` - } + normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` + selectQuery := "select * from t1" utils.Exec(t, conn, insertQuery) qr := utils.Exec(t, conn, selectQuery) From 0ca45a239381e729eac657a888fded20de752477 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 26 Jun 2024 16:14:47 -0600 Subject: [PATCH 13/19] Rename TestToto to TestProperTypesOfPullOutValue Signed-off-by: Florent Poinsard --- go/test/endtoend/vtgate/queries/subquery/subquery_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index a32ed6fe8cc..fef088d7c1d 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -235,7 +235,7 @@ func TestSubqueries(t *testing.T) { } } -func TestToto(t *testing.T) { +func TestProperTypesOfPullOutValue(t *testing.T) { query := "select (select sum(id) from user) from user_extra" mcmp, closer := start(t) From 30743bf77b6631e032df3bbeda067e28154560f5 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 26 Jun 2024 16:18:14 -0600 Subject: [PATCH 14/19] Fix clean up func of subquery E2E test Signed-off-by: Florent Poinsard --- go/test/endtoend/vtgate/queries/subquery/subquery_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index fef088d7c1d..1a76d328bfc 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -36,7 +36,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { deleteAll := func() { _, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp") - tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx"} + tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx", "user", "user_extra"} for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -242,7 +242,7 @@ func TestProperTypesOfPullOutValue(t *testing.T) { defer closer() mcmp.Exec("INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')") - mcmp.Exec("INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')") + mcmp.Exec("INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')") r := mcmp.Exec(query) require.True(t, r.Fields[0].Type == sqltypes.Decimal) From 2d985edd526d3731b767151d456daaa08e82b0ee Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 26 Jun 2024 16:21:25 -0600 Subject: [PATCH 15/19] Skip TestProperTypesOfPullOutValue if below vtgate v21 Signed-off-by: Florent Poinsard --- go/test/endtoend/vtgate/queries/subquery/subquery_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index 1a76d328bfc..b8fcca34f1c 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -236,6 +236,8 @@ func TestSubqueries(t *testing.T) { } func TestProperTypesOfPullOutValue(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate") + query := "select (select sum(id) from user) from user_extra" mcmp, closer := start(t) From bb0046905265d89d2bb5459540797647c143062a Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 26 Jun 2024 19:04:04 -0600 Subject: [PATCH 16/19] Fix argument size and scale issues Signed-off-by: Florent Poinsard --- .../planbuilder/operators/expressions.go | 6 +++++- .../planbuilder/operators/subquery_planning.go | 18 +++++++++++++----- .../testdata/info_schema57_cases.json | 4 ++-- .../testdata/info_schema80_cases.json | 6 +++--- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index d7c3a620289..17b4bc7c3f1 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -44,7 +44,11 @@ func breakExpressionInLHSandRHS( Name: bvName, Expr: nodeExpr, }) - arg := sqlparser.NewTypedArgument(bvName, ctx.SQLTypeForExpr(nodeExpr)) + typeForExpr, _ := ctx.TypeForExpr(nodeExpr) + arg := sqlparser.NewTypedArgument(bvName, typeForExpr.Type()) + arg.Scale = typeForExpr.Scale() + arg.Size = typeForExpr.Size() + // we are replacing one of the sides of the comparison with an argument, // but we don't want to lose the type information we have, so we copy it over ctx.SemTable.CopyExprInfo(nodeExpr, arg) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index df563105f43..0c43feed6ae 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -22,8 +22,9 @@ import ( "golang.org/x/exp/slices" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/slice" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" @@ -365,7 +366,11 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql // need to find the argument name for it and use that instead // we can't use the column name directly, because we're in the RHS of the join name := outer.findOrAddColNameBindVarName(ctx, col) - cursor.Replace(sqlparser.NewTypedArgument(name, ctx.SQLTypeForExpr(col))) + typ, _ := ctx.TypeForExpr(col) + arg := sqlparser.NewTypedArgument(name, typ.Type()) + arg.Scale = typ.Scale() + arg.Size = typ.Size() + cursor.Replace(arg) }, nil) return result.(sqlparser.Expr) } @@ -400,12 +405,15 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp } return sqlparser.NewArgument(sq1.HasValuesName) default: - argType := sqltypes.Unknown + argType := evalengine.NewUnknownType() ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) if isAe { - argType = ctx.SQLTypeForExpr(ae.Expr) + argType, _ = ctx.TypeForExpr(ae.Expr) } - return sqlparser.NewTypedArgument(s, argType) + arg := sqlparser.NewTypedArgument(s, argType.Type()) + arg.Scale = argType.Scale() + arg.Size = argType.Size() + return arg } } diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json index 230982b2870..31246a2f40f 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json @@ -319,7 +319,7 @@ "Sharded": false }, "FieldQuery": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where 1 != 1", - "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR */", + "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR(64) */", "SysTableTableSchema": "[:v2]", "Table": "information_schema.referential_constraints" } @@ -723,7 +723,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR */", + "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR(64) */", "Table": "`user`", "Values": [ ":x_COLUMN_NAME" diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json index 63f3bf2373f..9553210174c 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json @@ -319,7 +319,7 @@ "Sharded": false }, "FieldQuery": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where 1 != 1", - "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR */", + "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR(64) */", "SysTableTableSchema": "[:v2]", "Table": "information_schema.referential_constraints" } @@ -445,7 +445,7 @@ "Sharded": false }, "FieldQuery": "select 1 from information_schema.table_constraints as tc where 1 != 1", - "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name /* VARCHAR */ and tc.constraint_schema = :__vtschemaname /* VARCHAR */", + "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name /* VARCHAR(64) */ and tc.constraint_schema = :__vtschemaname /* VARCHAR */", "SysTableTableName": "[tc_table_name:'table_name']", "SysTableTableSchema": "['table_schema', :cc_constraint_schema]", "Table": "information_schema.table_constraints" @@ -788,7 +788,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR */", + "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR(64) */", "Table": "`user`", "Values": [ ":x_COLUMN_NAME" From 9aa70a5b225401898f51c150af57363801b977a3 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 27 Jun 2024 10:34:05 +0200 Subject: [PATCH 17/19] undo test changes Signed-off-by: Andres Taylor --- go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt b/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt index a1cf83a5726..b7299002d01 100644 --- a/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt @@ -12,7 +12,7 @@ select * from t1 insert into t1 (id,intval,floatval) values (1,2,3.14) 1 ks_unsharded/-: begin -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) 2 ks_unsharded/-: commit ---------------------------------------------------------------------- @@ -42,7 +42,7 @@ delete from t1 where id = 100 insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14 1 ks_unsharded/-: begin -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, cast(3.14 as DECIMAL(3, 2))) on duplicate key update intval = 3, floatval = cast(3.14 as DECIMAL(3, 2)) +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 2 ks_unsharded/-: commit ----------------------------------------------------------------------- +---------------------------------------------------------------------- \ No newline at end of file From 5f4c9d9832fbb02064f1f63a43e0822490f427bd Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 27 Jun 2024 11:14:25 +0200 Subject: [PATCH 18/19] imports Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/operators/subquery_planning.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 0c43feed6ae..2fee79b6b9f 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -22,12 +22,11 @@ import ( "golang.org/x/exp/slices" - "vitess.io/vitess/go/vt/vtgate/evalengine" - "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" ) From 45e581ad4588613c85f5d457176f36968959119b Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 27 Jun 2024 15:24:38 +0200 Subject: [PATCH 19/19] refactor: remove unused code and add comments Signed-off-by: Andres Taylor --- .../planbuilder/operators/queryprojection.go | 35 +++++++------------ .../operators/subquery_planning.go | 26 +++++++++++--- go/vt/vtgate/semantics/semantic_state.go | 2 +- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 3beec57e413..352a5ffc7a7 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -68,26 +68,22 @@ type ( // Aggr encodes all information needed for aggregation functions Aggr struct { - Original *sqlparser.AliasedExpr - Func sqlparser.AggrFunc // if we are missing a Func, it means this is a AggregateAnyValue - OpCode opcode.AggregateOpcode + Original *sqlparser.AliasedExpr // The original SQL expression for the aggregation + Func sqlparser.AggrFunc // The aggregation function (e.g., COUNT, SUM). If nil, it means AggregateAnyValue is used + OpCode opcode.AggregateOpcode // The opcode representing the type of aggregation being performed - // OriginalOpCode will contain opcode.AggregateUnassigned unless we are changing opcode while pushing them down + // OriginalOpCode will contain opcode.AggregateUnassigned unless we are changing the opcode while pushing them down OriginalOpCode opcode.AggregateOpcode - Alias string + Alias string // The alias name for the aggregation result - // The index at which the user expects to see this aggregated function. Set to nil, if the user does not ask for it - // Only used in the old Horizon Planner - Index *int + Distinct bool // Whether the aggregation function is DISTINCT - Distinct bool - - // the offsets point to columns on the same aggregator - ColOffset int - WSOffset int + // Offsets pointing to columns within the same aggregator + ColOffset int // Offset for the column being aggregated + WSOffset int // Offset for the weight string of the column - SubQueryExpression []*SubQuery + SubQueryExpression []*SubQuery // Subqueries associated with this aggregation } ) @@ -442,14 +438,12 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte // Here we go over the expressions we are returning. Since we know we are aggregating, // all expressions have to be either grouping expressions or aggregate expressions. // If we find an expression that is neither, we treat is as a special aggregation function AggrRandom - for idx, expr := range qp.SelectExprs { + for _, expr := range qp.SelectExprs { aliasedExpr, err := expr.GetAliasedExpr() if err != nil { panic(err) } - idxCopy := idx - if !ContainsAggr(ctx, expr.Col) { getExpr, err := expr.GetExpr() if err != nil { @@ -457,7 +451,6 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte } if !qp.isExprInGroupByExprs(ctx, getExpr) { aggr := NewAggr(opcode.AggregateAnyValue, nil, aliasedExpr, aliasedExpr.ColumnName()) - aggr.Index = &idxCopy out = append(out, aggr) } continue @@ -466,14 +459,13 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte panic(vterrors.VT12001("in scatter query: complex aggregate expression")) } - sqlparser.CopyOnRewrite(aliasedExpr.Expr, qp.extractAggr(ctx, idx, aliasedExpr, addAggr, makeComplex), nil, nil) + sqlparser.CopyOnRewrite(aliasedExpr.Expr, qp.extractAggr(ctx, aliasedExpr, addAggr, makeComplex), nil, nil) } return } func (qp *QueryProjection) extractAggr( ctx *plancontext.PlanningContext, - idx int, aliasedExpr *sqlparser.AliasedExpr, addAggr func(a Aggr), makeComplex func(), @@ -489,7 +481,6 @@ func (qp *QueryProjection) extractAggr( ae = aliasedExpr } aggrFunc := createAggrFromAggrFunc(aggr, ae) - aggrFunc.Index = &idx addAggr(aggrFunc) return false } @@ -497,7 +488,6 @@ func (qp *QueryProjection) extractAggr( // If we are here, we have a function that is an aggregation but not parsed into an AggrFunc. // This is the case for UDFs - we have to be careful with these because we can't evaluate them in VTGate. aggr := NewAggr(opcode.AggregateUDF, nil, aeWrap(ex), "") - aggr.Index = &idx addAggr(aggr) return false } @@ -507,7 +497,6 @@ func (qp *QueryProjection) extractAggr( } if !qp.isExprInGroupByExprs(ctx, ex) { aggr := NewAggr(opcode.AggregateAnyValue, nil, aeWrap(ex), "") - aggr.Index = &idx addAggr(aggr) } return false diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 2fee79b6b9f..5a0aed3f10d 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -26,7 +26,6 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" - "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -374,7 +373,17 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql return result.(sqlparser.Expr) } -func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Expr, se SubQueryExpression, subqueries ...*SubQuery) sqlparser.Expr { +// rewriteColNameToArgument rewrites the column names in the expression to use the argument names instead +// this is used when we push an operator from above the subquery into the outer side of the subquery +func rewriteColNameToArgument( + ctx *plancontext.PlanningContext, + in sqlparser.Expr, // the expression to rewrite + se SubQueryExpression, // the subquery expression we are rewriting + subqueries ...*SubQuery, // the inner subquery operators +) sqlparser.Expr { + // the visitor function that will rewrite the expression tree + // it will be invoked on unqualified column names, and replace them with arguments + // when the column is representing a subquery rewriteIt := func(s string) sqlparser.SQLNode { var sq1, sq2 *SubQuery for _, sq := range se { @@ -404,11 +413,18 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp } return sqlparser.NewArgument(sq1.HasValuesName) default: - argType := evalengine.NewUnknownType() + // for scalar value subqueries, the argument is typed based on the first expression in the subquery + // so here we make an attempt at figuring out the type of the argument ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) - if isAe { - argType, _ = ctx.TypeForExpr(ae.Expr) + if !isAe { + return sqlparser.NewArgument(s) + } + + argType, found := ctx.TypeForExpr(ae.Expr) + if !found { + return sqlparser.NewArgument(s) } + arg := sqlparser.NewTypedArgument(s, argType.Type()) arg.Scale = argType.Scale() arg.Size = argType.Size() diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 1dcaaf87061..0544764b04f 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -671,7 +671,7 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0, nil), true } - return evalengine.Type{}, false + return evalengine.NewUnknownType(), false } // NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons