Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use auto gen seed when using LLM cache #441

Merged
merged 15 commits into from
Oct 21, 2024
49 changes: 3 additions & 46 deletions rdagent/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,17 @@ class RDAgentSettings(BaseSettings):
# Log configs
# TODO: (xiao) think it can be a separate config.
log_trace_path: str | None = None
log_llm_chat_content: bool = True

use_azure: bool = False
use_azure_token_provider: bool = False
managed_identity_client_id: str | None = None
max_retry: int = 10
retry_wait_seconds: int = 1
dump_chat_cache: bool = False
use_chat_cache: bool = False

# Behavior of returning answers to the same question when caching is enabled
use_auto_chat_cache_seed_gen: bool = False
"""
`_create_chat_completion_inner_function` provdies a feature to pass in a seed to affect the cache hash key
We weant to enable a auto seed generator to get different default seed for `_create_chat_completion_inner_function` if seed is not given.
We want to enable a auto seed generator to get different default seed for `_create_chat_completion_inner_function`
if seed is not given.
So the cache will only not miss you ask the same question on same round.
"""
init_chat_cache_seed: int = 42

dump_embedding_cache: bool = False
use_embedding_cache: bool = False
prompt_cache_path: str = str(Path.cwd() / "prompt_cache.db")
session_cache_folder_location: str = str(Path.cwd() / "session_cache_folder/")
max_past_message_include: int = 10

# Chat configs
openai_api_key: str = "" # TODO: simplify the key design.
chat_openai_api_key: str = ""
chat_azure_api_base: str = ""
chat_azure_api_version: str = ""
chat_model: str = "gpt-4-turbo"
chat_max_tokens: int = 3000
chat_temperature: float = 0.5
chat_stream: bool = True
chat_seed: int | None = None
chat_frequency_penalty: float = 0.0
chat_presence_penalty: float = 0.0
chat_token_limit: int = (
100000 # 100000 is the maximum limit of gpt4, which might increase in the future version of gpt
)
default_system_prompt: str = "You are an AI assistant who helps to answer user's questions."

# Embedding configs
embedding_openai_api_key: str = ""
embedding_azure_api_base: str = ""
embedding_azure_api_version: str = ""
embedding_model: str = ""
embedding_max_str_num: int = 50

# offline llama2 related config
use_llama2: bool = False
llama2_ckpt_dir: str = "Llama-2-7b-chat"
llama2_tokenizer_path: str = "Llama-2-7b-chat/tokenizer.model"
llams2_max_batch_size: int = 8

# azure document intelligence configs
azure_document_intelligence_key: str = ""
azure_document_intelligence_endpoint: str = ""
Expand Down
16 changes: 7 additions & 9 deletions rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import json
import multiprocessing as mp
import pickle
import random
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, NoReturn, cast

from filelock import FileLock
from fuzzywuzzy import fuzz # type: ignore[import-untyped]

from rdagent.oai.llm_utils import APIBackend
from rdagent.core.conf import RD_AGENT_SETTINGS


class RDAgentException(Exception): # noqa: N818
Expand Down Expand Up @@ -86,14 +87,12 @@ class of `class_path`
return getattr(module, class_name)


def _subprocess_wrapper(f, seed):
def _subprocess_wrapper(f: Callable, seed: int, args: list) -> Any:
"""
It is a function wrapper. To ensure the subprocess has a fixed start seed.
"""
def _f(*args, **kwargs):
APIBackend.cache_seed_gen.set_seed(seed)
return f(*args, **kwargs)
return _f
random.seed(seed)
return f(*args)


def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) -> list:
Expand All @@ -119,10 +118,9 @@ def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) ->
"""
if n == 1:
return [f(*args) for f, args in func_calls]

with mp.Pool(processes=n) as pool:
with mp.Pool(processes=max(1, min(n, len(func_calls)))) as pool:
results = [
pool.apply_async(_subprocess_wrapper(f, APIBackend.cache_seed_gen.get_next_seed()), args=args)
pool.apply_async(_subprocess_wrapper, args=(f, random.randint(0, 10000), args)) # noqa: S311
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use get_next_seed;
We may use other random seed mechanism if we want to solve the confliction with random.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we use # noqa: S311?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S311, i don't think we need this encryption requirements

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use get_next_seed; We may use other random seed mechanism if we want to solve the confliction with random.

APIBackend cannot be called here because a circular import occurs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import it in the function to avoid circular import

for f, args in func_calls
]
return [result.get() for result in results]
Expand Down
18 changes: 9 additions & 9 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import random
import datetime
import hashlib
import json
import os
import random
import re
import sqlite3
import ssl
Expand All @@ -18,7 +17,7 @@
import numpy as np
import tiktoken

from rdagent.core.conf import RD_AGENT_SETTINGS, RDAgentSettings
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import SingletonBaseClass
from rdagent.log import LogColors
from rdagent.log import rdagent_logger as logger
Expand Down Expand Up @@ -251,13 +250,14 @@ class CacheSeedGen:
- This seed is specifically for the cache and is different from a regular seed.
- If the cache is removed, setting the same seed will not produce the same QA trace.
"""
def __init__(self):
self.set_seed(RDAgentSettings.init_chat_cache_seed)

