Skip to content

Commit

Permalink
fix structure to make the cast non-breaking (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli authored May 28, 2024
1 parent c777db7 commit 70259f2
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 80 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,12 @@ params = {
"ids": ["player100", "player101"], # second query
}

# both session_pool and session support `.execute_parameter()`
resp = client.execute_parameter(
resp = client.execute_py_params(
"RETURN abs($p1)+3 AS col1, (toBoolean($p2) AND false) AS col2, toLower($p3)+1 AS col3",
params,
)

resp = client.execute_parameter(
resp = client.execute_py_params(
"MATCH (v) WHERE id(v) in $ids RETURN id(v) AS vertex_id",
params,
)
Expand Down
2 changes: 1 addition & 1 deletion example/Params.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"p4": ["Bob", "Lily"],
}

resp = client.execute_parameter(
resp = client.execute_py_params(
"MATCH (v) WHERE id(v) in $p4 RETURN id(v) AS vertex_id",
params_premitive,
)
File renamed without changes.
73 changes: 1 addition & 72 deletions nebula3/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
#
# This source code is licensed under Apache 2.0 License.

import datetime
import time
import ssl

from typing import Any

from nebula3.fbthrift.transport import (
TSocket,
TSSLSocket,
Expand All @@ -20,7 +17,7 @@
from nebula3.fbthrift.transport.TTransport import TTransportException
from nebula3.fbthrift.protocol import THeaderProtocol, TBinaryProtocol

from nebula3.common.ttypes import ErrorCode, Value, NList, Date, Time, DateTime
from nebula3.common.ttypes import ErrorCode
from nebula3.graph import GraphService
from nebula3.graph.ttypes import VerifyClientVersionReq
from nebula3.logger import logger
Expand Down Expand Up @@ -193,72 +190,6 @@ def execute(self, session_id, stmt):
"""
return self.execute_parameter(session_id, stmt, None)

@staticmethod
def _cast_value(value: Any) -> Value:
"""
Cast the value to nebula Value type
ref: https://github.com/vesoft-inc/nebula/blob/master/src/common/datatypes/Value.cpp
:param value: the value to be casted
:return: the casted value
"""
if isinstance(value, Value):
return value
casted_value = Value()
if isinstance(value, bool):
casted_value.set_bVal(value)
elif isinstance(value, int):
casted_value.set_iVal(value)
elif isinstance(value, str):
casted_value.set_sVal(value)
elif isinstance(value, float):
casted_value.set_fVal(value)
elif isinstance(value, datetime.date):
date_value = Date(year=value.year, month=value.month, day=value.day)
casted_value.set_dVal(date_value)
elif isinstance(value, datetime.time):
time_value = Time(
hour=value.hour,
minute=value.minute,
sec=value.second,
microsec=value.microsecond,
)
casted_value.set_tVal(time_value)
elif isinstance(value, datetime.datetime):
datetime_value = DateTime(
year=value.year,
month=value.month,
day=value.day,
hour=value.hour,
minute=value.minute,
sec=value.second,
microsec=value.microsecond,
)
casted_value.set_dtVal(datetime_value)
# TODO: add support for GeoSpatial
else:
raise TypeError(f"Unsupported type: {type(value)}")
return casted_value

@staticmethod
def _build_byte_param(params: dict) -> dict:
byte_params = {}
for k, v in params.items():
if isinstance(v, Value):
byte_params[k] = v
elif str(type(v)).startswith("nebula3.common.ttypes"):
byte_params[k] = v
elif isinstance(v, list):
byte_list = []
for item in v:
byte_list.append(Connection._cast_value(item))
nlist = NList(values=byte_list)
byte_params[k] = nlist
elif isinstance(v, dict):
# TODO: add support for NMap
raise TypeError("Unsupported type: dict")
else:
byte_params[k] = Connection._cast_value(v)
return byte_params

def execute_parameter(self, session_id, stmt, params):
"""execute interface with session_id and ngql
Expand All @@ -267,8 +198,6 @@ def execute_parameter(self, session_id, stmt, params):
:param params: parameter map
:return: ExecutionResponse
"""
if params is not None:
params = Connection._build_byte_param(params)
try:
resp = self._connection.executeWithParameter(session_id, stmt, params)
return resp
Expand Down
70 changes: 70 additions & 0 deletions nebula3/gclient/net/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
from abc import abstractmethod
from typing import Dict, Any, Optional
from nebula3.data.ResultSet import ResultSet
from nebula3.common.ttypes import ErrorCode, Value, NList, Date, Time, DateTime


class BaseExecutor:
Expand All @@ -21,3 +23,71 @@ def execute(self, stmt: str) -> ResultSet:

def execute_json(self, stmt: str) -> bytes:
return self.execute_json_with_parameter(stmt, None)

def execute_py_params(self, stmt: str, params: Optional[Dict[str, Any]]) -> ResultSet:
"""**Recommended** Execute a statement with parameters in Python type instead of thrift type."""
return self.execute_parameter(stmt, _build_byte_param(params))


def _build_byte_param(params: dict) -> dict:
byte_params = {}
for k, v in params.items():
if isinstance(v, Value):
byte_params[k] = v
elif str(type(v)).startswith("nebula3.common.ttypes"):
byte_params[k] = v
elif isinstance(v, list):
byte_list = []
for item in v:
byte_list.append(_cast_value(item))
nlist = NList(values=byte_list)
byte_params[k] = nlist
elif isinstance(v, dict):
# TODO: add support for NMap
raise TypeError("Unsupported type: dict")
else:
byte_params[k] = _cast_value(v)
return byte_params

def _cast_value(value: Any) -> Value:
"""
Cast the value to nebula Value type
ref: https://github.com/vesoft-inc/nebula/blob/master/src/common/datatypes/Value.cpp
:param value: the value to be casted
:return: the casted value
"""
casted_value = Value()
if isinstance(value, bool):
casted_value.set_bVal(value)
elif isinstance(value, int):
casted_value.set_iVal(value)
elif isinstance(value, str):
casted_value.set_sVal(value)
elif isinstance(value, float):
casted_value.set_fVal(value)
elif isinstance(value, datetime.date):
date_value = Date(year=value.year, month=value.month, day=value.day)
casted_value.set_dVal(date_value)
elif isinstance(value, datetime.time):
time_value = Time(
hour=value.hour,
minute=value.minute,
sec=value.second,
microsec=value.microsecond,
)
casted_value.set_tVal(time_value)
elif isinstance(value, datetime.datetime):
datetime_value = DateTime(
year=value.year,
month=value.month,
day=value.day,
hour=value.hour,
minute=value.minute,
sec=value.second,
microsec=value.microsecond,
)
casted_value.set_dtVal(datetime_value)
# TODO: add support for GeoSpatial
else:
raise TypeError(f"Unsupported type: {type(value)}")
return casted_value
8 changes: 4 additions & 4 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_parameter(self):
assert 'bob1' == resp.row_values(0)[2].as_string()

# same test with premitive params
resp = client.execute_parameter(
resp = client.execute_py_params(
'RETURN abs($p1)+3 AS col1, (toBoolean($p2) and false) AS col2, toLower($p3)+1 AS col3',
self.params_premitive,
)
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_parameter(self):
self.params,
)
assert not resp.is_succeeded()
resp = client.execute_parameter(
resp = client.execute_py_params(
'$p1=go from "Bob" over like yield like._dst;',
self.params_premitive,
)
Expand All @@ -137,7 +137,7 @@ def test_parameter(self):
self.params,
)
assert not resp.is_succeeded()
resp = client.execute_parameter(
resp = client.execute_py_params(
'go from $p3 over like yield like._dst;',
self.params_premitive,
)
Expand All @@ -163,7 +163,7 @@ def test_parameter(self):
)
assert not resp.is_succeeded()

resp = client.execute_parameter(
resp = client.execute_py_params(
"MATCH (v) WHERE id(v) in $p4 RETURN id(v) AS vertex_id",
self.params_premitive,
)
Expand Down

0 comments on commit 70259f2

Please sign in to comment.