Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat update router training #87

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ or your own code where you want to use the results from optillm. You can use it

| Plugin | Slug | Description |
| ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- |
| Router | `router` | Uses the [optillm-bert-uncased](https://huggingface.co/codelion/optillm-bert-uncased) model to route requests to different approaches based on the user prompt |
| Memory | `memory` | Implements a short term memory layer, enables you to use unbounded context length with any LLM |
| Privacy | `privacy` | Anonymize PII data in request and deanonymize it back to original value in response |
| Read URLs | `readurls` | Reads all URLs found in the request, fetches the content at the URL and adds it to the context |
Expand Down
5 changes: 2 additions & 3 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import re
from concurrent.futures import ThreadPoolExecutor

# Import the LiteLLM wrapper
from optillm.litellm_wrapper import LiteLLMWrapper

# Import approach modules
from optillm.mcts import chat_with_mcts
from optillm.bon import best_of_n_sampling
Expand Down Expand Up @@ -74,6 +71,8 @@ def get_config():
azure_ad_token_provider=token_provider
)
else:
# Import the LiteLLM wrapper
from optillm.litellm_wrapper import LiteLLMWrapper
default_client = LiteLLMWrapper()
return default_client, API_KEY

Expand Down
130 changes: 130 additions & 0 deletions scripts/gen_optillm_ground_truth_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import json
import argparse
import asyncio
from tqdm import tqdm
from datasets import load_dataset
from openai import AsyncOpenAI
from typing import List, Dict, Any, Tuple
import random

# OptILM approaches remain the same as in original script
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]

# Dataset configurations
DATASET_CONFIGS = [
("MixEval", "free_form"),
("MixEval", "multiple_choice"),
("MixEval_Hard", "free_form"),
("MixEval_Hard", "multiple_choice")
]

def construct_prompt(sample: Dict[str, Any], split_type: str) -> str:
"""Construct prompt based on split type."""
context = sample.get("context", "")
prompt = sample["prompt"]

if split_type == "multiple_choice":
options = sample["options"]
options_text = "\nOptions:\n" + "\n".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
return f"Context: {context}\n\nQuestion: {prompt}{options_text}\n\nProvide the correct answer from the options above."
else:
return f"Context: {context}\n\nQuestion: {prompt}\n\nProvide your answer."

def is_correct_response(response: str, targets: List[str]) -> bool:
"""Check if response matches any of the target answers."""
response = response.strip().lower()
return any(target.strip().lower() == response for target in targets)

async def generate_response(prompt: str, approach: str) -> Dict[str, Any]:
"""Generate a response using the specified approach."""
if approach == "none":
client = AsyncOpenAI()
response = await client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
)
return {
"content": response.choices[0].message.content,
"tokens": response.usage.completion_tokens,
}
else:
client = AsyncOpenAI(api_key="none", base_url="http://localhost:8000/v1")
response = await client.chat.completions.create(
model=f"{approach}-gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
)
return {
"content": response.choices[0].message.content,
"tokens": response.usage.completion_tokens,
}

def rank_responses(responses: List[Dict[str, Any]], targets: List[str]) -> List[int]:
"""Rank responses based on correctness and token efficiency."""
# Create tuples of (index, is_correct, tokens) for sorting
ranked_data = []
for i, response in enumerate(responses):
is_correct = is_correct_response(response["content"], targets)
ranked_data.append((i, is_correct, response["tokens"]))

# Sort by correctness (True first) and then by tokens (ascending)
ranked_data.sort(key=lambda x: (-int(x[1]), x[2]))

# Extract indices for final ranking
return [idx for idx, _, _ in ranked_data]

async def process_sample(sample: Dict[str, Any], split_type: str) -> Dict[str, Any]:
"""Process a single sample from the dataset."""
prompt = construct_prompt(sample, split_type)
results = []

# Generate responses for each approach
for approach in APPROACHES:
response = await generate_response(prompt, approach)
results.append({"approach": approach, **response})

# Rank the responses based on correctness and token efficiency
rankings = rank_responses(results, sample["target"])

# Add rankings to results
for rank, idx in enumerate(rankings):
results[idx]["rank"] = rank

return {
"prompt": prompt,
"results": results,
}

async def generate_dataset(num_samples: int, output_file: str):
"""Generate the dataset and save it to a JSONL file."""
with open(output_file, "w") as f:
for config, split_type in DATASET_CONFIGS:
print(f"Processing {config} - {split_type}")
dataset = load_dataset("MixEval/MixEval", config, split=split_type)

