diff --git a/src/handyllm/_io.py b/src/handyllm/_io.py new file mode 100644 index 0000000..fa30cdd --- /dev/null +++ b/src/handyllm/_io.py @@ -0,0 +1,55 @@ +import json +from pathlib import Path +import yaml +from functools import wraps + +from .response import DictProxy + + +# add multi representer for Path, for YAML serialization +class MySafeDumper(yaml.SafeDumper): + pass + + +MySafeDumper.add_multi_representer( + Path, lambda dumper, data: dumper.represent_str(str(data)) +) +MySafeDumper.add_multi_representer( + DictProxy, lambda dumper, data: dumper.represent_dict(data) +) + + +@wraps(yaml.dump) +def yaml_dump(*args, **kwargs): + kwargs.setdefault("Dumper", MySafeDumper) + kwargs.setdefault("allow_unicode", True) + return yaml.dump(*args, **kwargs) + + +@wraps(yaml.safe_load) +def yaml_load(*args, **kwargs): + return yaml.safe_load(*args, **kwargs) + + +@wraps(json.dump) +def json_dump(*args, **kwargs): + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("indent", 2) + return json.dump(*args, **kwargs) + + +@wraps(json.dumps) +def json_dumps(*args, **kwargs): + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("indent", 2) + return json.dumps(*args, **kwargs) + + +@wraps(json.load) +def json_load(*args, **kwargs): + return json.load(*args, **kwargs) + + +@wraps(json.loads) +def json_loads(*args, **kwargs): + return json.loads(*args, **kwargs) diff --git a/src/handyllm/_utils.py b/src/handyllm/_utils.py index 887e4e7..6629eb5 100644 --- a/src/handyllm/_utils.py +++ b/src/handyllm/_utils.py @@ -1,12 +1,12 @@ import collections.abc import copy -import json from urllib.parse import quote_plus import time import inspect from ._constants import API_TYPES_AZURE from .prompt_converter import PromptConverter +from ._io import json_dumps def get_request_url(request_url, api_type, api_version, engine): @@ -41,7 +41,7 @@ def wrap_log_input(input_content: str, log_marks, kwargs): input_lines = [str(item) for item in log_marks] else: input_lines = [str(log_marks)] - input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False)) + input_lines.append(json_dumps(arguments, indent=2, ensure_ascii=False)) input_lines.append(" INPUT START ".center(50, "-")) input_lines.append(input_content) input_lines.append(" INPUT END ".center(50, "-") + "\n") diff --git a/src/handyllm/cache_manager.py b/src/handyllm/cache_manager.py index bec53bf..501e5d8 100644 --- a/src/handyllm/cache_manager.py +++ b/src/handyllm/cache_manager.py @@ -2,23 +2,22 @@ from functools import wraps from inspect import iscoroutinefunction -import json from os import PathLike from pathlib import Path from typing import Callable, Collection, Iterable, List, Optional, TypeVar, Union, cast from typing_extensions import ParamSpec -import yaml from .types import PathType, StrHandler, StringifyHandler +from ._io import json_load, json_dump, yaml_load, yaml_dump def _suffix_loader(file: Path): with open(file, "r", encoding="utf-8") as f: # determine the format according to the file suffix if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"): - content = yaml.safe_load(f) + content = yaml_load(f) elif file.suffix.endswith(".json"): - content = json.load(f) + content = json_load(f) else: content = f.read() return content @@ -28,9 +27,9 @@ def _suffix_dumper(file: Path, content): with open(file, "w", encoding="utf-8") as f: # determine the format according to the file suffix if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"): - yaml.dump(content, f, default_flow_style=False, allow_unicode=True) + yaml_dump(content, f, default_flow_style=False) elif file.suffix.endswith(".json"): - json.dump(content, f, ensure_ascii=False, indent=2) + json_dump(content, f) else: f.write(str(content)) diff --git a/src/handyllm/endpoint_manager.py b/src/handyllm/endpoint_manager.py index 034a3fb..b9e0cc4 100644 --- a/src/handyllm/endpoint_manager.py +++ b/src/handyllm/endpoint_manager.py @@ -6,10 +6,10 @@ from threading import Lock from collections.abc import MutableSequence from typing import Iterable, Mapping, Optional, Union -import yaml from .types import PathType from ._utils import isiterable +from ._io import yaml_load class Endpoint: @@ -124,7 +124,7 @@ def load_from_list(self, obj: Iterable[Union[Mapping, Endpoint]], override=False def load_from(self, path: PathType, encoding="utf-8", override=False): with open(path, "r", encoding=encoding) as fin: - obj = yaml.safe_load(fin) + obj = yaml_load(fin) if isinstance(obj, Mapping): obj = obj.get("endpoints", None) if obj: diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index ccda4e9..0729bcd 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -17,7 +17,6 @@ ] import inspect -import json import re import copy import io @@ -41,7 +40,6 @@ from abc import abstractmethod, ABC from contextlib import asynccontextmanager, contextmanager -import yaml import frontmatter from mergedeep import merge as merge_dict, Strategy from dotenv import load_dotenv @@ -60,6 +58,7 @@ from .run_config import RunConfig, RecordRequestMode, CredentialType, VarMapFileFormat from .types import PathType, SyncHandlerChat, SyncHandlerCompletions, VarMapType from .response import ChatChunk, ChatResponse, CompletionsChunk, CompletionsResponse +from ._io import MySafeDumper, json_load, yaml_load PromptType = TypeVar("PromptType", bound="HandyPrompt") @@ -71,15 +70,6 @@ handler = frontmatter.YAMLHandler() -# add multi representer for Path, for YAML serialization -class MySafeDumper(yaml.SafeDumper): - pass - - -MySafeDumper.add_multi_representer( - Path, lambda dumper, data: dumper.represent_str(str(data)) -) - p_var_map = re.compile(r"(%\w+%)") DEFAULT_CONFIG = RunConfig() @@ -462,9 +452,9 @@ def _prepare_run( evaled_run_config.credential_path, "r", encoding="utf-8" ) as fin: if evaled_run_config.credential_type == CredentialType.JSON: - credential_dict = json.load(fin) + credential_dict = json_load(fin) else: - credential_dict = yaml.safe_load(fin) + credential_dict = yaml_load(fin) # do not overwrite the existing request arguments for key, value in credential_dict.items(): if key not in evaled_request: @@ -1103,7 +1093,7 @@ def load_var_map( """ with open(path, "r", encoding="utf-8") as fin: if format in (VarMapFileFormat.JSON, VarMapFileFormat.YAML): - return yaml.safe_load(fin) + return yaml_load(fin) content = fin.read() substitute_map = {} blocks = p_var_map.split(content) diff --git a/src/handyllm/openai_client.py b/src/handyllm/openai_client.py index 28f6d32..263d419 100644 --- a/src/handyllm/openai_client.py +++ b/src/handyllm/openai_client.py @@ -7,13 +7,11 @@ from typing import Dict, Iterable, Mapping, Optional, TypeVar, Union import os -import json +from json import JSONDecodeError import time from enum import Enum, auto import asyncio -import yaml - from .endpoint_manager import Endpoint, EndpointManager from .requestor import ( Requestor, @@ -37,6 +35,7 @@ TYPE_API_TYPES, ) from .types import PathType +from ._io import yaml_load, json_loads RequestorType = TypeVar("RequestorType", bound="Requestor") @@ -145,7 +144,7 @@ def __init__( def load_from(self, path: PathType, encoding="utf-8", override=False): with open(path, "r", encoding=encoding) as fin: - obj = yaml.safe_load(fin) + obj = yaml_load(fin) if obj: self.load_from_obj(obj, override=override) @@ -258,8 +257,8 @@ def _infer_model_engine_map(self, model_engine_map=None): if not json_str: return None try: - return json.loads(json_str) - except json.JSONDecodeError: + return json_loads(json_str) + except JSONDecodeError: return None def _consume_kwargs(self, kwargs): diff --git a/src/handyllm/prompt_converter.py b/src/handyllm/prompt_converter.py index c60cdb7..526c118 100644 --- a/src/handyllm/prompt_converter.py +++ b/src/handyllm/prompt_converter.py @@ -2,9 +2,9 @@ import re from typing import IO, Generator, MutableMapping, MutableSequence, Optional -import yaml from .types import PathType, ShortChatChunk +from ._io import yaml_dump, yaml_load class PromptConverter: @@ -67,11 +67,11 @@ def raw2msgs(self, raw_prompt: str): if "type" in extra_properties: type_of_msg = extra_properties.pop("type") if type_of_msg == "tool_calls": - msg["tool_calls"] = yaml.safe_load(content) + msg["tool_calls"] = yaml_load(content) msg["content"] = None elif type_of_msg == "content_array": # parse content array - msg["content"] = yaml.safe_load(content) + msg["content"] = yaml_load(content) for key in extra_properties: msg[key] = extra_properties[key] msgs.append(msg) @@ -99,10 +99,10 @@ def msgs2raw(msgs): } if tool_calls: extra_properties["type"] = "tool_calls" - content = yaml.dump(tool_calls, allow_unicode=True) + content = yaml_dump(tool_calls) elif isinstance(content, MutableSequence): extra_properties["type"] = "content_array" - content = yaml.dump(content, allow_unicode=True) + content = yaml_dump(content) if extra_properties: extra = ( " {" @@ -136,7 +136,7 @@ def consume_stream2fd( fd.write(' {type="tool_calls"}\n') role_completed = True # dump tool calls - fd.write(yaml.dump([tool_call], allow_unicode=True)) + fd.write(yaml_dump([tool_call])) elif text: if not role_completed: fd.write("\n") diff --git a/src/handyllm/requestor.py b/src/handyllm/requestor.py index 5e42126..4d3f57c 100644 --- a/src/handyllm/requestor.py +++ b/src/handyllm/requestor.py @@ -20,7 +20,6 @@ ) import asyncio import logging -import json import time import requests import httpx @@ -33,6 +32,7 @@ ChatResponse, CompletionsResponse, ) +from ._io import json_loads module_logger = logging.getLogger(__name__) @@ -272,7 +272,7 @@ def _gen_stream_response(self, raw_response: requests.Response, prepare_ret): return if byte_line.startswith(b"data: "): line = byte_line[len(b"data: ") :].decode("utf-8") - yield json.loads(line) + yield json_loads(line) except Exception as e: if self._exception_callback: self._exception_callback(e, prepare_ret) @@ -393,7 +393,7 @@ async def _agen_stream_response(self, raw_response: httpx.Response, prepare_ret) return if raw_line.startswith("data: "): line = raw_line[len("data: ") :] - yield json.loads(line) + yield json_loads(line) except Exception as e: if self._exception_callback: self._exception_callback(e, prepare_ret) diff --git a/src/handyllm/response.py b/src/handyllm/response.py index 1d3e861..8c2f346 100644 --- a/src/handyllm/response.py +++ b/src/handyllm/response.py @@ -22,7 +22,9 @@ def __init__(self, *args, **kwargs): self[key] = self._wrap(value) def __getattr__(self, attr): - return self[attr] + if attr in self: + return self[attr] + raise AttributeError(f"Attribute {attr} not found") def __setattr__(self, attr, value): self[attr] = self._wrap(value)