def set_seed(self, seed: int):
def __init__(self) -> None:
self.set_seed(RD_AGENT_SETTINGS.init_chat_cache_seed)

def set_seed(self, seed: int) -> None:
random.seed(seed)

def get_next_seed(self):
def get_next_seed(self) -> int:
"""generate next random int"""
return random.randint(0, 10000)

Expand Down Expand Up @@ -618,11 +618,11 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
To make retries useful, we need to enable a seed.
This seed is different from `self.chat_seed` for GPT. It is for the local cache mechanism enabled by RD-Agent locally.
"""
if seed is None and RDAgentSettings.use_auto_chat_cache_seed_gen:
if seed is None and RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen:
seed = self.cache_seed_gen.get_next_seed()

# TODO: we can add this function back to avoid so much `self.cfg.log_llm_chat_content`
if self.cfg.log_llm_chat_content:
if LLM_SETTINGS.log_llm_chat_content:
logger.info(self._build_log_messages(messages), tag="llm_messages")
# TODO: fail to use loguru adaptor due to stream response
input_content_json = json.dumps(messages)
Expand Down
123 changes: 122 additions & 1 deletion test/oai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from rdagent.oai.llm_utils import APIBackend


def _worker(system_prompt, user_prompt):
api = APIBackend()
return api.build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)


class TestChatCompletion(unittest.TestCase):
def test_chat_completion(self) -> None:
system_prompt = "You are a helpful assistant."
Expand Down Expand Up @@ -45,13 +53,126 @@ def test_chat_multi_round(self) -> None:
response2 = session.build_chat_completion(user_prompt=user_prompt_2)
assert response2 is not None

def test_chat_cache(self):
def test_chat_cache(self) -> None:
"""
Tests:
- Single process, ask same question, enable cache
- 2 pass
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import multiprocessing_wrapper
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"

origin_value = (
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
)

LLM_SETTINGS.use_chat_cache = True
LLM_SETTINGS.dump_chat_cache = True

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True

APIBackend.cache_seed_gen.set_seed(10)
response1 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response2 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)

APIBackend.cache_seed_gen.set_seed(20)
response3 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response4 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)

APIBackend.cache_seed_gen.set_seed(10)
response5 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response6 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
)

# Reset, for other tests
(
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
) = origin_value

assert (
response1 != response3 and response2 != response4
), "Responses sequence should be determined by 'init_chat_cache_seed'"
assert (
response1 == response5 and response2 == response6
), "Responses sequence should be determined by 'init_chat_cache_seed'"
assert (
response1 != response2 and response3 != response4 and response5 != response6
), "Same question should get different response when use_auto_chat_cache_seed_gen=True"

def test_chat_cache_multiprocess(self) -> None:
"""
Tests:
- Multi process, ask same question, enable cache
- 2 pass
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import multiprocessing_wrapper
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"

origin_value = (
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
)

LLM_SETTINGS.use_chat_cache = True
LLM_SETTINGS.dump_chat_cache = True

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True

func_calls = [(_worker, (system_prompt, user_prompt)) for _ in range(4)]

APIBackend.cache_seed_gen.set_seed(10)
responses1 = multiprocessing_wrapper(func_calls, n=4)
APIBackend.cache_seed_gen.set_seed(20)
responses2 = multiprocessing_wrapper(func_calls, n=4)
APIBackend.cache_seed_gen.set_seed(10)
responses3 = multiprocessing_wrapper(func_calls, n=4)

# Reset, for other tests
(
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
) = origin_value
for i in range(len(func_calls)):
assert (
responses1[i] != responses2[i] and responses1[i] == responses3[i]
), "Responses sequence should be determined by 'init_chat_cache_seed'"
for j in range(i + 1, len(func_calls)):
assert (
responses1[i] != responses1[j] and responses2[i] != responses2[j]
), "Same question should get different response when use_auto_chat_cache_seed_gen=True"


if __name__ == "__main__":
Expand Down
Loading