Skip to content

Commit

Permalink
Merge pull request #42 from atomiechen/fix_dictproxy_dump
Browse files Browse the repository at this point in the history
Fix dictproxy dump
  • Loading branch information
atomiechen authored Aug 5, 2024
2 parents 0c7a246 + 4b0fdff commit cdf8400
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 40 deletions.
55 changes: 55 additions & 0 deletions src/handyllm/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
from pathlib import Path
import yaml
from functools import wraps

from .response import DictProxy


# add multi representer for Path, for YAML serialization
class MySafeDumper(yaml.SafeDumper):
pass


MySafeDumper.add_multi_representer(
Path, lambda dumper, data: dumper.represent_str(str(data))
)
MySafeDumper.add_multi_representer(
DictProxy, lambda dumper, data: dumper.represent_dict(data)
)


@wraps(yaml.dump)
def yaml_dump(*args, **kwargs):
kwargs.setdefault("Dumper", MySafeDumper)
kwargs.setdefault("allow_unicode", True)
return yaml.dump(*args, **kwargs)


@wraps(yaml.safe_load)
def yaml_load(*args, **kwargs):
return yaml.safe_load(*args, **kwargs)


@wraps(json.dump)
def json_dump(*args, **kwargs):
kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("indent", 2)
return json.dump(*args, **kwargs)


@wraps(json.dumps)
def json_dumps(*args, **kwargs):
kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("indent", 2)
return json.dumps(*args, **kwargs)


@wraps(json.load)
def json_load(*args, **kwargs):
return json.load(*args, **kwargs)


@wraps(json.loads)
def json_loads(*args, **kwargs):
return json.loads(*args, **kwargs)
4 changes: 2 additions & 2 deletions src/handyllm/_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import collections.abc
import copy
import json
from urllib.parse import quote_plus
import time
import inspect

from ._constants import API_TYPES_AZURE
from .prompt_converter import PromptConverter
from ._io import json_dumps


def get_request_url(request_url, api_type, api_version, engine):
Expand Down Expand Up @@ -41,7 +41,7 @@ def wrap_log_input(input_content: str, log_marks, kwargs):
input_lines = [str(item) for item in log_marks]
else:
input_lines = [str(log_marks)]
input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False))
input_lines.append(json_dumps(arguments, indent=2, ensure_ascii=False))
input_lines.append(" INPUT START ".center(50, "-"))
input_lines.append(input_content)
input_lines.append(" INPUT END ".center(50, "-") + "\n")
Expand Down
11 changes: 5 additions & 6 deletions src/handyllm/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@

from functools import wraps
from inspect import iscoroutinefunction
import json
from os import PathLike
from pathlib import Path
from typing import Callable, Collection, Iterable, List, Optional, TypeVar, Union, cast
from typing_extensions import ParamSpec
import yaml

from .types import PathType, StrHandler, StringifyHandler
from ._io import json_load, json_dump, yaml_load, yaml_dump


def _suffix_loader(file: Path):
with open(file, "r", encoding="utf-8") as f:
# determine the format according to the file suffix
if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"):
content = yaml.safe_load(f)
content = yaml_load(f)
elif file.suffix.endswith(".json"):
content = json.load(f)
content = json_load(f)
else:
content = f.read()
return content
Expand All @@ -28,9 +27,9 @@ def _suffix_dumper(file: Path, content):
with open(file, "w", encoding="utf-8") as f:
# determine the format according to the file suffix
if file.suffix.endswith(".yaml") or file.suffix.endswith(".yml"):
yaml.dump(content, f, default_flow_style=False, allow_unicode=True)
yaml_dump(content, f, default_flow_style=False)
elif file.suffix.endswith(".json"):
json.dump(content, f, ensure_ascii=False, indent=2)
json_dump(content, f)
else:
f.write(str(content))

Expand Down
4 changes: 2 additions & 2 deletions src/handyllm/endpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from threading import Lock
from collections.abc import MutableSequence
from typing import Iterable, Mapping, Optional, Union
import yaml

from .types import PathType
from ._utils import isiterable
from ._io import yaml_load


class Endpoint:
Expand Down Expand Up @@ -124,7 +124,7 @@ def load_from_list(self, obj: Iterable[Union[Mapping, Endpoint]], override=False

def load_from(self, path: PathType, encoding="utf-8", override=False):
with open(path, "r", encoding=encoding) as fin:
obj = yaml.safe_load(fin)
obj = yaml_load(fin)
if isinstance(obj, Mapping):
obj = obj.get("endpoints", None)
if obj:
Expand Down
18 changes: 4 additions & 14 deletions src/handyllm/hprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
]

import inspect
import json
import re
import copy
import io
Expand All @@ -41,7 +40,6 @@
from abc import abstractmethod, ABC
from contextlib import asynccontextmanager, contextmanager

import yaml
import frontmatter
from mergedeep import merge as merge_dict, Strategy
from dotenv import load_dotenv
Expand All @@ -60,6 +58,7 @@
from .run_config import RunConfig, RecordRequestMode, CredentialType, VarMapFileFormat
from .types import PathType, SyncHandlerChat, SyncHandlerCompletions, VarMapType
from .response import ChatChunk, ChatResponse, CompletionsChunk, CompletionsResponse
from ._io import MySafeDumper, json_load, yaml_load


PromptType = TypeVar("PromptType", bound="HandyPrompt")
Expand All @@ -71,15 +70,6 @@
handler = frontmatter.YAMLHandler()


# add multi representer for Path, for YAML serialization
class MySafeDumper(yaml.SafeDumper):
pass


