Skip to content

Commit

Permalink
*: fix case-sensitivity issue with EQ/IN when query information_schem…
Browse files Browse the repository at this point in the history
…a.columns (#31463)

ref #31481
  • Loading branch information
hawkingrei authored Jan 10, 2022
1 parent 4baab3c commit a082abd
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 6 deletions.
6 changes: 6 additions & 0 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ Projection 12500.00 root ifnull(Column#10, 0)->Column#10
explain format = 'brief' select * from information_schema.columns;
id estRows task access object operator info
MemTableScan 10000.00 root table:COLUMNS
explain format = 'brief' select * from information_schema.columns where table_name = 'T1';
id estRows task access object operator info
MemTableScan 10000.00 root table:COLUMNS table_name:["t1"]
explain format = 'brief' select * from information_schema.columns where table_schema = 'TEST' and table_name = 'T1' and column_name = 'c1';
id estRows task access object operator info
MemTableScan 10000.00 root table:COLUMNS table_schema:["test"], table_name:["t1"], column_name:["c1"]
explain format = 'brief' select c2 = (select c2 from t2 where t1.c1 = t2.c1 order by c1 limit 1) from t1;
id estRows task access object operator info
Projection 10000.00 root eq(test.t1.c2, test.t2.c2)->Column#11
Expand Down
2 changes: 2 additions & 0 deletions cmd/explaintest/t/explain_easy.test
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ explain format = 'brief' select sum(t1.c1 in (select c1 from t2)) from t1;
explain format = 'brief' select c1 from t1 where c1 in (select c2 from t2);
explain format = 'brief' select (select count(1) k from t1 s where s.c1 = t1.c1 having k != 0) from t1;
explain format = 'brief' select * from information_schema.columns;
explain format = 'brief' select * from information_schema.columns where table_name = 'T1';
explain format = 'brief' select * from information_schema.columns where table_schema = 'TEST' and table_name = 'T1' and column_name = 'c1';
explain format = 'brief' select c2 = (select c2 from t2 where t1.c1 = t2.c1 order by c1 limit 1) from t1;
explain format = 'brief' select * from t1 order by c1 desc limit 1;
explain format = 'brief' select * from t4 use index(idx) where a > 1 and b > 1 and c > 1 limit 1;
Expand Down
5 changes: 3 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1715,8 +1715,9 @@ func (b *executorBuilder) buildMemTable(v *plannercore.PhysicalMemTable) Executo
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID()),
table: v.Table,
retriever: &hugeMemTableRetriever{
table: v.Table,
columns: v.Columns,
table: v.Table,
columns: v.Columns,
extractor: v.Extractor.(*plannercore.ColumnsTableExtractor),
},
}

Expand Down
28 changes: 24 additions & 4 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc
return nil
}

