Skip to content

Commit

Permalink
Execute should block until at least one row is received
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Sep 16, 2022
1 parent 4e7f6be commit 805b35d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 21 deletions.
13 changes: 6 additions & 7 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def test_execute_many_without_params(trino_connection):
cur = trino_connection.cursor()
cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
cur.fetchall()
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
with pytest.raises(TrinoUserError) as e:
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
cur.fetchall()
assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)

Expand Down Expand Up @@ -883,13 +883,12 @@ def test_transaction_autocommit(trino_connection_in_autocommit):
with trino_connection_in_autocommit as connection:
connection.start_transaction()
cur = connection.cursor()
cur.execute(
"""
CREATE TABLE memory.default.nation
AS SELECT * from tpch.tiny.nation
""")

with pytest.raises(TrinoUserError) as transaction_error:
cur.execute(
"""
CREATE TABLE memory.default.nation
AS SELECT * from tpch.tiny.nation
""")
cur.fetchall()
assert "Catalog only supports writes using autocommit: memory" \
in str(transaction_error.value)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def sample_post_response_data():
"""

yield {
"nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1",
"nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1",
"id": "20210817_140827_00000_arvdv",
"taskDownloadUris": [],
"infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv",
"infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv",
"stats": {
"scheduled": False,
"runningSplits": 0,
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,9 +892,7 @@ def test_trino_result_response_headers():
'X-Trino-Fake-2': 'two',
})

result = TrinoResult(
query=mock_trino_query,
)
result = TrinoResult(query=mock_trino_query, rows=[])
assert result.response_headers == mock_trino_query.response_headers


Expand Down
34 changes: 28 additions & 6 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,28 @@ def test_http_session_is_defaulted_when_not_specified(mock_client):


@httprettified
def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data)

# bind post statement
# bind post statement to submit query
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
Expand Down Expand Up @@ -108,21 +115,29 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):


@httprettified
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data):
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data,
sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data)

# bind post statement
# bind post statement to submit query
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
Expand Down Expand Up @@ -166,21 +181,28 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post


@httprettified
def test_token_retrieved_once_when_multithreaded(sample_post_response_data):
def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data)

# bind post statement
# bind post statement to submit query
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
Expand Down
17 changes: 14 additions & 3 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,20 @@ class TrinoResult(object):
https://docs.python.org/3/library/stdtypes.html#generator-types
"""

def __init__(self, query, rows=None):
def __init__(self, query, rows: List[Any]):
self._query = query
# Initial rows from the first POST request
self._rows = rows
self._rownumber = 0

@property
def rows(self):
return self._rows

@rows.setter
def rows(self, rows):
self._rows = rows

@property
def rownumber(self) -> int:
return self._rownumber
Expand Down Expand Up @@ -650,7 +658,7 @@ def columns(self):
while not self._columns and not self.finished and not self.cancelled:
# Columns are not returned immediately after query is submitted.
# Continue fetching data until columns information is available and push fetched rows into buffer.
self._result._rows += self.fetch()
self._result.rows += self.fetch()
return self._columns

@property
Expand Down Expand Up @@ -695,8 +703,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
self._finished = True

rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows

self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
return self._result

def _update_state(self, status):
Expand Down

0 comments on commit 805b35d

Please sign in to comment.