Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Errors propagate through transaction (#247)
Browse files Browse the repository at this point in the history
## What is the goal of this PR?

We align with Client NodeJS version 2.6.1 (some of the work in typedb/typedb-driver-nodejs#197), which implements a better error propagation mechanism: when an exception occurs, we store it against all the transaction's active transmit queues to retrieve whenever the user tries to perform an operation in the transaction anywhere.

## What are the changes implemented in this PR?

* store errors received from gRPC against each receive queue
* return a new exception type, transaction is closed with errors, which throws the errors from all queues (note that this can be duplicate if there are multiple open transmit queues that have been given the same error)
* we clean up queues that are no longer needed, so we minimise the number of times the user sees the same exception
  • Loading branch information
flyingsilverfin authored Jan 20, 2022
1 parent 5acc0dd commit 56c990b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 43 deletions.
2 changes: 1 addition & 1 deletion typedb/connection/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False

def _raise_transaction_closed(self):
errors = self._bidirectional_stream.drain_errors()
errors = self._bidirectional_stream.get_errors()
if len(errors) == 0:
raise TypeDBClientException.of(TRANSACTION_CLOSED)
else:
Expand Down
18 changes: 12 additions & 6 deletions typedb/stream/bidirectional_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction
self._dispatcher.dispatch(req)
return ResponsePartIterator(request_id, self, self._dispatcher)

def done(self, request_id: UUID):
self._response_collector.remove(request_id)

def is_open(self) -> bool:
return self._is_open.get()

Expand All @@ -78,8 +81,9 @@ def fetch(self, request_id: UUID) -> Union[transaction_proto.Transaction.Res, tr
raise TypeDBClientException.of(TRANSACTION_CLOSED)
server_msg = next(self._response_iterator)
except RpcError as e:
self.close(e)
raise TypeDBClientException.of_rpc(e)
error = TypeDBClientException.of_rpc(e)
self.close(error)
raise error
except StopIteration:
self.close()
raise TypeDBClientException.of(TRANSACTION_CLOSED)
Expand All @@ -100,10 +104,10 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio
else:
raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id)

def drain_errors(self) -> List[RpcError]:
return self._response_collector.drain_errors()
def get_errors(self) -> List[TypeDBClientException]:
return self._response_collector.get_errors()

def close(self, error: RpcError = None):
def close(self, error: TypeDBClientException = None):
if self._is_open.compare_and_set(True, False):
self._response_collector.close(error)
try:
Expand All @@ -127,7 +131,9 @@ def __init__(self, request_id: UUID, stream: "BidirectionalStream"):
self._stream = stream

def get(self) -> T:
return self._stream.fetch(self._request_id)
value = self._stream.fetch(self._request_id)
self._stream.done(self._request_id)
return value


class RequestIterator(Iterator[Union[transaction_proto.Transaction.Req, StopIteration]]):
Expand Down
85 changes: 51 additions & 34 deletions typedb/stream/response_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,81 +21,98 @@

import queue
from threading import Lock
from typing import Generic, TypeVar, Dict, Optional, Union
from typing import Generic, TypeVar, Dict, Optional
from uuid import UUID

from grpc import RpcError

from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, ILLEGAL_STATE, \
TRANSACTION_CLOSED_WITH_ERRORS

R = TypeVar('R')


class ResponseCollector(Generic[R]):

def __init__(self):
self._collectors: Dict[UUID, ResponseCollector.Queue[R]] = {}
self._response_queues: Dict[UUID, ResponseCollector.Queue[R]] = {}
self._collectors_lock = Lock()

def new_queue(self, request_id: UUID):
with self._collectors_lock:
collector: ResponseCollector.Queue[R] = ResponseCollector.Queue()
self._collectors[request_id] = collector
self._response_queues[request_id] = collector
return collector

def get(self, request_id: UUID) -> Optional["ResponseCollector.Queue[R]"]:
return self._collectors.get(request_id)
return self._response_queues.get(request_id)

def remove(self, request_id: UUID):
with self._collectors_lock:
del self._response_queues[request_id]

def close(self, error: Optional[RpcError]):
def close(self, error: Optional[TypeDBClientException]):
with self._collectors_lock:
for collector in self._collectors.values():
for collector in self._response_queues.values():
collector.close(error)

def drain_errors(self) -> [RpcError]:
def get_errors(self) -> [TypeDBClientException]:
errors = []
with self._collectors_lock:
for collector in self._collectors.values():
errors.extend(collector.drain_errors())
for collector in self._response_queues.values():
error = collector.get_error()
if error is not None:
errors.append(error)
return errors


class Queue(Generic[R]):

def __init__(self):
self._response_queue: queue.Queue[Union[Response[R], Done]] = queue.Queue()
self._response_queue: queue.Queue[Response] = queue.Queue()
self._error: TypeDBClientException = None

def get(self, block: bool) -> R:
response = self._response_queue.get(block=block)
if response.message:
return response.message
elif response.error:
raise TypeDBClientException.of_rpc(response.error)
else:
if response.is_value():
return response.value
elif response.is_done() and self._error is None:
raise TypeDBClientException.of(TRANSACTION_CLOSED)
elif response.is_done() and self._error is not None:
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, self._error)
else:
raise TypeDBClientException.of(ILLEGAL_STATE)

def put(self, response: R):
self._response_queue.put(Response(response))
self._response_queue.put(ValueResponse(response))

def close(self, error: Optional[RpcError]):
self._response_queue.put(Done(error))
def close(self, error: Optional[TypeDBClientException]):
self._error = error
self._response_queue.put(DoneResponse())

def drain_errors(self) -> [RpcError]:
errors = []
while not self._response_queue.empty():
response = self._response_queue.get(block = False)
if response.error:
errors.append(response.error)
return errors
def get_error(self) -> TypeDBClientException:
return self._error


class Response:

class Response(Generic[R]):
def is_value(self):
return False

def is_done(self):
return False


class ValueResponse(Response, Generic[R]):

def __init__(self, value: R):
self.message = value
self.value = value

def is_value(self):
return True

class Done:

def __init__(self, error: Optional[RpcError]):
self.error = error
class DoneResponse(Response):

def __init__(self):
pass

def is_done(self):
return True
4 changes: 2 additions & 2 deletions typedb/stream/response_part_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
# specific language governing permissions and limitations
# under the License.
#
from enum import Enum
from typing import Iterator, TYPE_CHECKING
from uuid import UUID

import typedb_protocol.common.transaction_pb2 as transaction_proto

from enum import Enum
from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE
from typedb.common.rpc.request_builder import transaction_stream_req
from typedb.stream.request_transmitter import RequestTransmitter
Expand Down Expand Up @@ -78,6 +77,7 @@ def _has_next(self) -> bool:

def __next__(self) -> transaction_proto.Transaction.ResPart:
if not self._has_next():
self._bidirectional_stream.done(self._request_id)
raise StopIteration
self._state = ResponsePartIterator.State.EMPTY
return self._next

0 comments on commit 56c990b

Please sign in to comment.