Skip to content

Commit

Permalink
Fix: Casting and formatting in CLI
Browse files Browse the repository at this point in the history
Closes #81
  • Loading branch information
ashvardanian committed Feb 20, 2024
1 parent c331893 commit fd923d1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 76 deletions.
6 changes: 3 additions & 3 deletions src/ucall/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def pack(self, res):
if isinstance(res, Image.Image):
buf = BytesIO()
if not res.format:
res.format = 'tiff'
res.save(buf, res.format, compression='raw', compression_level=0)
res.format = "tiff"
res.save(buf, res.format, compression="raw", compression_level=0)

return buf.getvalue()

Expand All @@ -54,7 +54,7 @@ def wrapper(*args, **kwargs):
new_kwargs = {}

for arg, hint in zip(args, hints.values()):
assert isinstance(hint, type), 'Hint must be a type!'
assert isinstance(hint, type), "Hint must be a type!"
if isinstance(arg, bytes):
new_args.append(self.unpack(arg, hint))
else:
Expand Down
85 changes: 45 additions & 40 deletions src/ucall/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def get_kwargs(buffer):
if buffer is not None:
for arg in buffer:
sp = None
if '=' in arg:
sp = arg.split('=')
if "=" in arg:
sp = arg.split("=")
else:
raise KeyError('Missing key in kwarg argument')
raise KeyError("Missing key in kwarg argument")
kwargs[sp[0]] = sp[1]
return kwargs

Expand All @@ -26,33 +26,35 @@ def cast(value: str, type_name: Optional[str]):
if type_name is None:
if value.isdigit():
return int(value)
if value.replace('.', '', 1).isdigit():
if value.replace(".", "", 1).isdigit():
return float(value)
if value in ['True', 'False']:
return bool(value)
if value == "True":
return True
if value == "False":
return False
return value

type_name = type_name.lower()
if type_name == 'image':
if type_name == "image":
return Image.open(value)
if type_name == 'binary':
return open(value, 'rb').read()
if type_name == "binary":
return open(value, "rb").read()

return locate(type_name)(value)


def fix_types(args, kwargs):
"""Casts `args` and `kwargs` to expected types."""
for i in range(len(args)):
if ':' in args[i]:
val, tp = args[i].split(':')
if ":" in args[i]:
val, tp = args[i].split(":")
args[i] = cast(val, tp)
else:
args[i] = cast(args[i], None)
keys = list(kwargs.keys())
for k in keys:
if ':' in k:
key, tp = k.split(':')
if ":" in k:
key, tp = k.split(":")
val = kwargs.pop(k)
kwargs[key] = cast(val, tp)
else:
Expand All @@ -63,54 +65,57 @@ def add_specials(kwargs: dict, special: Optional[list[str]], type_name: str):
if special is None:
return
for x in special:
if not '=' in x:
raise KeyError(f'Missing key in {type_name} argument')
k, v = x.split('=')
kwargs[k + ':' + type_name] = v
if not "=" in x:
raise KeyError(f"Missing key in {type_name} argument")
k, v = x.split("=")
kwargs[k + ":" + type_name] = v


def cli():
parsed = get_parser().parse_args()
kwargs = get_kwargs(parsed.kwargs)
args = parsed.positional if parsed.positional else []

add_specials(kwargs, parsed.file, 'binary')
add_specials(kwargs, parsed.image, 'image')
add_specials(kwargs, parsed.file, "binary")
add_specials(kwargs, parsed.image, "image")

fix_types(args, kwargs)
client = Client(uri=parsed.uri, port=parsed.port, use_http=True)
res = getattr(client, parsed.method)(*args, **kwargs)

if parsed.format == 'raw':
if parsed.format == "raw":
print(json.dumps(res.data, indent=4))
else:
try:
print(getattr(res, parsed.format))
except Exception as err:
print('Error:', err)
print("Error:", err)


def get_parser():
parser = argparse.ArgumentParser(description='UCall Client CLI')
parser.add_argument('method', type=str, help='Method name')

parser.add_argument('--uri', type=str, default='localhost',
help='Server URI')
parser.add_argument('-p', '--port', type=int, default=8545,
help='Server port')

parser.add_argument('kwargs', nargs='*', help='KEY[:TYPE]=VALUE arguments')
parser.add_argument('-f', '--file', nargs='*', help='Binary files')
parser.add_argument('-i', '--image', nargs='*', help='Image files')

parser.add_argument('--positional', nargs='*',
help='Switch to positional arguments VALUE[:TYPE]')

parser.add_argument('--format', type=str,
choices=['json', 'bytes', 'numpy', 'image', 'raw'], default='raw',
help='How to parse and format the response')
parser = argparse.ArgumentParser(description="UCall Client CLI")
parser.add_argument("method", type=str, help="Method name")

parser.add_argument("--uri", type=str, default="localhost", help="Server URI")
parser.add_argument("-p", "--port", type=int, default=8545, help="Server port")

parser.add_argument("kwargs", nargs="*", help="KEY[:TYPE]=VALUE arguments")
parser.add_argument("-f", "--file", nargs="*", help="Binary files")
parser.add_argument("-i", "--image", nargs="*", help="Image files")

parser.add_argument(
"--positional", nargs="*", help="Switch to positional arguments VALUE[:TYPE]"
)

parser.add_argument(
"--format",
type=str,
choices=["json", "bytes", "numpy", "image", "raw"],
default="raw",
help="How to parse and format the response",
)
return parser