MySafeDumper.add_multi_representer(
Path, lambda dumper, data: dumper.represent_str(str(data))
)

p_var_map = re.compile(r"(%\w+%)")

DEFAULT_CONFIG = RunConfig()
Expand Down Expand Up @@ -462,9 +452,9 @@ def _prepare_run(
evaled_run_config.credential_path, "r", encoding="utf-8"
) as fin:
if evaled_run_config.credential_type == CredentialType.JSON:
credential_dict = json.load(fin)
credential_dict = json_load(fin)
else:
credential_dict = yaml.safe_load(fin)
credential_dict = yaml_load(fin)
# do not overwrite the existing request arguments
for key, value in credential_dict.items():
if key not in evaled_request:
Expand Down Expand Up @@ -1103,7 +1093,7 @@ def load_var_map(
"""
with open(path, "r", encoding="utf-8") as fin:
if format in (VarMapFileFormat.JSON, VarMapFileFormat.YAML):
return yaml.safe_load(fin)
return yaml_load(fin)
content = fin.read()
substitute_map = {}
blocks = p_var_map.split(content)
Expand Down
11 changes: 5 additions & 6 deletions src/handyllm/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

from typing import Dict, Iterable, Mapping, Optional, TypeVar, Union
import os
import json
from json import JSONDecodeError
import time
from enum import Enum, auto
import asyncio

import yaml

from .endpoint_manager import Endpoint, EndpointManager
from .requestor import (
Requestor,
Expand All @@ -37,6 +35,7 @@
TYPE_API_TYPES,
)
from .types import PathType
from ._io import yaml_load, json_loads


RequestorType = TypeVar("RequestorType", bound="Requestor")
Expand Down Expand Up @@ -145,7 +144,7 @@ def __init__(

def load_from(self, path: PathType, encoding="utf-8", override=False):
with open(path, "r", encoding=encoding) as fin:
obj = yaml.safe_load(fin)
obj = yaml_load(fin)
if obj:
self.load_from_obj(obj, override=override)

Expand Down Expand Up @@ -258,8 +257,8 @@ def _infer_model_engine_map(self, model_engine_map=None):
if not json_str:
return None
try:
return json.loads(json_str)
except json.JSONDecodeError:
return json_loads(json_str)
except JSONDecodeError:
return None

def _consume_kwargs(self, kwargs):
Expand Down
12 changes: 6 additions & 6 deletions src/handyllm/prompt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import re
from typing import IO, Generator, MutableMapping, MutableSequence, Optional
import yaml

from .types import PathType, ShortChatChunk
from ._io import yaml_dump, yaml_load


class PromptConverter:
Expand Down Expand Up @@ -67,11 +67,11 @@ def raw2msgs(self, raw_prompt: str):
if "type" in extra_properties:
type_of_msg = extra_properties.pop("type")
if type_of_msg == "tool_calls":
msg["tool_calls"] = yaml.safe_load(content)
msg["tool_calls"] = yaml_load(content)
msg["content"] = None
elif type_of_msg == "content_array":
# parse content array
msg["content"] = yaml.safe_load(content)
msg["content"] = yaml_load(content)
for key in extra_properties:
msg[key] = extra_properties[key]
msgs.append(msg)
Expand Down Expand Up @@ -99,10 +99,10 @@ def msgs2raw(msgs):
}
if tool_calls:
extra_properties["type"] = "tool_calls"
content = yaml.dump(tool_calls, allow_unicode=True)
content = yaml_dump(tool_calls)
elif isinstance(content, MutableSequence):
extra_properties["type"] = "content_array"
content = yaml.dump(content, allow_unicode=True)
content = yaml_dump(content)
if extra_properties:
extra = (
" {"
Expand Down Expand Up @@ -136,7 +136,7 @@ def consume_stream2fd(
fd.write(' {type="tool_calls"}\n')
role_completed = True
# dump tool calls
fd.write(yaml.dump([tool_call], allow_unicode=True))
fd.write(yaml_dump([tool_call]))
elif text:
if not role_completed:
fd.write("\n")
Expand Down
6 changes: 3 additions & 3 deletions src/handyllm/requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
import asyncio
import logging
import json
import time
import requests
import httpx
Expand All @@ -33,6 +32,7 @@
ChatResponse,
CompletionsResponse,
)
from ._io import json_loads


module_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -272,7 +272,7 @@ def _gen_stream_response(self, raw_response: requests.Response, prepare_ret):
return
if byte_line.startswith(b"data: "):
line = byte_line[len(b"data: ") :].decode("utf-8")
yield json.loads(line)
yield json_loads(line)
except Exception as e:
if self._exception_callback:
self._exception_callback(e, prepare_ret)
Expand Down Expand Up @@ -393,7 +393,7 @@ async def _agen_stream_response(self, raw_response: httpx.Response, prepare_ret)
return
if raw_line.startswith("data: "):
line = raw_line[len("data: ") :]
yield json.loads(line)
yield json_loads(line)
except Exception as e:
if self._exception_callback:
self._exception_callback(e, prepare_ret)
Expand Down
4 changes: 3 additions & 1 deletion src/handyllm/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def __init__(self, *args, **kwargs):
self[key] = self._wrap(value)

def __getattr__(self, attr):
return self[attr]
if attr in self:
return self[attr]
raise AttributeError(f"Attribute {attr} not found")

def __setattr__(self, attr, value):
self[attr] = self._wrap(value)
Expand Down

0 comments on commit cdf8400

Please sign in to comment.