Skip to content

Commit

Permalink
Updated caching with max batching and added api_calls under stats
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhil authored and nikhil committed Nov 4, 2024
1 parent b7190e8 commit c6e58f3
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 51 deletions.
54 changes: 27 additions & 27 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
lotus.logger.setLevel("DEBUG")

# Environment flags to enable/disable tests
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true"
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "true").lower() == "true"
ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true"

MODEL_NAME_TO_ENABLED = {
Expand Down Expand Up @@ -77,18 +77,18 @@ def test_filter_caching(setup_models, model):
user_instruction = "{Text} is a positive sentiment"

# First call - should make API calls
initial_api_calls = gpt_4o_mini.api_calls
initial_api_calls = lm.stats.total_usage.api_calls
filtered_df1 = df.sem_filter(user_instruction)
first_call_api_count = gpt_4o_mini.api_calls - initial_api_calls
first_call_api_count = lm.stats.total_usage.api_calls

# Second call - should use cache
filtered_df2 = df.sem_filter(user_instruction)
second_call_api_count = gpt_4o_mini.api_calls - (initial_api_calls + first_call_api_count)
second_call_api_count = lm.stats.total_usage.api_calls

# Verify results are the same
assert filtered_df1.equals(filtered_df2)

assert gpt_4o_mini.api_calls == 0
assert lm.stats.total_usage.api_calls == 0
assert initial_api_calls == 0
# Verify first call made API calls
assert first_call_api_count == 0, "First call should make API calls"
Expand All @@ -97,34 +97,34 @@ def test_filter_caching(setup_models, model):
assert second_call_api_count == 0, "Second call should use cache (no new API calls)"


def test_filter_caching(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)
# def test_filter_caching(setup_models):
# gpt_4o_mini, _ = setup_models
# lotus.settings.configure(lm=gpt_4o_mini)

# Test filter operation on a dataframe
data = {"Text": ["I am really excited to go to class today!", "I am very sad"]}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"
# # Test filter operation on a dataframe
# data = {"Text": ["I am really excited to go to class today!", "I am very sad"]}
# df = pd.DataFrame(data)
# user_instruction = "{Text} is a positive sentiment"

# First call - should make API calls
initial_api_calls = gpt_4o_mini.api_calls
filtered_df1 = df.sem_filter(user_instruction)
first_call_api_count = gpt_4o_mini.api_calls - initial_api_calls
# # First call - should make API calls
# initial_api_calls = gpt_4o_mini.api_calls
# filtered_df1 = df.sem_filter(user_instruction)
# first_call_api_count = gpt_4o_mini.api_calls - initial_api_calls

# Second call - should use cache
filtered_df2 = df.sem_filter(user_instruction)
second_call_api_count = gpt_4o_mini.api_calls - (initial_api_calls + first_call_api_count)
# # Second call - should use cache
# filtered_df2 = df.sem_filter(user_instruction)
# second_call_api_count = gpt_4o_mini.api_calls - (initial_api_calls + first_call_api_count)

# Verify results are the same
assert filtered_df1.equals(filtered_df2)
# # Verify results are the same
# assert filtered_df1.equals(filtered_df2)

assert gpt_4o_mini.api_calls == 0
assert initial_api_calls == 0
# Verify first call made API calls
assert first_call_api_count == 0, "First call should make API calls"
# assert gpt_4o_mini.api_calls == 0
# assert initial_api_calls == 0
# # Verify first call made API calls
# assert first_call_api_count == 0, "First call should make API calls"

# Verify second call used cache (no new API calls)
assert second_call_api_count == 0, "Second call should use cache (no new API calls)"
# # Verify second call used cache (no new API calls)
# assert second_call_api_count == 0, "Second call should use cache (no new API calls)"


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
Expand Down
Binary file modified .lotus_cache/cache.db
Binary file not shown.
39 changes: 15 additions & 24 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,11 @@ def __init__(
self.tokenizer = tokenizer
self.cache = cache
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()

def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}
cache = kwargs.pop("cache", self.cache)

# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10)
self.history = []
self.cache = cache
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)

self.stats: LMStats = LMStats()

def _messages_to_cache_key(self, messages):
if isinstance(messages[0], list):
return tuple(tuple(frozenset(m.items()) for m in msg) for msg in messages)
Expand All @@ -66,7 +56,7 @@ def _cache_key_to_messages(self, messages_tuple):
return [dict(m) for m in messages_tuple]

@lru_cache(maxsize=128)
def _cached_completion(self, messages_tuple, **kwargs):
def _cached_completion(self, messages_tuple, **all_kwargs):
all_responses: list[ModelResponse] = []
messages = self._cache_key_to_messages(messages_tuple)
for i in range(0, len(messages), self.max_batch_size):
Expand Down Expand Up @@ -94,19 +84,20 @@ def _batch_complete(self, messages, **kwargs):
def __call__(self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any]
) -> LMOutput:
kwargs = {**self.kwargs, **kwargs}
kwargs_for_batch = self._format_batch_kwargs(kwargs)
cache = kwargs.pop("cache", self.cache)
self.stats.total_usage.api_calls = 0
self.stats.TotalUsage.api_calls = 0

all_kwargs = {**self.kwargs, **kwargs}

if kwargs.get("logprobs", False):
kwargs["top_logprobs"] = kwargs.get("top_logprobs", 10)
# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10)

all_responses: list[ModelResponse] = []

if cache:
messages_tuple = self._messages_to_cache_key(messages)
all_responses = self._cached_completion(messages_tuple, **kwargs_for_batch)
all_responses = self._cached_completion(messages_tuple, **all_kwargs)
else:
for i in range(0, len(messages), self.max_batch_size):
batch = messages[i : i + self.max_batch_size]
Expand All @@ -117,13 +108,13 @@ def __call__(self, messages: list[dict[str, str]] | list[list[dict[str, str]]],
**all_kwargs, # type: ignore
)
all_responses.extend(responses)
self.stats.total_usage.api_calls += 1
self.stats.TotalUsage.api_calls += 1

self.logger.info(f"Making API call #{self.stats.total_usage.api_calls}")
outputs = [self._get_top_choice(resp) for resp in responses]
logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs.get("logprobs") else None


self.logger.info(f"Making API call #{self.stats.TotalUsage.api_calls}")
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)

# throw errors, if any
for resp in all_responses:
Expand Down

0 comments on commit c6e58f3

Please sign in to comment.