Skip to content

Commit

Permalink
New features of PromptConverter & CacheManager & OpenAIClient
Browse files Browse the repository at this point in the history
Merge pull request #47 from atomiechen/dev
  • Loading branch information
atomiechen authored Aug 31, 2024
2 parents 72dd0cc + ae2ede2 commit 5d9c7b5
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 24 deletions.
91 changes: 83 additions & 8 deletions src/handyllm/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def _load_files(
load_method: Optional[Union[Collection[Optional[StrHandler]], StrHandler]],
infer_from_suffix: bool,
):
all_files_exist = all(Path(file).exists() for file in files)
if not all_files_exist:
return None
if load_method is None:
load_method = (None,) * len(files)
if not isinstance(load_method, Collection):
Expand Down Expand Up @@ -100,18 +97,26 @@ def _dump_files(

class CacheManager:
def __init__(
self, base_dir: PathType, enabled: bool = True, only_dump: bool = False
self,
base_dir: PathType,
enabled: bool = True,
only_dump: bool = False,
only_load: bool = False,
):
self.base_dir = base_dir
self.enabled = enabled
self.only_dump = only_dump
self.only_load = only_load
if only_dump and only_load:
raise ValueError("only_dump and only_load cannot be True at the same time.")

def cache(
self,
func: Callable[P, R],
out: Union[PathType, Iterable[PathType]],
enabled: Optional[bool] = None,
only_dump: Optional[bool] = None,
only_load: Optional[bool] = None,
dump_method: Optional[
Union[Collection[Optional[StringifyHandler]], StringifyHandler]
] = None,
Expand All @@ -129,6 +134,10 @@ def cache(
enabled = self.enabled
if only_dump is None:
only_dump = self.only_dump
if only_load is None:
only_load = self.only_load
if only_dump and only_load:
raise ValueError("only_dump and only_load cannot be True at the same time.")
if not enabled:
return func
if isinstance(out, str) or isinstance(out, PathLike):
Expand All @@ -139,9 +148,18 @@ def cache(
@wraps(func)
async def async_wrapped_func(*args: P.args, **kwargs: P.kwargs):
if not only_dump:
results = _load_files(full_files, load_method, infer_from_suffix)
if results is not None:
non_exist_files = [
file for file in full_files if not Path(file).exists()
]
if len(non_exist_files) == 0:
results = _load_files(
full_files, load_method, infer_from_suffix
)
return cast(R, results)
elif only_load:
raise FileNotFoundError(
f"Cache files not found: {non_exist_files}"
)
results = await func(*args, **kwargs)
_dump_files(results, full_files, dump_method, infer_from_suffix)
return cast(R, results)
Expand All @@ -152,11 +170,68 @@ async def async_wrapped_func(*args: P.args, **kwargs: P.kwargs):
@wraps(func)
def sync_wrapped_func(*args: P.args, **kwargs: P.kwargs):
if not only_dump:
results = _load_files(full_files, load_method, infer_from_suffix)
if results is not None:
non_exist_files = [
file for file in full_files if not Path(file).exists()
]
if len(non_exist_files) == 0:
results = _load_files(
full_files, load_method, infer_from_suffix
)
return cast(R, results)
elif only_load:
raise FileNotFoundError(
f"Cache files not found: {non_exist_files}"
)
results = func(*args, **kwargs)
_dump_files(results, full_files, dump_method, infer_from_suffix)
return cast(R, results)

return cast(Callable[P, R], sync_wrapped_func)

def ensure_dumped(
self,
func: Callable[P, R],
out: Union[PathType, Iterable[PathType]],
dump_method: Optional[
Union[Collection[Optional[StringifyHandler]], StringifyHandler]
] = None,
infer_from_suffix: bool = True,
) -> Callable[P, None]:
"""
Ensure the output of the original function is cached. Will not
call the original function and will not load the files if they
exist. The decorated function returns None.
Example scenario:
The decorated function can be called multiple times, but we only
want to check files existence without loading them multiple times
(which is the case of `cache` method).
"""
if isinstance(out, str) or isinstance(out, PathLike):
out = [out]
full_files = [Path(self.base_dir, file) for file in out]
if iscoroutinefunction(func):

@wraps(func)
async def async_wrapped_func(*args: P.args, **kwargs: P.kwargs):
non_exist_files = [
file for file in full_files if not Path(file).exists()
]
if len(non_exist_files) > 0:
results = await func(*args, **kwargs)
_dump_files(results, full_files, dump_method, infer_from_suffix)

return cast(Callable[P, None], async_wrapped_func)
else:

@wraps(func)
def sync_wrapped_func(*args: P.args, **kwargs: P.kwargs):
non_exist_files = [
file for file in full_files if not Path(file).exists()
]
if len(non_exist_files) > 0:
results = func(*args, **kwargs)
_dump_files(results, full_files, dump_method, infer_from_suffix)

return cast(Callable[P, None], sync_wrapped_func)
11 changes: 11 additions & 0 deletions src/handyllm/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class OpenAIClient:
# set this to your endpoint manager
endpoint_manager: Optional[EndpointManager] = None

# ensure only client-wide credentials are used;
# API runtime credentials are ignored
ensure_client_credentials: bool = False

def __init__(
self,
mode: Union[str, ClientMode] = ClientMode.SYNC,
Expand Down Expand Up @@ -336,6 +340,13 @@ def _consume_kwargs(self, kwargs):
else:
engine = model
dest_url = kwargs.pop("dest_url", dest_url)

if self.ensure_client_credentials:
api_key = self.api_key
organization = self.organization
api_base = self.api_base or API_BASE_OPENAI
api_type = self.api_type or API_TYPE_OPENAI
api_version = self.api_version
return api_key, organization, api_base, api_type, api_version, engine, dest_url

def _make_requestor(
Expand Down
25 changes: 9 additions & 16 deletions src/handyllm/prompt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,18 @@ def msgs2raw(msgs):
role = message.get("role")
content = message.get("content")
tool_calls = message.get("tool_calls")
extra_properties = {
key: message[key]
for key in message
if key not in ["role", "content", "tool_calls"]
}
extras = []
for key in message:
if key not in ["role", "content", "tool_calls"]:
extras.append(f"{key}={message[key]}")
if tool_calls:
extra_properties["type"] = "tool_calls"
extras.append("tool")
content = yaml_dump(tool_calls)
elif isinstance(content, MutableSequence):
extra_properties["type"] = "content_array"
extras.append("array")
content = yaml_dump(content)
if extra_properties:
extra = (
" {"
+ " ".join(
[f'{key}="{extra_properties[key]}"' for key in extra_properties]
)
+ "}"
)
if extras:
extra = " {" + " ".join(extras) + "}"
else:
extra = ""
messages.append(f"${role}${extra}\n{content}")
Expand All @@ -140,7 +133,7 @@ def consume_stream2fd(
fd.write(f"${role}$") # do not add newline
if tool_call:
if not role_completed:
fd.write(' {type="tool_calls"}\n')
fd.write(" {tool}\n")
role_completed = True
# dump tool calls
fd.write(yaml_dump([tool_call]))
Expand Down
13 changes: 13 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,16 @@ def dump_method(i: int):

out2 = wrapped_func()
assert out2 == 1


def test_ensure_dumped(tmp_path: Path, capsys: CaptureFixture[str]):
cm = CacheManager(base_dir=tmp_path)
wrapped_func = cm.ensure_dumped(func=func, out="test.txt", dump_method=str)
out = wrapped_func()
assert (tmp_path / "test.txt").read_text() == "1"
assert out is None
assert capsys.readouterr().out == "In func\n"

out2 = wrapped_func()
assert out2 is None
assert capsys.readouterr().out == ""
12 changes: 12 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,15 @@ def test_chat_stream():
):
result += chunk.choices[0].delta["content"]
assert result == "Hello"


def test_ensure_client_credentials():
client = OpenAIClient()
client.api_key = "client_key"
assert (
client.chat(messages=[], api_key="should_be_used").api_key == "should_be_used"
)
client.ensure_client_credentials = True
assert (
client.chat(messages=[], api_key="should_not_be_used").api_key == "client_key"
)

0 comments on commit 5d9c7b5

Please sign in to comment.