diff --git a/pkg/expression/expr_to_pb_test.go b/pkg/expression/expr_to_pb_test.go index 6b32b453f8fb4..e083d35c3abaa 100644 --- a/pkg/expression/expr_to_pb_test.go +++ b/pkg/expression/expr_to_pb_test.go @@ -1291,6 +1291,44 @@ func TestExprPushDownToFlash(t *testing.T) { require.NoError(t, err) exprs = append(exprs, function) + // CastIntAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), intColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // CastRealAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), realColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // CastDecimalAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), decimalColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // CastStringAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), stringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), binaryStringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // CastTimeAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), datetimeColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // CastDurationAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), durationColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // CastJsonAsJson + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeJSON), jsonColumn) + require.NoError(t, err) + exprs = append(exprs, function) + pushed, remained = PushDownExprs(ctx, exprs, client, kv.TiFlash) require.Len(t, pushed, len(exprs)) require.Len(t, remained, 0) diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index 64ace07c7c123..d6015049022ee 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -1237,6 +1237,9 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool { return function.GetArgs()[0].GetType().GetType() != mysql.TypeYear case tipb.ScalarFuncSig_CastTimeAsDuration: return retType.GetType() == mysql.TypeDuration + case tipb.ScalarFuncSig_CastIntAsJson, tipb.ScalarFuncSig_CastRealAsJson, tipb.ScalarFuncSig_CastDecimalAsJson, tipb.ScalarFuncSig_CastStringAsJson, + tipb.ScalarFuncSig_CastTimeAsJson, tipb.ScalarFuncSig_CastDurationAsJson, tipb.ScalarFuncSig_CastJsonAsJson: + return true } case ast.DateAdd, ast.AddDate: switch function.Function.PbCode() {