From 70259f25e5186068bb2d527338c2532e656e846c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Tue, 28 May 2024 14:11:08 +0800 Subject: [PATCH] fix structure to make the cast non-breaking (#350) --- README.md | 5 +- example/Params.py | 2 +- .../SessionPoolExample.py | 0 nebula3/gclient/net/Connection.py | 73 +------------------ nebula3/gclient/net/base.py | 70 ++++++++++++++++++ tests/test_parameter.py | 8 +- 6 files changed, 78 insertions(+), 80 deletions(-) rename SessionPoolExample.py => example/SessionPoolExample.py (100%) diff --git a/README.md b/README.md index a9586aea..59b36a9b 100644 --- a/README.md +++ b/README.md @@ -132,13 +132,12 @@ params = { "ids": ["player100", "player101"], # second query } -# both session_pool and session support `.execute_parameter()` -resp = client.execute_parameter( +resp = client.execute_py_params( "RETURN abs($p1)+3 AS col1, (toBoolean($p2) AND false) AS col2, toLower($p3)+1 AS col3", params, ) -resp = client.execute_parameter( +resp = client.execute_py_params( "MATCH (v) WHERE id(v) in $ids RETURN id(v) AS vertex_id", params, ) diff --git a/example/Params.py b/example/Params.py index bfb47f96..a1f4d382 100644 --- a/example/Params.py +++ b/example/Params.py @@ -51,7 +51,7 @@ "p4": ["Bob", "Lily"], } -resp = client.execute_parameter( +resp = client.execute_py_params( "MATCH (v) WHERE id(v) in $p4 RETURN id(v) AS vertex_id", params_premitive, ) diff --git a/SessionPoolExample.py b/example/SessionPoolExample.py similarity index 100% rename from SessionPoolExample.py rename to example/SessionPoolExample.py diff --git a/nebula3/gclient/net/Connection.py b/nebula3/gclient/net/Connection.py index 5f993c90..a1fda96f 100644 --- a/nebula3/gclient/net/Connection.py +++ b/nebula3/gclient/net/Connection.py @@ -4,12 +4,9 @@ # # This source code is licensed under Apache 2.0 License. -import datetime import time import ssl -from typing import Any - from nebula3.fbthrift.transport import ( TSocket, TSSLSocket, @@ -20,7 +17,7 @@ from nebula3.fbthrift.transport.TTransport import TTransportException from nebula3.fbthrift.protocol import THeaderProtocol, TBinaryProtocol -from nebula3.common.ttypes import ErrorCode, Value, NList, Date, Time, DateTime +from nebula3.common.ttypes import ErrorCode from nebula3.graph import GraphService from nebula3.graph.ttypes import VerifyClientVersionReq from nebula3.logger import logger @@ -193,72 +190,6 @@ def execute(self, session_id, stmt): """ return self.execute_parameter(session_id, stmt, None) - @staticmethod - def _cast_value(value: Any) -> Value: - """ - Cast the value to nebula Value type - ref: https://github.com/vesoft-inc/nebula/blob/master/src/common/datatypes/Value.cpp - :param value: the value to be casted - :return: the casted value - """ - if isinstance(value, Value): - return value - casted_value = Value() - if isinstance(value, bool): - casted_value.set_bVal(value) - elif isinstance(value, int): - casted_value.set_iVal(value) - elif isinstance(value, str): - casted_value.set_sVal(value) - elif isinstance(value, float): - casted_value.set_fVal(value) - elif isinstance(value, datetime.date): - date_value = Date(year=value.year, month=value.month, day=value.day) - casted_value.set_dVal(date_value) - elif isinstance(value, datetime.time): - time_value = Time( - hour=value.hour, - minute=value.minute, - sec=value.second, - microsec=value.microsecond, - ) - casted_value.set_tVal(time_value) - elif isinstance(value, datetime.datetime): - datetime_value = DateTime( - year=value.year, - month=value.month, - day=value.day, - hour=value.hour, - minute=value.minute, - sec=value.second, - microsec=value.microsecond, - ) - casted_value.set_dtVal(datetime_value) - # TODO: add support for GeoSpatial - else: - raise TypeError(f"Unsupported type: {type(value)}") - return casted_value - - @staticmethod - def _build_byte_param(params: dict) -> dict: - byte_params = {} - for k, v in params.items(): - if isinstance(v, Value): - byte_params[k] = v - elif str(type(v)).startswith("nebula3.common.ttypes"): - byte_params[k] = v - elif isinstance(v, list): - byte_list = [] - for item in v: - byte_list.append(Connection._cast_value(item)) - nlist = NList(values=byte_list) - byte_params[k] = nlist - elif isinstance(v, dict): - # TODO: add support for NMap - raise TypeError("Unsupported type: dict") - else: - byte_params[k] = Connection._cast_value(v) - return byte_params def execute_parameter(self, session_id, stmt, params): """execute interface with session_id and ngql @@ -267,8 +198,6 @@ def execute_parameter(self, session_id, stmt, params): :param params: parameter map :return: ExecutionResponse """ - if params is not None: - params = Connection._build_byte_param(params) try: resp = self._connection.executeWithParameter(session_id, stmt, params) return resp diff --git a/nebula3/gclient/net/base.py b/nebula3/gclient/net/base.py index b80320b0..0faa3479 100644 --- a/nebula3/gclient/net/base.py +++ b/nebula3/gclient/net/base.py @@ -1,6 +1,8 @@ +import datetime from abc import abstractmethod from typing import Dict, Any, Optional from nebula3.data.ResultSet import ResultSet +from nebula3.common.ttypes import ErrorCode, Value, NList, Date, Time, DateTime class BaseExecutor: @@ -21,3 +23,71 @@ def execute(self, stmt: str) -> ResultSet: def execute_json(self, stmt: str) -> bytes: return self.execute_json_with_parameter(stmt, None) + + def execute_py_params(self, stmt: str, params: Optional[Dict[str, Any]]) -> ResultSet: + """**Recommended** Execute a statement with parameters in Python type instead of thrift type.""" + return self.execute_parameter(stmt, _build_byte_param(params)) + + +def _build_byte_param(params: dict) -> dict: + byte_params = {} + for k, v in params.items(): + if isinstance(v, Value): + byte_params[k] = v + elif str(type(v)).startswith("nebula3.common.ttypes"): + byte_params[k] = v + elif isinstance(v, list): + byte_list = [] + for item in v: + byte_list.append(_cast_value(item)) + nlist = NList(values=byte_list) + byte_params[k] = nlist + elif isinstance(v, dict): + # TODO: add support for NMap + raise TypeError("Unsupported type: dict") + else: + byte_params[k] = _cast_value(v) + return byte_params + +def _cast_value(value: Any) -> Value: + """ + Cast the value to nebula Value type + ref: https://github.com/vesoft-inc/nebula/blob/master/src/common/datatypes/Value.cpp + :param value: the value to be casted + :return: the casted value + """ + casted_value = Value() + if isinstance(value, bool): + casted_value.set_bVal(value) + elif isinstance(value, int): + casted_value.set_iVal(value) + elif isinstance(value, str): + casted_value.set_sVal(value) + elif isinstance(value, float): + casted_value.set_fVal(value) + elif isinstance(value, datetime.date): + date_value = Date(year=value.year, month=value.month, day=value.day) + casted_value.set_dVal(date_value) + elif isinstance(value, datetime.time): + time_value = Time( + hour=value.hour, + minute=value.minute, + sec=value.second, + microsec=value.microsecond, + ) + casted_value.set_tVal(time_value) + elif isinstance(value, datetime.datetime): + datetime_value = DateTime( + year=value.year, + month=value.month, + day=value.day, + hour=value.hour, + minute=value.minute, + sec=value.second, + microsec=value.microsecond, + ) + casted_value.set_dtVal(datetime_value) + # TODO: add support for GeoSpatial + else: + raise TypeError(f"Unsupported type: {type(value)}") + return casted_value \ No newline at end of file diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 79ffb666..2e6c6517 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -96,7 +96,7 @@ def test_parameter(self): assert 'bob1' == resp.row_values(0)[2].as_string() # same test with premitive params - resp = client.execute_parameter( + resp = client.execute_py_params( 'RETURN abs($p1)+3 AS col1, (toBoolean($p2) and false) AS col2, toLower($p3)+1 AS col3', self.params_premitive, ) @@ -127,7 +127,7 @@ def test_parameter(self): self.params, ) assert not resp.is_succeeded() - resp = client.execute_parameter( + resp = client.execute_py_params( '$p1=go from "Bob" over like yield like._dst;', self.params_premitive, ) @@ -137,7 +137,7 @@ def test_parameter(self): self.params, ) assert not resp.is_succeeded() - resp = client.execute_parameter( + resp = client.execute_py_params( 'go from $p3 over like yield like._dst;', self.params_premitive, ) @@ -163,7 +163,7 @@ def test_parameter(self): ) assert not resp.is_succeeded() - resp = client.execute_parameter( + resp = client.execute_py_params( "MATCH (v) WHERE id(v) in $p4 RETURN id(v) AS vertex_id", self.params_premitive, )