if __name__ == '__main__':
if __name__ == "__main__":
cli()
75 changes: 42 additions & 33 deletions src/ucall/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def _receive_all(sock, buffer_size=4096):
body = None
content_len = -1

if header == b'HTTP':
while b'\r\n\r\n' not in header:
if header == b"HTTP":
while b"\r\n\r\n" not in header:
chunk = sock.recv(1024)
if not chunk:
break
header += chunk

header, body = header.split(b'\r\n\r\n', 1)
header, body = header.split(b"\r\n\r\n", 1)

pref = b'Content-Length:'
pref = b"Content-Length:"
for line in header.splitlines():
if line.startswith(pref):
content_len = int(line[len(pref):].strip())
content_len = int(line[len(pref) :].strip())
break
else:
body = header
Expand All @@ -52,11 +52,11 @@ def __init__(self, data):
@property
def json(self) -> Union[bool, int, float, str, dict, list, tuple]:
self.raise_for_status()
return self.data['result']
return self.data["result"]

def raise_for_status(self):
if 'error' in self.data:
raise RuntimeError(self.data['error'])
if "error" in self.data:
raise RuntimeError(self.data["error"])

@property
def bytes(self) -> bytes:
Expand Down Expand Up @@ -91,54 +91,57 @@ def _pack_bytes(self, buffer):
def _pack_pillow(self, image):
buf = BytesIO()
if not image.format:
image.format = 'tiff'
image.save(buf, image.format, compression='raw', compression_level=0)
image.format = "tiff"
image.save(buf, image.format, compression="raw", compression_level=0)
buf.seek(0)
return base64.b64encode(buf.getvalue()).decode()

def pack(self, req):
keys = None
if isinstance(req['params'], dict):
keys = req['params'].keys()
if isinstance(req["params"], dict):
keys = req["params"].keys()
else:
keys = range(0, len(req['params']))
keys = range(0, len(req["params"]))

for k in keys:
if isinstance(req['params'][k], np.ndarray):
req['params'][k] = self._pack_numpy(req['params'][k])
if isinstance(req["params"][k], np.ndarray):
req["params"][k] = self._pack_numpy(req["params"][k])

elif isinstance(req['params'][k], Image.Image):
req['params'][k] = self._pack_pillow(req['params'][k])
elif isinstance(req["params"][k], Image.Image):
req["params"][k] = self._pack_pillow(req["params"][k])

elif isinstance(req['params'][k], bytes):
req['params'][k] = self._pack_bytes(req['params'][k])
elif isinstance(req["params"][k], bytes):
req["params"][k] = self._pack_bytes(req["params"][k])

return req


class Client:
"""JSON-RPC Client that uses classic sync Python `requests` to pass JSON calls over HTTP"""

def __init__(self, uri: str = '127.0.0.1', port: int = 8545, use_http: bool = True) -> None:
def __init__(
self, uri: str = "127.0.0.1", port: int = 8545, use_http: bool = True
) -> None:
self.uri = uri
self.port = port
self.use_http = use_http
self.sock = None
self.http_template = f'POST / HTTP/1.1\r\nHost: {uri}:{port}\r\nUser-Agent: py-ucall\r\nAccept: */*\r\nConnection: keep-alive\r\nContent-Length: %i\r\nContent-Type: application/json\r\n\r\n'
self.http_template = f"POST / HTTP/1.1\r\nHost: {uri}:{port}\r\nUser-Agent: py-ucall\r\nAccept: */*\r\nConnection: keep-alive\r\nContent-Length: %i\r\nContent-Type: application/json\r\n\r\n"

def __getattr__(self, name):
def call(*args, **kwargs):
params = kwargs
if len(args) != 0:
assert len(
kwargs) == 0, 'Can\'t mix positional and keyword parameters!'
assert len(kwargs) == 0, "Can't mix positional and keyword parameters!"
params = args

return self.__call__({
'method': name,
'params': params,
'jsonrpc': '2.0',
})
return self.__call__(
{
"method": name,
"params": params,
"jsonrpc": "2.0",
}
)

return call

Expand All @@ -156,15 +159,15 @@ def _socket_is_closed(self) -> bool:
return True
try:
buf = self.sock.recv(1, socket.MSG_PEEK | socket.MSG_DONTWAIT)
if buf == b'':
if buf == b"":
return True
except BlockingIOError as exc:
if exc.errno != errno.EAGAIN:
raise
return False

def _send(self, json_data: dict):
json_data['id'] = random.randint(1, 2**16)
json_data["id"] = random.randint(1, 2**16)
req_obj = Request(json_data)
request = json.dumps(req_obj.packed)
if self.use_http:
Expand All @@ -185,8 +188,13 @@ def __call__(self, jsonrpc: object) -> Response:

class ClientTLS(Client):
def __init__(
self, uri: str = '127.0.0.1', port: int = 8545, ssl_context: ssl.SSLContext = None,
allow_self_signed: bool = False, enable_session_resumption: bool = True) -> None:
self,
uri: str = "127.0.0.1",
port: int = 8545,
ssl_context: ssl.SSLContext = None,
allow_self_signed: bool = False,
enable_session_resumption: bool = True,
) -> None:

super().__init__(uri, port, use_http=True)

Expand All @@ -205,7 +213,8 @@ def _make_socket(self):
return
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock = self.ssl_context.wrap_socket(
self.sock, server_hostname=self.uri, session=self.session)
self.sock, server_hostname=self.uri, session=self.session
)
self.sock.connect((self.uri, self.port))
if self.session_resumption:
self.session = self.sock.session
Expand Down

0 comments on commit fd923d1

Please sign in to comment.