# Calculate samples per configuration
samples_per_config = max(1, num_samples // len(DATASET_CONFIGS))

for sample in tqdm(dataset.select(range(samples_per_config)),
total=samples_per_config,
desc=f"{config}-{split_type}"):
try:
result = await process_sample(sample, split_type)
f.write(json.dumps(result) + "\n")
except Exception as e:
print(f"Error processing sample: {str(e)}")

def main():
parser = argparse.ArgumentParser(description="Generate OptILM Ground Truth dataset")
parser.add_argument("--num_samples", type=int, default=100,
help="Total number of samples to process (divided among configurations)")
parser.add_argument("--output_file", type=str,
default="optillm_ground_truth_dataset.jsonl",
help="Output file path")
args = parser.parse_args()

asyncio.run(generate_dataset(args.num_samples, args.output_file))
print(f"Dataset generated and saved to {args.output_file}")

if __name__ == "__main__":
main()
54 changes: 49 additions & 5 deletions scripts/train_optillm_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __getitem__(self, idx):
}

def load_and_preprocess_data(tokenizer):
dataset = load_dataset('json', data_files='optillm_dataset.jsonl')
dataset = load_dataset('json', data_files='optillm_combined_dataset.jsonl')

data_items = []

Expand Down Expand Up @@ -290,11 +290,54 @@ def main(args):
best_model.eval()

test_prompts = [
# Linear Programming (likely MCTS or Z3)
"Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0",
# Graph Theory (likely MCTS or RTO)
"Find the shortest path between nodes A and B in the given graph",
# Recursive Problem (likely MOA or COT)
"Solve the Tower of Hanoi problem with 4 disks",
# Number Theory (likely NONE or Z3)
"Determine if the given number is prime",
"Find all possible combinations of coins that sum up to $1"
# Combinatorics (likely MCTS or BON)
"Find all possible combinations of coins that sum up to $1",
# Symbolic Mathematics (likely Z3 or LEAP)
"Solve the equation: 2x^3 - 5x^2 + 3x - 7 = 0",
# Natural Language Processing (likely PVG or SELF_CONSISTENCY)
"Summarize the main points of the given article in three sentences",
# Computer Vision (likely RSTAR or PVG)
"Describe the contents of the image, including any text present",
# Game Theory (likely MCTS or BON)
"Find the Nash equilibrium for the prisoner's dilemma game",
# Constraint Satisfaction (likely Z3 or PLANSEARCH)
"Solve the Sudoku puzzle given the following initial configuration",
# Optimization (likely MCTS or RSTAR)
"Find the optimal route for a salesperson visiting 10 cities",
# Logical Reasoning (likely COT_REFLECTION or SELF_CONSISTENCY)
"If all A are B, and some B are C, what can we conclude about A and C?",
# Time Series Analysis (likely RSTAR or PVG)
"Predict the stock price for the next week given the past year's data",
# Robotics (likely MCTS or RTO)
"Plan a path for a robot to navigate through a room with obstacles",
# Natural Language Understanding (likely PVG or LEAP)
"Identify the sentiment and main topics in the following customer review",
# Theorem Proving (likely Z3 or COT_REFLECTION)
"Prove that the square root of 2 is irrational",
# Reinforcement Learning (likely MCTS or RSTAR)
"Design a policy for an agent to maximize its score in a given game environment",
# Information Retrieval (likely PVG or SELF_CONSISTENCY)
"Find the most relevant documents in the corpus for the given query",
# Cryptography (likely Z3 or LEAP)
"Decrypt the following message encrypted with a simple substitution cipher",
# Quantum Computing (likely NONE or Z3)
"Simulate a quantum circuit with 3 qubits and measure the output",
# Computer Graphics (likely RSTAR or PVG)
"Generate a 3D model of a house based on the given floor plan",
# Bioinformatics (likely Z3 or LEAP)
"Find potential binding sites for a given protein sequence in a DNA strand",
# Automated Reasoning (likely COT_REFLECTION or Z3)
"Given a set of logical statements, determine if the conclusion follows",
# Natural Language Generation (likely PVG or SELF_CONSISTENCY)
"Write a short story in the style of Edgar Allan Poe about a haunted lighthouse"
]

effort_levels = [0.0, 0.2, 0.5, 0.8, 1.0]
Expand All @@ -310,13 +353,14 @@ def main(args):
parser = argparse.ArgumentParser(description="Train OptILM classifier")
parser.add_argument("--model_name", type=str, default="google-bert/bert-large-uncased", help="Pretrained model name")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training")
parser.add_argument("--learning_rate", type=float, default=1e-6, help="Learning rate")
parser.add_argument("--num_epochs", type=int, default=10, help="Maximum number of training epochs")
parser.add_argument("--learning_rate", type=float, default=5e-7, help="Learning rate")
parser.add_argument("--num_epochs", type=int, default=20, help="Maximum number of training epochs")
parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub")
parser.add_argument("--hub_model_id", type=str, help="Model ID for Hugging Face Hub")
parser.add_argument("--k_folds", type=int, default=5, help="Number of folds for cross-validation")
parser.add_argument("--patience", type=int, default=3, help="Number of epochs to wait for improvement before early stopping")
parser.add_argument("--clip_value", type=float, default=1.0, help="Gradient clipping value")

args = parser.parse_args()
main(args)
main(args)

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="optillm",
version="0.0.8",
version="0.0.9",
packages=find_packages(),
py_modules=['optillm'],
package_data={
Expand Down