diff --git a/.gitignore b/.gitignore index c7a6e11..63f867f 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,7 @@ dmypy.json local/ *ipynb query/ + + +*.bin +*.gguf \ No newline at end of file diff --git a/README.md b/README.md index 36bb9f3..e31de3b 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,46 @@ chem_model = ChemCrow(model="gpt-4-0613", temp=0.1, streaming=False) chem_model.run("What is the molecular weight of tylenol?") ``` +### 💻 Running using local LLMs. + +ChemCrow also supports the use of local LLMs, through either GPT4All, or HuggingFace's [TGI](https://huggingface.co/docs/text-generation-inference/index). + +#### GPT4All + +To run using GPT4All, you will need to download one of the [supported models](https://gpt4all.io/index.html). + +```python +from chemcrow.agents import ChemCrow + +chem_model = ChemCrow( + model_type='gpt4all', + model="./models/mistral-7b-instruct-v0.1.Q4_0.gguf", + temp=0.1, + max_tokens=100, + verbose=False, +) +``` + +#### TGI + +The other option is Text Generation Interface. This allows you to serve a model and run inference as an API. +To deploy a model, you will need docker. Run it as explained [here](https://huggingface.co/docs/text-generation-inference/quicktour). + +```python +from chemcrow.agents import ChemCrow + +agent = ChemCrow( + model_type='tgi', + model_server_url='http://server-ip-address:8080', + temp=0.3, + max_tokens=40, + max_iterations=3, +).agent_executor +``` + +The advantage of TGI is improved efficiency, plus easy access to any model available in HuggingFace. + + ## ✅ Citation Bran, Andres M., et al. "ChemCrow: Augmenting large-language models with chemistry tools." arXiv preprint arXiv:2304.05376 (2023). diff --git a/chemcrow/agents/chemcrow.py b/chemcrow/agents/chemcrow.py index 4528887..10d0c7d 100644 --- a/chemcrow/agents/chemcrow.py +++ b/chemcrow/agents/chemcrow.py @@ -1,64 +1,114 @@ -from typing import Optional - -import langchain +import os from dotenv import load_dotenv +from typing import Optional, Dict, Literal +import langchain +import nest_asyncio from langchain import PromptTemplate, chains from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from pydantic import ValidationError from rmrkl import ChatZeroShotAgent, RetryAgentExecutor + from .prompts import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, REPHRASE_TEMPLATE, SUFFIX from .tools import make_tools -def _make_llm(model, temp, api_key, streaming: bool = False): - if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): - llm = langchain.chat_models.ChatOpenAI( - temperature=temp, - model_name=model, - request_timeout=1000, - streaming=streaming, - callbacks=[StreamingStdOutCallbackHandler()], - openai_api_key=api_key, - ) - elif model.startswith("text-"): - llm = langchain.OpenAI( - temperature=temp, - model_name=model, - streaming=streaming, - callbacks=[StreamingStdOutCallbackHandler()], - openai_api_key=api_key, +def _make_llm( + model_type: Literal["openai", "tgi", "gpt4all"], + model_server_url: Optional[str], + verbose, + api_key, + **kwargs +): + if model_type == "openai": + load_dotenv() + try: + llm = langchain.chat_models.ChatOpenAI( + temperature=kwargs['temp'], + model_name=kwargs['model'], + request_timeout=1000, + streaming=True if verbose else False, + callbacks=[StreamingStdOutCallbackHandler()] if verbose else [None], + openai_api_key = api_key + ) + except: + raise ValueError("Invalid OpenAI API key") + + elif model_type == "tgi": + from langchain.llms import HuggingFaceTextGenInference + llm = HuggingFaceTextGenInference( + inference_server_url=model_server_url, + max_new_tokens=kwargs['max_tokens'], + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=kwargs['temp'], + repetition_penalty=1.03, ) - else: - raise ValueError(f"Invalid model name: {model}") + + elif model_type == "gpt4all": + from langchain.llms import GPT4All + model = kwargs['model'] + if isinstance(model, str): + if os.path.exists(model) and model.endswith(".gguf"): + llm = GPT4All( + model=model, + max_tokens=kwargs['max_tokens'], + temp=kwargs['temp'], + verbose=False + ) + else: + raise ValueError(f"Couldn't load model. Only models with .gguf format are suported currently.") + else: + raise ValueError(f"Invalid model name: {model}") return llm + class ChemCrow: def __init__( self, + model_type = 'openai', + model_server_url: Optional[str] = None, tools=None, model="gpt-4-0613", tools_model="gpt-3.5-turbo-0613", temp=0.1, + max_tokens: int = 4096, max_iterations=40, verbose=True, streaming: bool = True, - openai_api_key: Optional[str] = None, - api_keys: dict = {}, + openai_api_key: str = '', + api_keys: Dict[str, str] = {}, ): """Initialize ChemCrow agent.""" - load_dotenv() - try: - self.llm = _make_llm(model, temp, openai_api_key, streaming) - except ValidationError: - raise ValueError("Invalid OpenAI API key") + self.llm = _make_llm( + model_type, + model_server_url, + verbose, + openai_api_key, + model=model, + max_tokens=max_tokens, + temp=temp + ) if tools is None: api_keys["OPENAI_API_KEY"] = openai_api_key - tools_llm = _make_llm(tools_model, temp, openai_api_key, streaming) - tools = make_tools(tools_llm, api_keys=api_keys, verbose=verbose) + tools_llm = _make_llm( + model_type, + model_server_url, + verbose, + openai_api_key, + model=model, + max_tokens=max_tokens, + temp=temp + ) + tools = make_tools( + tools_llm, + api_keys = api_keys, + verbose=verbose + ) # Initialize agent self.agent_executor = RetryAgentExecutor.from_agent_and_tools(