diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java new file mode 100644 index 0000000000000..a1b6e50e48b2a --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java @@ -0,0 +1,371 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.nativeworker; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.eventlistener.EventListener; +import com.facebook.presto.split.PageSourceManager; +import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; +import com.facebook.presto.sql.planner.NodePartitioningManager; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.MaterializedRow; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.transaction.TransactionManager; +import org.testcontainers.containers.BindMode; +import org.testcontainers.containers.Container; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.wait.strategy.Wait; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; +import java.util.stream.Collectors; + +import static org.testng.Assert.fail; + +public class ContainerQueryRunner + implements QueryRunner +{ + private static final Network network = Network.newNetwork(); + private static final String PRESTO_COORDINATOR_IMAGE = System.getProperty("coordinatorImage", "presto-coordinator:latest"); + private static final String PRESTO_WORKER_IMAGE = System.getProperty("workerImage", "presto-worker:latest"); + private static final String BASE_DIR = System.getProperty("user.dir"); + private final GenericContainer coordinator; + private final GenericContainer worker; + + public ContainerQueryRunner() + throws InterruptedException + { + coordinator = new GenericContainer<>(PRESTO_COORDINATOR_IMAGE) + .withExposedPorts(8081) + .withNetwork(network).withNetworkAliases("presto-coordinator") + .withFileSystemBind(BASE_DIR + "/testcontainers/coordinator/etc", "/opt/presto-server/etc", BindMode.READ_WRITE) + .withFileSystemBind(BASE_DIR + "/testcontainers/coordinator/entrypoint.sh", "/opt/entrypoint.sh", BindMode.READ_ONLY) + .waitingFor(Wait.forLogMessage(".*======== SERVER STARTED ========.*", 1)) + .withStartupTimeout(Duration.ofSeconds(120)); + + worker = new GenericContainer<>(PRESTO_WORKER_IMAGE) + .withExposedPorts(7777) + .withNetwork(network).withNetworkAliases("presto-worker") + .withFileSystemBind(BASE_DIR + "/testcontainers/nativeworker/velox-etc", "/opt/presto-server/etc", BindMode.READ_ONLY) + .withFileSystemBind(BASE_DIR + "/testcontainers/nativeworker/entrypoint.sh", "/opt/entrypoint.sh", BindMode.READ_ONLY) + .waitingFor(Wait.forLogMessage(".*Announcement succeeded: HTTP 202.*", 1)); + + coordinator.start(); + worker.start(); + + TimeUnit.SECONDS.sleep(20); + } + + public static MaterializedResult toMaterializedResult(String csvData) + { + List rows = new ArrayList<>(); + List columnTypes = new ArrayList<>(); + + // Split the CSV data into lines + String[] lines = csvData.split("\n"); + + // Parse all rows and collect them + List allRows = parseCsvLines(lines); + + // Infer column types based on the maximum columns found + int maxColumns = allRows.stream().mapToInt(row -> row.length).max().orElse(0); + for (int i = 0; i < maxColumns; i++) { + final int columnIndex = i; // Make index effectively final + columnTypes.add(inferType(allRows.stream().map(row -> columnIndex < row.length ? row[columnIndex] : "").collect(Collectors.toList()))); + } + + // Convert all rows to MaterializedRow + for (String[] columns : allRows) { + List values = new ArrayList<>(); + for (int i = 0; i < columnTypes.size(); i++) { + values.add(i < columns.length ? convertToType(columns[i], columnTypes.get(i)) : null); + } + rows.add(new MaterializedRow(1, values)); + } + + // Create and return the MaterializedResult + return new MaterializedResult(rows, columnTypes); + } + + private static List parseCsvLines(String[] lines) + { + List allRows = new ArrayList<>(); + List currentRow = new ArrayList<>(); + StringBuilder currentField = new StringBuilder(); + boolean insideQuotes = false; + + for (String line : lines) { + for (int i = 0; i < line.length(); i++) { + char ch = line.charAt(i); + if (ch == '"') { + // Handle double quotes inside quoted string + if (insideQuotes && i + 1 < line.length() && line.charAt(i + 1) == '"') { + currentField.append(ch); + i++; + } + else { + insideQuotes = !insideQuotes; + } + } + else if (ch == ',' && !insideQuotes) { + currentRow.add(currentField.toString()); + currentField.setLength(0); // Clear the current field + } + else { + currentField.append(ch); + } + } + if (insideQuotes) { + currentField.append('\n'); // Add newline for multiline fields + } + else { + currentRow.add(currentField.toString()); + currentField.setLength(0); // Clear the current field + allRows.add(currentRow.toArray(new String[0])); + currentRow.clear(); + } + } + if (!currentRow.isEmpty()) { + currentRow.add(currentField.toString()); + allRows.add(currentRow.toArray(new String[0])); + } + return allRows; + } + + private static Type inferType(List values) + { + boolean isBigint = true; + boolean isDouble = true; + boolean isBoolean = true; + + for (String value : values) { + if (!value.matches("^-?\\d+$")) { + isBigint = false; + } + if (!value.matches("^-?\\d+\\.\\d+$")) { + isDouble = false; + } + if (!value.equalsIgnoreCase("true") && !value.equalsIgnoreCase("false")) { + isBoolean = false; + } + } + + if (isBigint) { + return BigintType.BIGINT; + } + else if (isDouble) { + return DoubleType.DOUBLE; + } + else if (isBoolean) { + return BooleanType.BOOLEAN; + } + else { + return VarcharType.VARCHAR; + } + } + + private static Object convertToType(String value, Type type) + { + if (type.equals(VarcharType.VARCHAR)) { + return value; + } + else if (type.equals(BigintType.BIGINT)) { + return Long.parseLong(value); + } + else if (type.equals(DoubleType.DOUBLE)) { + return Double.parseDouble(value); + } + else if (type.equals(BooleanType.BOOLEAN)) { + return Boolean.parseBoolean(value); + } + else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + public Container.ExecResult executeQuery(String sql) + { + String[] command = { + "/opt/presto-cli", + "--server", + "presto-coordinator:8081", + "--execute", + sql + }; + + Container.ExecResult execResult = null; + try { + execResult = coordinator.execInContainer(command); + } + catch (IOException e) { + throw new RuntimeException(e); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + + if (execResult.getExitCode() != 0) { + String errorDetails = "Stdout: " + execResult.getStdout() + "\nStderr: " + execResult.getStderr(); + fail("Presto CLI exited with error code: " + execResult.getExitCode() + "\n" + errorDetails); + } + return execResult; + } + + @Override + public void close() + { + coordinator.stop(); + worker.stop(); + } + + @Override + public TransactionManager getTransactionManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public Metadata getMetadata() + { + throw new UnsupportedOperationException(); + } + + @Override + public SplitManager getSplitManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public PageSourceManager getPageSourceManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public NodePartitioningManager getNodePartitioningManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectorPlanOptimizerManager getPlanOptimizerManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public StatsCalculator getStatsCalculator() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getEventListener() + { + throw new UnsupportedOperationException(); + } + + @Override + public TestingAccessControlManager getAccessControl() + { + throw new UnsupportedOperationException(); + } + + @Override + public MaterializedResult execute(String sql) + { + throw new UnsupportedOperationException(); + } + + @Override + public MaterializedResult execute(Session session, String sql, List resultTypes) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listTables(Session session, String catalog, String schema) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean tableExists(Session session, String table) + { + throw new UnsupportedOperationException(); + } + + @Override + public void installPlugin(Plugin plugin) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createCatalog(String catalogName, String connectorName, Map properties) + { + throw new UnsupportedOperationException(); + } + + @Override + public void loadFunctionNamespaceManager(String functionNamespaceManagerName, String catalogName, Map properties) + { + throw new UnsupportedOperationException(); + } + + @Override + public Lock getExclusiveLock() + { + throw new UnsupportedOperationException(); + } + + @Override + public int getNodeCount() + { + throw new UnsupportedOperationException(); + } + + @Override + public Session getDefaultSession() + { + return null; + } + + @Override + public MaterializedResult execute(Session session, String sql) + { + Container.ExecResult execResult = executeQuery(sql); + return toMaterializedResult(execResult.getStdout()); + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWithContainers.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWithContainers.java index 2a631141766c8..17af0be7309db 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWithContainers.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWithContainers.java @@ -13,102 +13,65 @@ */ package com.facebook.presto.nativeworker; -import org.testcontainers.containers.BindMode; -import org.testcontainers.containers.Container; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.Network; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; +import com.facebook.presto.tests.AbstractTestQueryFramework; import org.testng.annotations.Test; -import java.io.IOException; -import java.time.Duration; - -import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; - public class TestPrestoNativeWithContainers + extends AbstractTestQueryFramework { - private static final String PRESTO_COORDINATOR_IMAGE = System.getProperty("coordinatorImage", "presto-coordinator:latest"); - private static final String PRESTO_WORKER_IMAGE = System.getProperty("workerImage", "presto-worker:latest"); - private static final String BASE_DIR = System.getProperty("user.dir"); - private static final Network network = Network.newNetwork(); - private GenericContainer coordinator; - private GenericContainer worker; - - @BeforeClass - public void setUp() - throws InterruptedException - { - coordinator = new GenericContainer<>(PRESTO_COORDINATOR_IMAGE) - .withExposedPorts(8081) - .withNetwork(network).withNetworkAliases("presto-coordinator") - .withFileSystemBind(BASE_DIR + "/testcontainers/coordinator/etc", "/opt/presto-server/etc", BindMode.READ_WRITE) - .withFileSystemBind(BASE_DIR + "/testcontainers/coordinator/entrypoint.sh", "/opt/entrypoint.sh", BindMode.READ_ONLY) - .waitingFor(Wait.forLogMessage(".*======== SERVER STARTED ========.*", 1)) - .withStartupTimeout(Duration.ofSeconds(120)); - - worker = new GenericContainer<>(PRESTO_WORKER_IMAGE) - .withExposedPorts(7777) - .withNetwork(network).withNetworkAliases("presto-worker") - .withFileSystemBind(BASE_DIR + "/testcontainers/nativeworker/velox-etc", "/opt/presto-server/etc", BindMode.READ_ONLY) - .withFileSystemBind(BASE_DIR + "/testcontainers/nativeworker/entrypoint.sh", "/opt/entrypoint.sh", BindMode.READ_ONLY) - .waitingFor(Wait.forLogMessage(".*Announcement succeeded: HTTP 202.*", 1)); - - coordinator.start(); - worker.start(); - } - - @AfterClass - public void tearDown() + @Override + protected ContainerQueryRunner createQueryRunner() + throws Exception { - coordinator.stop(); - worker.stop(); - } - - private Container.ExecResult executeQuery(String sql) - throws IOException, InterruptedException - { - // Command to run inside the coordinator container using the presto-cli. - String[] command = { - "/opt/presto-cli", - "--server", - "presto-coordinator:8081", - "--execute", - sql - }; - - Container.ExecResult execResult = coordinator.execInContainer(command); - if (execResult.getExitCode() != 0) { - String errorDetails = "Stdout: " + execResult.getStdout() + "\nStderr: " + execResult.getStderr(); - fail("Presto CLI exited with error code: " + execResult.getExitCode() + "\n" + errorDetails); - } - return execResult; + return new ContainerQueryRunner(); } @Test public void testBasics() - throws IOException, InterruptedException { - String selectRuntimeNodes = "select * from system.runtime.nodes"; - executeQuery(selectRuntimeNodes); - String showCatalogs = "show catalogs"; - executeQuery(showCatalogs); - String showSession = "show session"; - executeQuery(showSession); + computeActual("select * from system.runtime.nodes"); + computeActual("show catalogs"); } @Test public void testFunctions() - throws IOException, InterruptedException { - String countValues = "SELECT COUNT(*) FROM (VALUES 1, 0, 0, 2, 3, 3) as t(x)"; - Container.ExecResult countResult = executeQuery(countValues); - assertTrue(countResult.getStdout().contains("6"), "Count is incorrect."); + assertQuery("SELECT COUNT(*) FROM (VALUES 1, 0, 0, 2, 3, 3) as t(x)", "SELECT 6"); + assertQuery("SELECT array_sort(ARRAY [5, 20, null, 5, 3, 50])", "SELECT ARRAY[3, 5, 5, 20, 50, null]"); + } - String sqlArrayIntegers = "SELECT array_sort(ARRAY [5, 20, null, 5, 3, 50])"; - Container.ExecResult execResultIntegers = executeQuery(sqlArrayIntegers); - assertTrue(execResultIntegers.getStdout().contains("[3, 5, 5, 20, 50, null]"), "Integer array not sorted correctly."); + @Test + public void testUnnest() + { + assertQuery("SELECT 1 FROM (VALUES (ARRAY[1])) AS t (a) CROSS JOIN UNNEST(a)", "SELECT 1"); + assertQuery("SELECT x[1] FROM UNNEST(ARRAY[ARRAY[1, 2, 3]]) t(x)", "SELECT 1"); + assertQuery("SELECT x[1][2] FROM UNNEST(ARRAY[ARRAY[ARRAY[1, 2, 3]]]) t(x)", "SELECT 2"); + assertQuery("SELECT x[2] FROM UNNEST(ARRAY[MAP(ARRAY[1,2], ARRAY['hello', 'hi'])]) t(x)", "SELECT 'hi'"); + assertQuery("SELECT * FROM UNNEST(ARRAY[1, 2, 3])", "SELECT * FROM VALUES (1), (2), (3)"); + assertQuery("SELECT a FROM UNNEST(ARRAY[1, 2, 3]) t(a)", "SELECT * FROM VALUES (1), (2), (3)"); + assertQuery("SELECT a, b FROM UNNEST(ARRAY[1, 2], ARRAY[3, 4]) t(a, b)", "SELECT * FROM VALUES (1, 3), (2, 4)"); + assertQuery("SELECT a FROM UNNEST(ARRAY[1, 2, 3], ARRAY[4, 5]) t(a, b)", "SELECT * FROM VALUES 1, 2, 3"); + assertQuery("SELECT count(*) FROM UNNEST(ARRAY[1, 2, 3], ARRAY[4, 5])", "SELECT 3"); + assertQuery("SELECT a FROM UNNEST(ARRAY['kittens', 'puppies']) t(a)", "SELECT * FROM VALUES ('kittens'), ('puppies')"); + assertQuery( + "WITH unioned AS ( SELECT 1 UNION ALL SELECT 2 ) SELECT * FROM unioned CROSS JOIN UNNEST(ARRAY[3]) steps (step)", + "SELECT * FROM (VALUES (1, 3), (2, 3))"); + assertQuery("SELECT c " + + "FROM UNNEST(ARRAY[1, 2, 3], ARRAY[4, 5]) t(a, b) " + + "CROSS JOIN (values (8), (9)) t2(c)", + "SELECT * FROM VALUES 8, 8, 8, 9, 9, 9"); + assertQuery("SELECT * FROM UNNEST(ARRAY[0, 1]) CROSS JOIN UNNEST(ARRAY[0, 1]) CROSS JOIN UNNEST(ARRAY[0, 1])", + "SELECT * FROM VALUES (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)"); + assertQuery("SELECT * FROM UNNEST(ARRAY[0, 1]), UNNEST(ARRAY[0, 1]), UNNEST(ARRAY[0, 1])", + "SELECT * FROM VALUES (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)"); + assertQuery("SELECT a, b FROM UNNEST(MAP(ARRAY[1,2], ARRAY['cat', 'dog'])) t(a, b)", "SELECT * FROM VALUES (1, 'cat'), (2, 'dog')"); + assertQuery("SELECT 1 FROM (VALUES (ARRAY[1])) AS t (a) CROSS JOIN UNNEST(a) WITH ORDINALITY", "SELECT 1"); + assertQuery("SELECT * FROM UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY", "SELECT * FROM VALUES (1, 1), (2, 2), (3, 3)"); + assertQuery("SELECT b FROM UNNEST(ARRAY[10, 20, 30]) WITH ORDINALITY t(a, b)", "SELECT * FROM VALUES (1), (2), (3)"); + assertQuery("SELECT a, b FROM UNNEST(ARRAY['kittens', 'puppies']) WITH ORDINALITY t(a, b)", "SELECT * FROM VALUES ('kittens', 1), ('puppies', 2)"); + assertQuery("SELECT c " + + "FROM UNNEST(ARRAY[1, 2, 3], ARRAY[4, 5]) WITH ORDINALITY t(a, b, c) " + + "CROSS JOIN (values (8), (9)) t2(d)", + "SELECT * FROM VALUES 1, 1, 2, 2, 3, 3"); } }