func (e *hugeMemTableRetriever) setDataForColumns(ctx context.Context, sctx sessionctx.Context) error {
func (e *hugeMemTableRetriever) setDataForColumns(ctx context.Context, sctx sessionctx.Context, extractor *plannercore.ColumnsTableExtractor) error {
checker := privilege.GetPrivilegeManager(sctx)
e.rows = e.rows[:0]
batch := 1024
Expand All @@ -678,7 +678,7 @@ func (e *hugeMemTableRetriever) setDataForColumns(ctx context.Context, sctx sess
}
}

e.dataForColumnsInTable(ctx, sctx, schema, table, priv)
e.dataForColumnsInTable(ctx, sctx, schema, table, priv, extractor)
if len(e.rows) >= batch {
return nil
}
Expand All @@ -688,15 +688,34 @@ func (e *hugeMemTableRetriever) setDataForColumns(ctx context.Context, sctx sess
return nil
}

func (e *hugeMemTableRetriever) dataForColumnsInTable(ctx context.Context, sctx sessionctx.Context, schema *model.DBInfo, tbl *model.TableInfo, priv mysql.PrivilegeType) {
func (e *hugeMemTableRetriever) dataForColumnsInTable(ctx context.Context, sctx sessionctx.Context, schema *model.DBInfo, tbl *model.TableInfo, priv mysql.PrivilegeType, extractor *plannercore.ColumnsTableExtractor) {
if err := tryFillViewColumnType(ctx, sctx, sctx.GetInfoSchema().(infoschema.InfoSchema), schema.Name, tbl); err != nil {
sctx.GetSessionVars().StmtCtx.AppendWarning(err)
return
}
var tableSchemaFilterEnable,
tableNameFilterEnable, columnsFilterEnable bool
if !extractor.SkipRequest {
tableSchemaFilterEnable = extractor.TableSchema.Count() > 0
tableNameFilterEnable = extractor.TableName.Count() > 0
columnsFilterEnable = extractor.ColumnName.Count() > 0
}
for i, col := range tbl.Columns {
if col.Hidden {
continue
}
if !extractor.SkipRequest {
if tableSchemaFilterEnable && !extractor.TableSchema.Exist(schema.Name.L) {
continue
}
if tableNameFilterEnable && !extractor.TableName.Exist(tbl.Name.L) {
continue
}
if columnsFilterEnable && !extractor.ColumnName.Exist(col.Name.L) {
continue
}
}

var charMaxLen, charOctLen, numericPrecision, numericScale, datetimePrecision interface{}
colLen, decimal := col.Flen, col.Decimal
defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(col.Tp)
Expand Down Expand Up @@ -2512,6 +2531,7 @@ func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx.

type hugeMemTableRetriever struct {
dummyCloser
extractor *plannercore.ColumnsTableExtractor
table *model.TableInfo
columns []*model.ColumnInfo
retrieved bool
Expand Down Expand Up @@ -2540,7 +2560,7 @@ func (e *hugeMemTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Co
var err error
switch e.table.Name.O {
case infoschema.TableColumns:
err = e.setDataForColumns(ctx, sctx)
err = e.setDataForColumns(ctx, sctx, e.extractor)
}
if err != nil {
return nil, err
Expand Down
2 changes: 2 additions & 0 deletions infoschema/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ func TestInfoSchemaFieldValue(t *testing.T) {
tk.MustExec("create table t ( s set('a','bc','def','ghij') default NULL, e1 enum('a', 'ab', 'cdef'), s2 SET('1','2','3','4','1585','ONE','TWO','Y','N','THREE'))")
tk.MustQuery("select column_name, character_maximum_length from information_schema.columns where table_schema=Database() and table_name = 't' and column_name = 's'").Check(
testkit.Rows("s 13"))
tk.MustQuery("select column_name, character_maximum_length from information_schema.columns where table_schema=Database() and table_name = 't' and column_name = 'S'").Check(
testkit.Rows("s 13"))
tk.MustQuery("select column_name, character_maximum_length from information_schema.columns where table_schema=Database() and table_name = 't' and column_name = 's2'").Check(
testkit.Rows("s2 30"))
tk.MustQuery("select column_name, character_maximum_length from information_schema.columns where table_schema=Database() and table_name = 't' and column_name = 'e1'").Check(
Expand Down
2 changes: 2 additions & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4341,6 +4341,8 @@ func (b *PlanBuilder) buildMemTable(_ context.Context, dbName model.CIStr, table
p.Extractor = &StatementsSummaryExtractor{}
case infoschema.TableTiKVRegionPeers:
p.Extractor = &TikvRegionPeersExtractor{}
case infoschema.TableColumns:
p.Extractor = &ColumnsTableExtractor{}
}
}
return p, nil
Expand Down
52 changes: 52 additions & 0 deletions planner/core/memtable_predicate_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1468,3 +1468,55 @@ func (e *TikvRegionPeersExtractor) explainInfo(p *PhysicalMemTable) string {
}
return s
}

// ColumnsTableExtractor is used to extract some predicates of columns table.
type ColumnsTableExtractor struct {
extractHelper

// SkipRequest means the where clause always false, we don't need to request any component
SkipRequest bool

TableSchema set.StringSet

TableName set.StringSet
// ColumnName represents all column name we should filter in memtable.
ColumnName set.StringSet
}

// Extract implements the MemTablePredicateExtractor Extract interface
func (e *ColumnsTableExtractor) Extract(_ sessionctx.Context,
schema *expression.Schema,
names []*types.FieldName,
predicates []expression.Expression,
) (remained []expression.Expression) {
remained, tableSchemaSkipRequest, tableSchema := e.extractCol(schema, names, predicates, "table_schema", true)
remained, tableNameSkipRequest, tableName := e.extractCol(schema, names, remained, "table_name", true)
remained, columnNameSkipRequest, columnName := e.extractCol(schema, names, remained, "column_name", true)
e.SkipRequest = columnNameSkipRequest || tableSchemaSkipRequest || tableNameSkipRequest
e.ColumnName = columnName
e.TableName = tableName
e.TableSchema = tableSchema
return remained
}

func (e *ColumnsTableExtractor) explainInfo(p *PhysicalMemTable) string {
if e.SkipRequest {
return "skip_request:true"
}
r := new(bytes.Buffer)
if len(e.TableSchema) > 0 {
r.WriteString(fmt.Sprintf("table_schema:[%s], ", extractStringFromStringSet(e.TableSchema)))
}
if len(e.TableName) > 0 {
r.WriteString(fmt.Sprintf("table_name:[%s], ", extractStringFromStringSet(e.TableName)))
}
if len(e.ColumnName) > 0 {
r.WriteString(fmt.Sprintf("column_name:[%s], ", extractStringFromStringSet(e.ColumnName)))
}
// remove the last ", " in the message info
s := r.String()
if len(s) > 2 {
return s[:len(s)-2]
}
return s
}
64 changes: 64 additions & 0 deletions planner/core/memtable_predicate_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1552,3 +1552,67 @@ func TestTikvRegionPeersExtractor(t *testing.T) {
}
}
}

func TestColumns(t *testing.T) {
store, dom, clean := testkit.CreateMockStoreAndDomain(t)
defer clean()

se, err := session.CreateSession4Test(store)
require.NoError(t, err)

var cases = []struct {
sql string
columnName set.StringSet
tableSchema set.StringSet
tableName set.StringSet
skipRequest bool
}{
{
sql: `select * from INFORMATION_SCHEMA.COLUMNS where column_name='T';`,
columnName: set.NewStringSet("t"),
},
{
sql: `select * from INFORMATION_SCHEMA.COLUMNS where table_schema='TEST';`,
tableSchema: set.NewStringSet("test"),
},
{
sql: `select * from INFORMATION_SCHEMA.COLUMNS where table_name='TEST';`,
tableName: set.NewStringSet("test"),
},
{
sql: "select * from information_schema.COLUMNS where table_name in ('TEST','t') and column_name in ('A','b')",
columnName: set.NewStringSet("a", "b"),
tableName: set.NewStringSet("test", "t"),
},
{
sql: `select * from information_schema.COLUMNS where table_name='a' and table_name in ('a', 'B');`,
tableName: set.NewStringSet("a"),
},
{
sql: `select * from information_schema.COLUMNS where table_name='a' and table_name='B';`,
skipRequest: true,
},
}
parser := parser.New()
for _, ca := range cases {
logicalMemTable := getLogicalMemTable(t, dom, se, parser, ca.sql)
require.NotNil(t, logicalMemTable.Extractor)

columnsTableExtractor := logicalMemTable.Extractor.(*plannercore.ColumnsTableExtractor)
require.Equal(t, ca.skipRequest, columnsTableExtractor.SkipRequest, "SQL: %v", ca.sql)

require.Equal(t, ca.columnName.Count(), columnsTableExtractor.ColumnName.Count())
if ca.columnName.Count() > 0 && columnsTableExtractor.ColumnName.Count() > 0 {
require.EqualValues(t, ca.columnName, columnsTableExtractor.ColumnName, "SQL: %v", ca.sql)
}

require.Equal(t, ca.tableSchema.Count(), columnsTableExtractor.TableSchema.Count())
if ca.tableSchema.Count() > 0 && columnsTableExtractor.TableSchema.Count() > 0 {
require.EqualValues(t, ca.tableSchema, columnsTableExtractor.TableSchema, "SQL: %v", ca.sql)
}
require.Equal(t, ca.tableName.Count(), columnsTableExtractor.TableName.Count())
if ca.tableName.Count() > 0 && columnsTableExtractor.TableName.Count() > 0 {
require.EqualValues(t, ca.tableName, columnsTableExtractor.TableName, "SQL: %v", ca.sql)
}
}
}
2 changes: 2 additions & 0 deletions privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2923,7 +2923,9 @@ func TestIssue28675(t *testing.T) {
tk.MustExec("grant select on test.v to test_user")
require.True(t, tk.Session().Auth(&auth.UserIdentity{Username: "test_user", Hostname: "localhost"}, nil, nil))
tk.MustQuery("select count(*) from information_schema.columns where table_schema='test' and table_name='v'").Check(testkit.Rows("1"))
tk.MustQuery("select count(*) from information_schema.columns where table_schema='Test' and table_name='V'").Check(testkit.Rows("1"))
tk.MustQuery("select privileges from information_schema.columns where table_schema='test' and table_name='v'").Check(testkit.Rows("select,update"))
tk.MustQuery("select privileges from information_schema.columns where table_schema='Test' and table_name='V'").Check(testkit.Rows("select,update"))
require.Equal(t, 1, len(tk.MustQuery("desc test.v").Rows()))
require.Equal(t, 1, len(tk.MustQuery("explain test.v").Rows()))
}
Expand Down

0 comments on commit a082abd

Please sign in to comment.