diff --git a/presto-sqlserver/src/test/java/io/prestosql/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java b/presto-sqlserver/src/test/java/io/prestosql/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java index b4fa68cc37bba..1d2ca168f1bad 100644 --- a/presto-sqlserver/src/test/java/io/prestosql/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java +++ b/presto-sqlserver/src/test/java/io/prestosql/plugin/sqlserver/TestSqlServerIntegrationSmokeTest.java @@ -13,6 +13,8 @@ */ package io.prestosql.plugin.sqlserver; +import io.prestosql.sql.planner.plan.AggregationNode; +import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.testing.AbstractTestIntegrationSmokeTest; import io.prestosql.testing.QueryRunner; import io.prestosql.testing.sql.TestTable; @@ -97,14 +99,40 @@ public void testAggregationPushdown() assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown(); - try (AutoCloseable ignoreTable = withTable("test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) { - sqlServer.execute("INSERT INTO test_aggregation_pushdown VALUES (100.000, 100000000.000000000)"); - sqlServer.execute("INSERT INTO test_aggregation_pushdown VALUES (123.321, 123456789.987654321)"); + try (AutoCloseable ignoreTable = withTable("test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10), varchar_column varchar(10))")) { + sqlServer.execute("INSERT INTO test_aggregation_pushdown VALUES (100.000, 100000000.000000000, 'ala')"); + sqlServer.execute("INSERT INTO test_aggregation_pushdown VALUES (123.321, 123456789.987654321, 'kot')"); assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); assertThat(query("SELECT max(short_decimal), max(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); assertThat(query("SELECT sum(short_decimal), sum(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); assertThat(query("SELECT avg(short_decimal), avg(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); + + // WHERE on aggregation column + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM test_aggregation_pushdown WHERE short_decimal < 110 AND long_decimal < 124")).isCorrectlyPushedDown(); + // WHERE on non-aggregation column + assertThat(query("SELECT min(long_decimal) FROM test_aggregation_pushdown WHERE short_decimal < 110")).isCorrectlyPushedDown(); + // GROUP BY + assertThat(query("SELECT short_decimal, min(long_decimal) FROM test_aggregation_pushdown GROUP BY short_decimal")).isCorrectlyPushedDown(); + // GROUP BY with WHERE on both grouping and aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM test_aggregation_pushdown WHERE short_decimal < 110 AND long_decimal < 124" + " GROUP BY short_decimal")).isCorrectlyPushedDown(); + // GROUP BY with WHERE on grouping column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM test_aggregation_pushdown WHERE short_decimal < 110 GROUP BY short_decimal")).isCorrectlyPushedDown(); + // GROUP BY with WHERE on aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM test_aggregation_pushdown WHERE long_decimal < 124 GROUP BY short_decimal")).isCorrectlyPushedDown(); + // GROUP BY with WHERE on neither grouping nor aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM test_aggregation_pushdown WHERE varchar_column = 'ala' GROUP BY short_decimal")).isCorrectlyPushedDown(); + // aggregation on varchar column + assertThat(query("SELECT min(varchar_column) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); + // aggregation on varchar column with GROUPING + assertThat(query("SELECT short_decimal, min(varchar_column) FROM test_aggregation_pushdown GROUP BY short_decimal")).isCorrectlyPushedDown(); + // aggregation on varchar column with WHERE + assertThat(query("SELECT min(varchar_column) FROM test_aggregation_pushdown WHERE varchar_column ='ala'")).isCorrectlyPushedDown(); + + // not supported yet + assertThat(query("SELECT min(DISTINCT short_decimal) FROM test_aggregation_pushdown")).isNotFullyPushedDown(AggregationNode.class); + // TODO: Improve assertion framework. Here min(long_decimal) is pushed down. There remains ProjectNode above it which relates to DISTINCT in the query. + assertThat(query("SELECT DISTINCT short_decimal, min(long_decimal) FROM test_aggregation_pushdown GROUP BY short_decimal")).isNotFullyPushedDown(ProjectNode.class); } }