From 623252f2e47bb399a72ba85f1fa6edf9277d94e2 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Fri, 9 Aug 2024 19:37:09 +0800 Subject: [PATCH 1/7] feat(PromptConverter): default use keywords to specify message type --- src/handyllm/prompt_converter.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/handyllm/prompt_converter.py b/src/handyllm/prompt_converter.py index c80ef9a..2a59b14 100644 --- a/src/handyllm/prompt_converter.py +++ b/src/handyllm/prompt_converter.py @@ -99,25 +99,18 @@ def msgs2raw(msgs): role = message.get("role") content = message.get("content") tool_calls = message.get("tool_calls") - extra_properties = { - key: message[key] - for key in message - if key not in ["role", "content", "tool_calls"] - } + extras = [] + for key in message: + if key not in ["role", "content", "tool_calls"]: + extras.append(f"{key}={message[key]}") if tool_calls: - extra_properties["type"] = "tool_calls" + extras.append("tool") content = yaml_dump(tool_calls) elif isinstance(content, MutableSequence): - extra_properties["type"] = "content_array" + extras.append("array") content = yaml_dump(content) - if extra_properties: - extra = ( - " {" - + " ".join( - [f'{key}="{extra_properties[key]}"' for key in extra_properties] - ) - + "}" - ) + if extras: + extra = " {" + " ".join(extras) + "}" else: extra = "" messages.append(f"${role}${extra}\n{content}") @@ -140,7 +133,7 @@ def consume_stream2fd( fd.write(f"${role}$") # do not add newline if tool_call: if not role_completed: - fd.write(' {type="tool_calls"}\n') + fd.write(" {tool}\n") role_completed = True # dump tool calls fd.write(yaml_dump([tool_call])) From 58f2b191ec774ff82d6d83452e8f7cb1c4721afa Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Mon, 12 Aug 2024 20:24:01 +0800 Subject: [PATCH 2/7] feat(CacheManager): support 'only_load' option --- src/handyllm/cache_manager.py | 43 ++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/src/handyllm/cache_manager.py b/src/handyllm/cache_manager.py index 501e5d8..0b13f18 100644 --- a/src/handyllm/cache_manager.py +++ b/src/handyllm/cache_manager.py @@ -39,9 +39,6 @@ def _load_files( load_method: Optional[Union[Collection[Optional[StrHandler]], StrHandler]], infer_from_suffix: bool, ): - all_files_exist = all(Path(file).exists() for file in files) - if not all_files_exist: - return None if load_method is None: load_method = (None,) * len(files) if not isinstance(load_method, Collection): @@ -100,11 +97,18 @@ def _dump_files( class CacheManager: def __init__( - self, base_dir: PathType, enabled: bool = True, only_dump: bool = False + self, + base_dir: PathType, + enabled: bool = True, + only_dump: bool = False, + only_load: bool = False, ): self.base_dir = base_dir self.enabled = enabled self.only_dump = only_dump + self.only_load = only_load + if only_dump and only_load: + raise ValueError("only_dump and only_load cannot be True at the same time.") def cache( self, @@ -112,6 +116,7 @@ def cache( out: Union[PathType, Iterable[PathType]], enabled: Optional[bool] = None, only_dump: Optional[bool] = None, + only_load: Optional[bool] = None, dump_method: Optional[ Union[Collection[Optional[StringifyHandler]], StringifyHandler] ] = None, @@ -129,6 +134,10 @@ def cache( enabled = self.enabled if only_dump is None: only_dump = self.only_dump + if only_load is None: + only_load = self.only_load + if only_dump and only_load: + raise ValueError("only_dump and only_load cannot be True at the same time.") if not enabled: return func if isinstance(out, str) or isinstance(out, PathLike): @@ -139,9 +148,18 @@ def cache( @wraps(func) async def async_wrapped_func(*args: P.args, **kwargs: P.kwargs): if not only_dump: - results = _load_files(full_files, load_method, infer_from_suffix) - if results is not None: + non_exist_files = [ + file for file in full_files if not Path(file).exists() + ] + if len(non_exist_files) == 0: + results = _load_files( + full_files, load_method, infer_from_suffix + ) return cast(R, results) + elif only_load: + raise FileNotFoundError( + f"Cache files not found: {non_exist_files}" + ) results = await func(*args, **kwargs) _dump_files(results, full_files, dump_method, infer_from_suffix) return cast(R, results) @@ -152,9 +170,18 @@ async def async_wrapped_func(*args: P.args, **kwargs: P.kwargs): @wraps(func) def sync_wrapped_func(*args: P.args, **kwargs: P.kwargs): if not only_dump: - results = _load_files(full_files, load_method, infer_from_suffix) - if results is not None: + non_exist_files = [ + file for file in full_files if not Path(file).exists() + ] + if len(non_exist_files) == 0: + results = _load_files( + full_files, load_method, infer_from_suffix + ) return cast(R, results) + elif only_load: + raise FileNotFoundError( + f"Cache files not found: {non_exist_files}" + ) results = func(*args, **kwargs) _dump_files(results, full_files, dump_method, infer_from_suffix) return cast(R, results) From 655aa04da87a5f909d680a81a11b36715e0fda65 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Mon, 12 Aug 2024 20:49:41 +0800 Subject: [PATCH 3/7] feat(CacheManager): add ensure_dumped method Example scenario: The decorated function can be called multiple times, but we only want to check files existence without loading them multiple times (which is the case of `cache` method). --- src/handyllm/cache_manager.py | 48 +++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/handyllm/cache_manager.py b/src/handyllm/cache_manager.py index 0b13f18..d2d7332 100644 --- a/src/handyllm/cache_manager.py +++ b/src/handyllm/cache_manager.py @@ -187,3 +187,51 @@ def sync_wrapped_func(*args: P.args, **kwargs: P.kwargs): return cast(R, results) return cast(Callable[P, R], sync_wrapped_func) + + def ensure_dumped( + self, + func: Callable[P, R], + out: Union[PathType, Iterable[PathType]], + dump_method: Optional[ + Union[Collection[Optional[StringifyHandler]], StringifyHandler] + ] = None, + infer_from_suffix: bool = True, + ) -> Callable[P, None]: + """ + Ensure the output of the original function is cached. Will not + call the original function and will not load the files if they + exist. The decorated function returns None. + + Example scenario: + + The decorated function can be called multiple times, but we only + want to check files existence without loading them multiple times + (which is the case of `cache` method). + """ + if isinstance(out, str) or isinstance(out, PathLike): + out = [out] + full_files = [Path(self.base_dir, file) for file in out] + if iscoroutinefunction(func): + + @wraps(func) + async def async_wrapped_func(*args: P.args, **kwargs: P.kwargs): + non_exist_files = [ + file for file in full_files if not Path(file).exists() + ] + if len(non_exist_files) > 0: + results = await func(*args, **kwargs) + _dump_files(results, full_files, dump_method, infer_from_suffix) + + return cast(Callable[P, None], async_wrapped_func) + else: + + @wraps(func) + def sync_wrapped_func(*args: P.args, **kwargs: P.kwargs): + non_exist_files = [ + file for file in full_files if not Path(file).exists() + ] + if len(non_exist_files) > 0: + results = func(*args, **kwargs) + _dump_files(results, full_files, dump_method, infer_from_suffix) + + return cast(Callable[P, None], sync_wrapped_func) From 08dfa160649b44afeb6788554c56bf275fea939f Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Mon, 12 Aug 2024 20:52:18 +0800 Subject: [PATCH 4/7] test: add test for CacheManager.ensure_dumped() --- tests/test_cache.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_cache.py b/tests/test_cache.py index 1aeae99..4836d18 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -90,3 +90,16 @@ def dump_method(i: int): out2 = wrapped_func() assert out2 == 1 + + +def test_ensure_dumped(tmp_path: Path, capsys: CaptureFixture[str]): + cm = CacheManager(base_dir=tmp_path) + wrapped_func = cm.ensure_dumped(func=func, out="test.txt", dump_method=str) + out = wrapped_func() + assert (tmp_path / "test.txt").read_text() == "1" + assert out is None + assert capsys.readouterr().out == "In func\n" + + out2 = wrapped_func() + assert out2 is None + assert capsys.readouterr().out == "" From 3f74b96fec269681630606bbfeef74661e290a8f Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sat, 31 Aug 2024 19:30:03 +0800 Subject: [PATCH 5/7] feat(OpenAIClient): ensure only client-wide credentials are used; API runtime credentials are ignored --- src/handyllm/openai_client.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/handyllm/openai_client.py b/src/handyllm/openai_client.py index 263d419..67f03d5 100644 --- a/src/handyllm/openai_client.py +++ b/src/handyllm/openai_client.py @@ -85,6 +85,10 @@ class OpenAIClient: # set this to your endpoint manager endpoint_manager: Optional[EndpointManager] = None + # ensure only client-wide credentials are used; + # API runtime credentials are ignored + ensure_client_credentials: bool = False + def __init__( self, mode: Union[str, ClientMode] = ClientMode.SYNC, @@ -336,6 +340,13 @@ def _consume_kwargs(self, kwargs): else: engine = model dest_url = kwargs.pop("dest_url", dest_url) + + if self.ensure_client_credentials: + api_key = self.api_key + organization = self.organization + api_base = self.api_base + api_type = self.api_type + api_version = self.api_version return api_key, organization, api_base, api_type, api_version, engine, dest_url def _make_requestor( From 5bd6f5f5cec8c0efa2a51c8c64ebe16bff0d9228 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sat, 31 Aug 2024 19:58:03 +0800 Subject: [PATCH 6/7] fix(OpenAIClient): add missing OpenAI api_base and api_type when using client credentials --- src/handyllm/openai_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/handyllm/openai_client.py b/src/handyllm/openai_client.py index 67f03d5..c913250 100644 --- a/src/handyllm/openai_client.py +++ b/src/handyllm/openai_client.py @@ -344,8 +344,8 @@ def _consume_kwargs(self, kwargs): if self.ensure_client_credentials: api_key = self.api_key organization = self.organization - api_base = self.api_base - api_type = self.api_type + api_base = self.api_base or API_BASE_OPENAI + api_type = self.api_type or API_TYPE_OPENAI api_version = self.api_version return api_key, organization, api_base, api_type, api_version, engine, dest_url From ae2ede287cdc517873206b1e4a8ba4e3a7566537 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sat, 31 Aug 2024 20:01:16 +0800 Subject: [PATCH 7/7] test(OpenAIClient): add test for ensure_client_credentials flag --- tests/test_client.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index d6786b3..6f35d20 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -142,3 +142,15 @@ def test_chat_stream(): ): result += chunk.choices[0].delta["content"] assert result == "Hello" + + +def test_ensure_client_credentials(): + client = OpenAIClient() + client.api_key = "client_key" + assert ( + client.chat(messages=[], api_key="should_be_used").api_key == "should_be_used" + ) + client.ensure_client_credentials = True + assert ( + client.chat(messages=[], api_key="should_not_be_used").api_key == "client_key" + )