From e1f0b8e2b2984c424b79f100e9d7f09f8ff060b0 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 00:04:01 +0800 Subject: [PATCH 1/6] fix: cannot yaml dump DictProxy --- src/handyllm/_io.py | 25 +++++++++++++++++++++++++ src/handyllm/hprompt.py | 15 +++------------ src/handyllm/prompt_converter.py | 12 ++++++------ 3 files changed, 34 insertions(+), 18 deletions(-) create mode 100644 src/handyllm/_io.py diff --git a/src/handyllm/_io.py b/src/handyllm/_io.py new file mode 100644 index 0000000..71d18ac --- /dev/null +++ b/src/handyllm/_io.py @@ -0,0 +1,25 @@ +from pathlib import Path +import yaml +from .response import DictProxy + + +# add multi representer for Path, for YAML serialization +class MySafeDumper(yaml.SafeDumper): + pass + + +MySafeDumper.add_multi_representer( + Path, lambda dumper, data: dumper.represent_str(str(data)) +) +MySafeDumper.add_multi_representer( + DictProxy, lambda dumper, data: dumper.represent_dict(data) +) + + +def yaml_dump(*args, **kwargs): + return yaml.dump(*args, Dumper=MySafeDumper, allow_unicode=True, **kwargs) + + +def yaml_load(stream): + return yaml.safe_load(stream) + diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index ccda4e9..dbd7f35 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -41,7 +41,6 @@ from abc import abstractmethod, ABC from contextlib import asynccontextmanager, contextmanager -import yaml import frontmatter from mergedeep import merge as merge_dict, Strategy from dotenv import load_dotenv @@ -60,6 +59,7 @@ from .run_config import RunConfig, RecordRequestMode, CredentialType, VarMapFileFormat from .types import PathType, SyncHandlerChat, SyncHandlerCompletions, VarMapType from .response import ChatChunk, ChatResponse, CompletionsChunk, CompletionsResponse +from ._io import MySafeDumper, yaml_load PromptType = TypeVar("PromptType", bound="HandyPrompt") @@ -71,15 +71,6 @@ handler = frontmatter.YAMLHandler() -# add multi representer for Path, for YAML serialization -class MySafeDumper(yaml.SafeDumper): - pass - - -MySafeDumper.add_multi_representer( - Path, lambda dumper, data: dumper.represent_str(str(data)) -) - p_var_map = re.compile(r"(%\w+%)") DEFAULT_CONFIG = RunConfig() @@ -464,7 +455,7 @@ def _prepare_run( if evaled_run_config.credential_type == CredentialType.JSON: credential_dict = json.load(fin) else: - credential_dict = yaml.safe_load(fin) + credential_dict = yaml_load(fin) # do not overwrite the existing request arguments for key, value in credential_dict.items(): if key not in evaled_request: @@ -1103,7 +1094,7 @@ def load_var_map( """ with open(path, "r", encoding="utf-8") as fin: if format in (VarMapFileFormat.JSON, VarMapFileFormat.YAML): - return yaml.safe_load(fin) + return yaml_load(fin) content = fin.read() substitute_map = {} blocks = p_var_map.split(content) diff --git a/src/handyllm/prompt_converter.py b/src/handyllm/prompt_converter.py index c60cdb7..526c118 100644 --- a/src/handyllm/prompt_converter.py +++ b/src/handyllm/prompt_converter.py @@ -2,9 +2,9 @@ import re from typing import IO, Generator, MutableMapping, MutableSequence, Optional -import yaml from .types import PathType, ShortChatChunk +from ._io import yaml_dump, yaml_load class PromptConverter: @@ -67,11 +67,11 @@ def raw2msgs(self, raw_prompt: str): if "type" in extra_properties: type_of_msg = extra_properties.pop("type") if type_of_msg == "tool_calls": - msg["tool_calls"] = yaml.safe_load(content) + msg["tool_calls"] = yaml_load(content) msg["content"] = None elif type_of_msg == "content_array": # parse content array - msg["content"] = yaml.safe_load(content) + msg["content"] = yaml_load(content) for key in extra_properties: msg[key] = extra_properties[key] msgs.append(msg) @@ -99,10 +99,10 @@ def msgs2raw(msgs): } if tool_calls: extra_properties["type"] = "tool_calls" - content = yaml.dump(tool_calls, allow_unicode=True) + content = yaml_dump(tool_calls) elif isinstance(content, MutableSequence): extra_properties["type"] = "content_array" - content = yaml.dump(content, allow_unicode=True) + content = yaml_dump(content) if extra_properties: extra = ( " {" @@ -136,7 +136,7 @@ def consume_stream2fd( fd.write(' {type="tool_calls"}\n') role_completed = True # dump tool calls - fd.write(yaml.dump([tool_call], allow_unicode=True)) + fd.write(yaml_dump([tool_call])) elif text: if not role_completed: fd.write("\n") From 0a691a8c8460cb258b13181007868c0274a3da3b Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 00:18:15 +0800 Subject: [PATCH 2/6] feat(response.DictProxy): raise AttributeError when key not in DictProxy --- src/handyllm/response.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/handyllm/response.py b/src/handyllm/response.py index 1d3e861..8c2f346 100644 --- a/src/handyllm/response.py +++ b/src/handyllm/response.py @@ -22,7 +22,9 @@ def __init__(self, *args, **kwargs): self[key] = self._wrap(value) def __getattr__(self, attr): - return self[attr] + if attr in self: + return self[attr] + raise AttributeError(f"Attribute {attr} not found") def __setattr__(self, attr, value): self[attr] = self._wrap(value) From 8e5cd51df984e4b417d8cb73f1403a452d471f02 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 00:25:14 +0800 Subject: [PATCH 3/6] feat(_io): allow *args and **kwargs --- src/handyllm/_io.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/handyllm/_io.py b/src/handyllm/_io.py index 71d18ac..e00b8a0 100644 --- a/src/handyllm/_io.py +++ b/src/handyllm/_io.py @@ -1,5 +1,7 @@ from pathlib import Path import yaml +from functools import wraps + from .response import DictProxy @@ -16,10 +18,13 @@ class MySafeDumper(yaml.SafeDumper): ) +@wraps(yaml.dump) def yaml_dump(*args, **kwargs): - return yaml.dump(*args, Dumper=MySafeDumper, allow_unicode=True, **kwargs) - + kwargs.setdefault("Dumper", MySafeDumper) + kwargs.setdefault("allow_unicode", True) + return yaml.dump(*args, **kwargs) -def yaml_load(stream): - return yaml.safe_load(stream) +@wraps(yaml.safe_load) +def yaml_load(*args, **kwargs): + return yaml.safe_load(*args, **kwargs) From d5501e155eafb09cb8d3ad02658066c0b3468f7c Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 00:27:04 +0800 Subject: [PATCH 4/6] refactor: use _io.yaml_dump and _io.yaml_load --- src/handyllm/cache_manager.py | 6 +++--- src/handyllm/endpoint_manager.py | 4 ++-- src/handyllm/openai_client.py | 5 ++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/handyllm/cache_manager.py b/src/handyllm/cache_manager.py index bec53bf..74ee0e1 100644 --- a/src/handyllm/cache_manager.py +++ b/src/handyllm/cache_manager.py @@ -7,16 +7,16 @@ from pathlib import Path from typing import Callable, Collection, Iterable, List, Optional, TypeVar, Union, cast from typing_extensions import ParamSpec -import yaml from .types import PathType, StrHandler, StringifyHandler +from ._io import yaml_load, yaml_dump def _suffix_loader(file: Path): with open(file, "r", encoding="utf-8") as f: # determine the format according to the file suffix if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"): - content = yaml.safe_load(f) + content = yaml_load(f) elif file.suffix.endswith(".json"): content = json.load(f) else: @@ -28,7 +28,7 @@ def _suffix_dumper(file: Path, content): with open(file, "w", encoding="utf-8") as f: # determine the format according to the file suffix if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"): - yaml.dump(content, f, default_flow_style=False, allow_unicode=True) + yaml_dump(content, f, default_flow_style=False) elif file.suffix.endswith(".json"): json.dump(content, f, ensure_ascii=False, indent=2) else: diff --git a/src/handyllm/endpoint_manager.py b/src/handyllm/endpoint_manager.py index 034a3fb..b9e0cc4 100644 --- a/src/handyllm/endpoint_manager.py +++ b/src/handyllm/endpoint_manager.py @@ -6,10 +6,10 @@ from threading import Lock from collections.abc import MutableSequence from typing import Iterable, Mapping, Optional, Union -import yaml from .types import PathType from ._utils import isiterable +from ._io import yaml_load class Endpoint: @@ -124,7 +124,7 @@ def load_from_list(self, obj: Iterable[Union[Mapping, Endpoint]], override=False def load_from(self, path: PathType, encoding="utf-8", override=False): with open(path, "r", encoding=encoding) as fin: - obj = yaml.safe_load(fin) + obj = yaml_load(fin) if isinstance(obj, Mapping): obj = obj.get("endpoints", None) if obj: diff --git a/src/handyllm/openai_client.py b/src/handyllm/openai_client.py index 28f6d32..5a66e36 100644 --- a/src/handyllm/openai_client.py +++ b/src/handyllm/openai_client.py @@ -12,8 +12,6 @@ from enum import Enum, auto import asyncio -import yaml - from .endpoint_manager import Endpoint, EndpointManager from .requestor import ( Requestor, @@ -37,6 +35,7 @@ TYPE_API_TYPES, ) from .types import PathType +from ._io import yaml_load RequestorType = TypeVar("RequestorType", bound="Requestor") @@ -145,7 +144,7 @@ def __init__( def load_from(self, path: PathType, encoding="utf-8", override=False): with open(path, "r", encoding=encoding) as fin: - obj = yaml.safe_load(fin) + obj = yaml_load(fin) if obj: self.load_from_obj(obj, override=override) From e63c42667185f3f52339ce83a8ff0056d8df7c53 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 00:44:19 +0800 Subject: [PATCH 5/6] feat: add json related io methods --- src/handyllm/_io.py | 20 ++++++++++++++++++++ src/handyllm/_utils.py | 4 ++-- src/handyllm/cache_manager.py | 7 +++---- src/handyllm/hprompt.py | 5 ++--- src/handyllm/openai_client.py | 8 ++++---- src/handyllm/requestor.py | 6 +++--- 6 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/handyllm/_io.py b/src/handyllm/_io.py index e00b8a0..8eeccc8 100644 --- a/src/handyllm/_io.py +++ b/src/handyllm/_io.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import yaml from functools import wraps @@ -28,3 +29,22 @@ def yaml_dump(*args, **kwargs): def yaml_load(*args, **kwargs): return yaml.safe_load(*args, **kwargs) +@wraps(json.dump) +def json_dump(*args, **kwargs): + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("indent", 2) + return json.dump(*args, **kwargs) + +@wraps(json.dumps) +def json_dumps(*args, **kwargs): + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("indent", 2) + return json.dumps(*args, **kwargs) + +@wraps(json.load) +def json_load(*args, **kwargs): + return json.load(*args, **kwargs) + +@wraps(json.loads) +def json_loads(*args, **kwargs): + return json.loads(*args, **kwargs) diff --git a/src/handyllm/_utils.py b/src/handyllm/_utils.py index 887e4e7..6629eb5 100644 --- a/src/handyllm/_utils.py +++ b/src/handyllm/_utils.py @@ -1,12 +1,12 @@ import collections.abc import copy -import json from urllib.parse import quote_plus import time import inspect from ._constants import API_TYPES_AZURE from .prompt_converter import PromptConverter +from ._io import json_dumps def get_request_url(request_url, api_type, api_version, engine): @@ -41,7 +41,7 @@ def wrap_log_input(input_content: str, log_marks, kwargs): input_lines = [str(item) for item in log_marks] else: input_lines = [str(log_marks)] - input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False)) + input_lines.append(json_dumps(arguments, indent=2, ensure_ascii=False)) input_lines.append(" INPUT START ".center(50, "-")) input_lines.append(input_content) input_lines.append(" INPUT END ".center(50, "-") + "\n") diff --git a/src/handyllm/cache_manager.py b/src/handyllm/cache_manager.py index 74ee0e1..501e5d8 100644 --- a/src/handyllm/cache_manager.py +++ b/src/handyllm/cache_manager.py @@ -2,14 +2,13 @@ from functools import wraps from inspect import iscoroutinefunction -import json from os import PathLike from pathlib import Path from typing import Callable, Collection, Iterable, List, Optional, TypeVar, Union, cast from typing_extensions import ParamSpec from .types import PathType, StrHandler, StringifyHandler -from ._io import yaml_load, yaml_dump +from ._io import json_load, json_dump, yaml_load, yaml_dump def _suffix_loader(file: Path): @@ -18,7 +17,7 @@ def _suffix_loader(file: Path): if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"): content = yaml_load(f) elif file.suffix.endswith(".json"): - content = json.load(f) + content = json_load(f) else: content = f.read() return content @@ -30,7 +29,7 @@ def _suffix_dumper(file: Path, content): if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"): yaml_dump(content, f, default_flow_style=False) elif file.suffix.endswith(".json"): - json.dump(content, f, ensure_ascii=False, indent=2) + json_dump(content, f) else: f.write(str(content)) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index dbd7f35..0729bcd 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -17,7 +17,6 @@ ] import inspect -import json import re import copy import io @@ -59,7 +58,7 @@ from .run_config import RunConfig, RecordRequestMode, CredentialType, VarMapFileFormat from .types import PathType, SyncHandlerChat, SyncHandlerCompletions, VarMapType from .response import ChatChunk, ChatResponse, CompletionsChunk, CompletionsResponse -from ._io import MySafeDumper, yaml_load +from ._io import MySafeDumper, json_load, yaml_load PromptType = TypeVar("PromptType", bound="HandyPrompt") @@ -453,7 +452,7 @@ def _prepare_run( evaled_run_config.credential_path, "r", encoding="utf-8" ) as fin: if evaled_run_config.credential_type == CredentialType.JSON: - credential_dict = json.load(fin) + credential_dict = json_load(fin) else: credential_dict = yaml_load(fin) # do not overwrite the existing request arguments diff --git a/src/handyllm/openai_client.py b/src/handyllm/openai_client.py index 5a66e36..263d419 100644 --- a/src/handyllm/openai_client.py +++ b/src/handyllm/openai_client.py @@ -7,7 +7,7 @@ from typing import Dict, Iterable, Mapping, Optional, TypeVar, Union import os -import json +from json import JSONDecodeError import time from enum import Enum, auto import asyncio @@ -35,7 +35,7 @@ TYPE_API_TYPES, ) from .types import PathType -from ._io import yaml_load +from ._io import yaml_load, json_loads RequestorType = TypeVar("RequestorType", bound="Requestor") @@ -257,8 +257,8 @@ def _infer_model_engine_map(self, model_engine_map=None): if not json_str: return None try: - return json.loads(json_str) - except json.JSONDecodeError: + return json_loads(json_str) + except JSONDecodeError: return None def _consume_kwargs(self, kwargs): diff --git a/src/handyllm/requestor.py b/src/handyllm/requestor.py index 5e42126..4d3f57c 100644 --- a/src/handyllm/requestor.py +++ b/src/handyllm/requestor.py @@ -20,7 +20,6 @@ ) import asyncio import logging -import json import time import requests import httpx @@ -33,6 +32,7 @@ ChatResponse, CompletionsResponse, ) +from ._io import json_loads module_logger = logging.getLogger(__name__) @@ -272,7 +272,7 @@ def _gen_stream_response(self, raw_response: requests.Response, prepare_ret): return if byte_line.startswith(b"data: "): line = byte_line[len(b"data: ") :].decode("utf-8") - yield json.loads(line) + yield json_loads(line) except Exception as e: if self._exception_callback: self._exception_callback(e, prepare_ret) @@ -393,7 +393,7 @@ async def _agen_stream_response(self, raw_response: httpx.Response, prepare_ret) return if raw_line.startswith("data: "): line = raw_line[len("data: ") :] - yield json.loads(line) + yield json_loads(line) except Exception as e: if self._exception_callback: self._exception_callback(e, prepare_ret) From 4b0fdff2598b2636cd31d020790b9d0726f856b1 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 6 Aug 2024 00:53:15 +0800 Subject: [PATCH 6/6] minor(_io): lint format --- src/handyllm/_io.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/handyllm/_io.py b/src/handyllm/_io.py index 8eeccc8..fa30cdd 100644 --- a/src/handyllm/_io.py +++ b/src/handyllm/_io.py @@ -25,26 +25,31 @@ def yaml_dump(*args, **kwargs): kwargs.setdefault("allow_unicode", True) return yaml.dump(*args, **kwargs) + @wraps(yaml.safe_load) def yaml_load(*args, **kwargs): return yaml.safe_load(*args, **kwargs) + @wraps(json.dump) def json_dump(*args, **kwargs): kwargs.setdefault("ensure_ascii", False) kwargs.setdefault("indent", 2) return json.dump(*args, **kwargs) + @wraps(json.dumps) def json_dumps(*args, **kwargs): kwargs.setdefault("ensure_ascii", False) kwargs.setdefault("indent", 2) return json.dumps(*args, **kwargs) + @wraps(json.load) def json_load(*args, **kwargs): return json.load(*args, **kwargs) + @wraps(json.loads) def json_loads(*args, **kwargs): return json.loads(*args, **kwargs)