Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use auto gen seed when using LLM cache #441

Merged
merged 15 commits into from
Oct 21, 2024
10 changes: 10 additions & 0 deletions rdagent/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ class RDAgentSettings(BaseSettings):
# TODO: (xiao) think it can be a separate config.
log_trace_path: str | None = None

# Behavior of returning answers to the same question when caching is enabled
use_auto_chat_cache_seed_gen: bool = False
"""
`_create_chat_completion_inner_function` provdies a feature to pass in a seed to affect the cache hash key
We want to enable a auto seed generator to get different default seed for `_create_chat_completion_inner_function`
if seed is not given.
So the cache will only not miss you ask the same question on same round.
"""
init_chat_cache_seed: int = 42

# azure document intelligence configs
azure_document_intelligence_key: str = ""
azure_document_intelligence_endpoint: str = ""
Expand Down
44 changes: 43 additions & 1 deletion rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import multiprocessing as mp
import pickle
import random
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, NoReturn, cast
Expand Down Expand Up @@ -86,11 +87,48 @@ class of `class_path`
return getattr(module, class_name)


class CacheSeedGen:
"""
It is a global seed generator to generate a sequence of seeds.
This will support the feature `use_auto_chat_cache_seed_gen` claim

NOTE:
- This seed is specifically for the cache and is different from a regular seed.
- If the cache is removed, setting the same seed will not produce the same QA trace.
"""

def __init__(self) -> None:
self.set_seed(RD_AGENT_SETTINGS.init_chat_cache_seed)

def set_seed(self, seed: int) -> None:
random.seed(seed)

def get_next_seed(self) -> int:
"""generate next random int"""
return random.randint(0, 10000) # noqa: S311


LLM_CACHE_SEED_GEN = CacheSeedGen()


def _subprocess_wrapper(f: Callable, seed: int, args: list) -> Any:
"""
It is a function wrapper. To ensure the subprocess has a fixed start seed.
"""

LLM_CACHE_SEED_GEN.set_seed(seed)
return f(*args)


def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) -> list:
"""It will use multiprocessing to call the functions in func_calls with the given parameters.
The results equals to `return [f(*args) for f, args in func_calls]`
It will not call multiprocessing if `n=1`

NOTE:
We coooperate with chat_cache_seed feature
We ensure get the same seed trace even we have multiple number of seed

Parameters
----------
func_calls : List[Tuple[Callable, Tuple]]
Expand All @@ -105,8 +143,12 @@ def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) ->
"""
if n == 1:
return [f(*args) for f, args in func_calls]

with mp.Pool(processes=max(1, min(n, len(func_calls)))) as pool:
results = [pool.apply_async(f, args) for f, args in func_calls]
results = [
pool.apply_async(_subprocess_wrapper, args=(f, LLM_CACHE_SEED_GEN.get_next_seed(), args))
for f, args in func_calls
]
return [result.get() for result in results]


Expand Down
9 changes: 7 additions & 2 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import json
import os
import random
import re
import sqlite3
import ssl
Expand All @@ -16,7 +17,8 @@
import numpy as np
import tiktoken

from rdagent.core.utils import SingletonBaseClass
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass
from rdagent.log import LogColors
from rdagent.log import rdagent_logger as logger
from rdagent.oai.llm_conf import LLM_SETTINGS
Expand Down Expand Up @@ -594,7 +596,10 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
To make retries useful, we need to enable a seed.
This seed is different from `self.chat_seed` for GPT. It is for the local cache mechanism enabled by RD-Agent locally.
"""
# TODO: we can add this function back to avoid so much `LLM_SETTINGS.log_llm_chat_content`
if seed is None and RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen:
seed = LLM_CACHE_SEED_GEN.get_next_seed()

# TODO: we can add this function back to avoid so much `self.cfg.log_llm_chat_content`
if LLM_SETTINGS.log_llm_chat_content:
logger.info(self._build_log_messages(messages), tag="llm_messages")
# TODO: fail to use loguru adaptor due to stream response
Expand Down
129 changes: 129 additions & 0 deletions test/oai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from rdagent.oai.llm_utils import APIBackend


def _worker(system_prompt, user_prompt):
api = APIBackend()
return api.build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)


class TestChatCompletion(unittest.TestCase):
def test_chat_completion(self) -> None:
system_prompt = "You are a helpful assistant."
Expand Down Expand Up @@ -45,6 +53,127 @@ def test_chat_multi_round(self) -> None:
response2 = session.build_chat_completion(user_prompt=user_prompt_2)
assert response2 is not None

def test_chat_cache(self) -> None:
"""
Tests:
- Single process, ask same question, enable cache
- 2 pass
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import LLM_CACHE_SEED_GEN
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"

origin_value = (
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
)

LLM_SETTINGS.use_chat_cache = True
LLM_SETTINGS.dump_chat_cache = True

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True

LLM_CACHE_SEED_GEN.set_seed(10)
response1 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response2 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)

LLM_CACHE_SEED_GEN.set_seed(20)
response3 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response4 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)

LLM_CACHE_SEED_GEN.set_seed(10)
response5 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response6 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)

# Reset, for other tests
(
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
) = origin_value

assert (
response1 != response3 and response2 != response4
), "Responses sequence should be determined by 'init_chat_cache_seed'"
assert (
response1 == response5 and response2 == response6
), "Responses sequence should be determined by 'init_chat_cache_seed'"
assert (
response1 != response2 and response3 != response4 and response5 != response6
), "Same question should get different response when use_auto_chat_cache_seed_gen=True"

def test_chat_cache_multiprocess(self) -> None:
"""
Tests:
- Multi process, ask same question, enable cache
- 2 pass
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import LLM_CACHE_SEED_GEN, multiprocessing_wrapper
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"

origin_value = (
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
)

LLM_SETTINGS.use_chat_cache = True
LLM_SETTINGS.dump_chat_cache = True

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True

func_calls = [(_worker, (system_prompt, user_prompt)) for _ in range(4)]

LLM_CACHE_SEED_GEN.set_seed(10)
responses1 = multiprocessing_wrapper(func_calls, n=4)
LLM_CACHE_SEED_GEN.set_seed(20)
responses2 = multiprocessing_wrapper(func_calls, n=4)
LLM_CACHE_SEED_GEN.set_seed(10)
responses3 = multiprocessing_wrapper(func_calls, n=4)

# Reset, for other tests
(
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
) = origin_value
for i in range(len(func_calls)):
assert (
responses1[i] != responses2[i] and responses1[i] == responses3[i]
), "Responses sequence should be determined by 'init_chat_cache_seed'"
for j in range(i + 1, len(func_calls)):
assert (
responses1[i] != responses1[j] and responses2[i] != responses2[j]
), "Same question should get different response when use_auto_chat_cache_seed_gen=True"


if __name__ == "__main__":
unittest.main()
Loading