diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 1dc8f05a..d616b0f0 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -177,6 +177,19 @@ def test_conjunctions(trino_connection): assert len(rows) == 1 +@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +def test_finished_state(trino_connection): + _, conn = trino_connection + metadata = sqla.MetaData() + queries = sqla.Table('queries', metadata, schema='runtime', autoload_with=conn) + s = sqla.select(queries.c.state).where(queries.c.query == "SELECT version()") + result = conn.execute(s) + rows = result.fetchall() + assert len(rows) > 0 + for row in rows: + assert row['state'] == 'FINISHED' + + @pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) def test_textual_sql(trino_connection): _, conn = trino_connection diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index c7056cc4..90cabe8f 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -317,7 +317,7 @@ def _get_server_version_info(self, connection: Connection) -> Any: query = "SELECT version()" try: res = connection.execute(sql.text(query)) - version = res.scalar() + version = res.scalar_one() return tuple([version]) except exc.ProgrammingError as e: logger.debug(f"Failed to get server version: {e.orig.message}")