diff --git a/.env.example b/.env.example index 9d8af6d..ed2c0ee 100644 --- a/.env.example +++ b/.env.example @@ -4,12 +4,21 @@ OPENAI_API_KEY="sk-..." # https://console.anthropic.com/account/keys ANTHROPIC_API_KEY="sk-ant-..." -DEFAULT_SERVICE="openai/chat/gpt-3.5-turbo" +# when things aren't specify, these defaults will kick in: +#DEFAULT_SERVICE="anthropic/complete/claude-v1.3-100k" +#DEFAULT_SERVICE="openai/chat/gpt-4" +#DEFAULT_SERVICE="openai/chat/gpt-3.5-turbo" + +# WARNING/TODO/FIXME: do not specify below options when running pytest +# +#DEFAULT_TEMPERATURE=0.6 +#DEFAULT_MAX_TOKENS=1337 +#DEFAULT_TOP_K=-1 +#DEFAULT_TOP_P=-1 # your Google Cloud Project ID or number # environment default used is not set -GOOGLE_PROJECT="project-name-id" +#GOOGLE_PROJECT="project-name-id" # the Vertex AI region you will use # defaults to us-central1 -GOOGLE_LOCATION="us-central1" - +#GOOGLE_LOCATION="us-central1" diff --git a/README.md b/README.md index 38734f1..9c241d1 100644 --- a/README.md +++ b/README.md @@ -468,6 +468,12 @@ We'd love your help in making Prr even better! To contribute, please follow thes 5. Push the branch to your fork 6. Create a new Pull Request +### Running unit tests + +```sh +$ pytest +``` + ## License **prr** - Prompt Runner is released under the [MIT License](/LICENSE). diff --git a/examples/code/html_boilerplate b/examples/code/_html_boilerplate similarity index 100% rename from examples/code/html_boilerplate rename to examples/code/_html_boilerplate diff --git a/examples/code/html_boilerplate.yaml b/examples/code/html_boilerplate.yaml index e9fe6ff..74aa653 100644 --- a/examples/code/html_boilerplate.yaml +++ b/examples/code/html_boilerplate.yaml @@ -1,6 +1,6 @@ version: 1 prompt: - content_file: 'html_boilerplate' + content_file: '_html_boilerplate' services: gpt4_temp7: model: 'openai/chat/gpt-4' diff --git a/examples/configured/chihuahua.yaml b/examples/configured/chihuahua.yaml index 3d34b43..d09957a 100644 --- a/examples/configured/chihuahua.yaml +++ b/examples/configured/chihuahua.yaml @@ -52,4 +52,4 @@ expect: min_response_length: 100 max_response_length: 200 match: - name: /independent/i \ No newline at end of file + name: /independent/i diff --git a/examples/shebang/write_tests b/examples/shebang/write_tests new file mode 100755 index 0000000..da217b9 --- /dev/null +++ b/examples/shebang/write_tests @@ -0,0 +1,5 @@ +#!/usr/bin/env prr script + +Write tests using pytest for the code below: + +{% include prompt_args %} diff --git a/prr/__main__.py b/prr/__main__.py index 1200366..da5f269 100755 --- a/prr/__main__.py +++ b/prr/__main__.py @@ -4,10 +4,10 @@ import os import sys -from prr.config import load_config - -from .utils.run import RunPromptCommand -from .utils.watch import WatchPromptCommand +from prr.commands.run import RunPromptCommand +from prr.commands.watch import WatchPromptCommand +from prr.prompt.model_options import ModelOptions +from prr.utils.config import load_config config = load_config() @@ -38,13 +38,43 @@ def add_common_args(_parser): action="store_true", default=False, ) + _parser.add_argument( "--service", "-s", - help="Service to use if none is configured (defaults to DEFAULT_SERVICE environment variable)", - default=config["DEFAULT_SERVICE"], + help="Service to use if none is configured (defaults to DEFAULT_SERVICE)", + default=config.get("DEFAULT_SERVICE"), type=str, ) + + _parser.add_argument( + "--temperature", + "-t", + help="Temperature (defaults to DEFAULT_TEMPERATURE)", + type=float, + ) + + _parser.add_argument( + "--max_tokens", + "-mt", + help="Max tokens to use (defaults to DEFAULT_MAX_TOKENS)", + type=int, + ) + + _parser.add_argument( + "--top_p", + "-tp", + help="Sets a cumulative probability threshold for selecting candidate tokens, where only tokens with a cumulative probability higher than the threshold are considered, allowing for flexible control over the diversity of the generated output (defaults to DEFAULT_TOP_P).", + type=int, + ) + + _parser.add_argument( + "--top_k", + "-tk", + help="Determines the number of top-scoring candidate tokens to consider at each decoding step, effectively limiting the diversity of the generated output (defaults to DEFAULT_TOP_K)", + type=int, + ) + _parser.add_argument( "--quiet", "-q", diff --git a/prr/utils/run.py b/prr/commands/run.py similarity index 72% rename from prr/utils/run.py rename to prr/commands/run.py index 0f464fa..e3228c4 100755 --- a/prr/utils/run.py +++ b/prr/commands/run.py @@ -8,8 +8,8 @@ from rich.console import Console from rich.panel import Panel -# from prr.config import config from prr.prompt import Prompt +from prr.prompt.prompt_loader import PromptConfigLoader from prr.runner import Runner console = Console(log_time=False, log_path=False) @@ -19,7 +19,7 @@ class RunPromptCommand: def __init__(self, args, prompt_args=None): self.args = args self.prompt_args = prompt_args - self.prompt = None + self.prompt_config = None if self.args["quiet"]: self.console = Console(file=StringIO()) @@ -27,7 +27,7 @@ def __init__(self, args, prompt_args=None): self.console = Console(log_time=False, log_path=False) self.load_prompt_for_path() - self.runner = Runner(self.prompt) + self.runner = Runner(self.prompt_config) def print_run_results(self, result, run_save_directory): request = result.request @@ -57,11 +57,11 @@ def print_run_results(self, result, run_save_directory): Panel("[green]" + response.response_content.strip() + "[/green]") ) - completion = f"[blue]Completion length[/blue]: {len(response.response_content)} bytes" - tokens_used = f"[blue]Tokens used[/blue]: {response.tokens_used()}" - elapsed_time = ( - f"[blue]Elapsed time[/blue]: {round(result.elapsed_time, 2)}s" - ) + completion = f"[blue]Completion length[/blue]: {len(response.response_content)} bytes" + tokens_used = f"[blue]Tokens used[/blue]: {response.tokens_used()}" + elapsed_time = ( + f"[blue]Elapsed time[/blue]: {round(result.elapsed_time, 2)}s" + ) self.console.log(f"{completion} {tokens_used} {elapsed_time}") @@ -69,7 +69,10 @@ def print_run_results(self, result, run_save_directory): self.console.log(f"💾 {run_save_directory}") def run_prompt_on_service(self, service_name, save=False): - service_config = self.prompt.config_for_service(service_name) + # TODO/FIXME: doing all this here just to get the actual options + # calculated after command line, defaults, config, etc + service_config = self.prompt_config.service_with_name(service_name) + service_config.process_option_overrides(self.args) options = service_config.options with self.console.status( @@ -82,24 +85,22 @@ def run_prompt_on_service(self, service_name, save=False): status.update(status="running model", spinner="dots8Bit") - result, run_save_directory = self.runner.run_service(service_name, save) + result, run_save_directory = self.runner.run_service( + service_name, self.args, save + ) self.print_run_results(result, run_save_directory) def load_prompt_for_path(self): prompt_path = self.args["prompt_path"] - if not os.path.exists(prompt_path) or not os.access(prompt_path, os.R_OK): - self.console.log( - f":x: Prompt file {prompt_path} is not accessible, giving up." - ) - exit(-1) - self.console.log(f":magnifying_glass_tilted_left: Reading {prompt_path}") - self.prompt = Prompt(prompt_path, self.prompt_args) + + loader = PromptConfigLoader() + self.prompt_config = loader.load_from_path(prompt_path) def run_prompt(self): - services_to_run = self.prompt.configured_service_names() + services_to_run = self.prompt_config.configured_services() if services_to_run == []: if self.args["service"]: @@ -109,14 +110,14 @@ def run_prompt(self): ) else: self.console.log( - f":x: No services configured for prompt {self.args['prompt_path']}, nor given in command-line. Not even in .env!" + f":x: No services configured for prompt {self.args['prompt_path']}, in ~/.prr_rc nor given in command-line." ) exit(-1) else: self.console.log(f":racing_car: Running services: {services_to_run}") if not self.args["abbrev"]: - self.console.log(Panel(self.prompt.text())) + self.console.log(Panel(self.prompt_config.template_text())) - for service_name in services_to_run: - self.run_prompt_on_service(service_name, self.args["log"]) + for service_name in services_to_run: + self.run_prompt_on_service(service_name, self.args["log"]) diff --git a/prr/utils/watch.py b/prr/commands/watch.py similarity index 89% rename from prr/utils/watch.py rename to prr/commands/watch.py index 5b9457c..b108290 100755 --- a/prr/utils/watch.py +++ b/prr/commands/watch.py @@ -4,8 +4,9 @@ import os import time +from prr.commands.run import RunPromptCommand from prr.prompt import Prompt -from prr.utils.run import RunPromptCommand +from prr.prompt.prompt_loader import PromptConfigLoader def timestamp_for_file(path): @@ -34,9 +35,10 @@ def update_timestamps(self, ready_timestamps=None): self.file_timestamps = self.current_timestamps() def setup_files_to_monitor(self): - prompt = Prompt(self.args["prompt_path"], self.prompt_args) - self.files = [prompt.path] - self.files.extend(prompt.dependency_files) + loader = PromptConfigLoader() + prompt_config = loader.load_from_path(self.args["prompt_path"]) + + self.files = loader.file_dependencies self.update_timestamps() def files_changed(self): diff --git a/prr/options.py b/prr/options.py deleted file mode 100644 index ae6a0ed..0000000 --- a/prr/options.py +++ /dev/null @@ -1,34 +0,0 @@ -DEFAULT_OPTIONS = {"temperature": 1.0, "top_k": -1, "top_p": -1, "max_tokens": 4000} - -ALLOWED_OPTIONS = DEFAULT_OPTIONS.keys() - - -class ModelOptions: - def __init__(self, options={}): - self.options_set = [] - self.update_options(DEFAULT_OPTIONS) - self.update_options(options) - - def update_options(self, options): - for key in options.keys(): - if key in ALLOWED_OPTIONS: - if key not in self.options_set: - self.options_set.append(key) - setattr(self, key, options[key]) - - def description(self): - return " ".join([f"{key}={self.option(key)}" for key in self.options_set]) - - def option(self, key): - return getattr(self, key) - - def __repr__(self): - return self.description() - - def to_dict(self): - dict = {} - - for key in self.options_set: - dict[key] = self.option(key) - - return dict diff --git a/prr/prompt.py b/prr/prompt.py deleted file mode 100644 index f9d3a5b..0000000 --- a/prr/prompt.py +++ /dev/null @@ -1,241 +0,0 @@ -import os - -import jinja2 -import yaml -from jinja2 import meta - -from .service_config import ServiceConfig - - -# parse something like: -# -# services: -# gpt35crazy: -# model: 'openai/chat/gpt-3.5-turbo' -# options: -# temperature: 0.99 -# claudev1smart: -# model: 'anthropic/complete/claude-v1' -# options: -# temperature: 0 -# options: -# temperature: 0.7 -# max_tokens: 64 -def parse_specifically_configured_services(services_config): - options = services_config.get("options") - services_config.pop("options") - - service_names = services_config.keys() - - services = {} - - for service_name in service_names: - service_config = services_config[service_name] - model = service_config["model"] - service_config.pop("model") - - merged_options = options.copy() - - if service_config["options"]: - service_options = service_config["options"] - merged_options.update(service_options) - - services[service_name] = ServiceConfig(model, merged_options) - - return services - - -# parse something like: -# -# services: -# models: -# - 'openai/chat/gpt-3.5-turbo' -# - 'anthropic/complete/claude-v1' -# options: -# temperature: 0.7 -# max_tokens: 100 -# top_p: 1.0 -# top_k: 40 -def parse_generally_configured_services(services_config): - options = services_config.get("options") - models = services_config.get("models") - - services = {} - - for model in models: - services[model] = ServiceConfig(model, options) - - return services - - -def parse_config_into_services(services_config): - if services_config.get("models"): - return parse_generally_configured_services(services_config) - else: - return parse_specifically_configured_services(services_config) - - -class Prompt: - def __init__(self, path, args=None): - self.path = None - self.messages = None - self.template = None - self.services = {} - self.args = args - # TODO/FIXME: should also include jinja includes - self.dependency_files = [] - - template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(path)) - self.template_env = jinja2.Environment(loader=template_loader) - - root, extension = os.path.splitext(path) - - if extension == ".yaml": - self.load_yaml_file(path) - else: - self.load_text_file(path) - - def configured_service_names(self): - if self.services: - return list(self.services.keys()) - - return [] - - def parse_messages(self, messages): - # expand content_file field in messages - self.messages = [] - self.dependency_files = [] - root_path = os.path.dirname(self.path) - - if messages: - for message in messages: - if message.get("content_file"): - updated_message = message.copy() - file_path = os.path.join( - root_path, updated_message.pop("content_file") - ) - - with open(file_path, "r") as f: - updated_message.update({"content": f.read()}) - self.messages.append(updated_message) - self.dependency_files.append(file_path) - else: - self.messages.append(message) - - def add_dependency_files_from_jinja_template(self, jinja_template_content): - parsed_content = self.template_env.parse(jinja_template_content) - referenced_templates = meta.find_referenced_templates(parsed_content) - - for referenced_template in referenced_templates: - template_path = os.path.join( - os.path.dirname(self.path), referenced_template - ) - self.dependency_files.append(template_path) - - def parse_services(self, services_config): - self.services = parse_config_into_services(services_config) - - def parse_prompt_config(self, prompt_config): - if prompt_config.get("messages"): - self.parse_messages(prompt_config["messages"]) - elif prompt_config.get("content"): - self.template = self.load_jinja_template_from_string( - prompt_config["content"] - ) - elif prompt_config.get("content_file"): - self.template = self.load_jinja_template_from_file( - prompt_config["content_file"] - ) - - def config_for_service(self, service_name): - if self.services: - if self.services.get(service_name): - return self.services[service_name] - - return ServiceConfig(service_name) - - def load_yaml_file(self, path): - with open(path, "r") as stream: - try: - file_contents = self.deal_with_shebang_line(stream) - data = yaml.safe_load(file_contents) - self.path = path - except yaml.YAMLError as exc: - print(exc) - - if data: - if data["services"]: - self.parse_services(data["services"]) - if data["prompt"]: - self.parse_prompt_config(data["prompt"]) - - def deal_with_shebang_line(self, stream): - file_contents = stream.readlines() - - # if the file starts with a shebang, like #!/usr/bin/prr, omit it - if file_contents[0].startswith("#!/"): - if file_contents[1] == "\n": - # allow for one empty line below the shebang - file_contents = file_contents[2:] - else: - file_contents = file_contents[1:] - - return "".join(file_contents) - - def load_jinja_template_from_string(self, content): - self.add_dependency_files_from_jinja_template(content) - return self.template_env.from_string(content) - - def load_jinja_template_from_file(self, template_subpath): - try: - with open( - os.path.join(os.path.dirname(self.path), template_subpath), "r" - ) as stream: - self.add_dependency_files_from_jinja_template(stream.read()) - - return self.template_env.get_template(template_subpath) - except FileNotFoundError: - print(f"Could not find template file: {template_subpath}") - exit(-1) - - def load_text_file(self, path): - self.path = path - - with open(path, "r") as stream: - file_contents = self.deal_with_shebang_line(stream) - self.template = self.template_env.from_string(file_contents) - - def message_text_description(self, message): - name = message.get("name") - role = message.get("role") - content = message.get("content") - - if name: - return f"{name} ({role}): {content}" - else: - return f"{role}: {content}" - - def text(self): - if self.messages: - return "\n".join( - [self.message_text_description(msg) for msg in self.messages] - ) - - if self.args: - return self.template.render({"prompt_args": self.args}) - - return self.template.render() - - def text_len(self): - return len(self.text()) - - def dump(self): - return yaml.dump({"text": self.text(), "messages": self.messages}) - - def text_abbrev(self, max_len=25): - if self.text_len() > max_len: - str = self.text()[0:max_len] + "..." - else: - str = self.text() - - return str.replace("\n", " ").replace(" ", " ") diff --git a/prr/prompt/__init__.py b/prr/prompt/__init__.py new file mode 100644 index 0000000..39b824c --- /dev/null +++ b/prr/prompt/__init__.py @@ -0,0 +1,16 @@ +import os + +import jinja2 +import yaml +from jinja2 import meta + +from prr.prompt.prompt_config import PromptConfig +from prr.prompt.prompt_template import PromptTemplate +from prr.prompt.service_config import ServiceConfig + + +class Prompt: + def __init__(self, content, config=None, args=None): + self.content = content + self.config = config + self.args = args diff --git a/prr/prompt/model_options.py b/prr/prompt/model_options.py new file mode 100644 index 0000000..6c551ef --- /dev/null +++ b/prr/prompt/model_options.py @@ -0,0 +1,61 @@ +from prr.utils.config import load_config + +ALLOWED_OPTIONS = ["max_tokens", "temperature", "top_k", "top_p"] + + +config = load_config() + + +class ModelOptions: + DEFAULT_OPTIONS = {"max_tokens": 4000, "temperature": 0.7, "top_k": -1, "top_p": -1} + + def __init__(self, options={}): + self.__init_defaults() + + self.options_set = [] + self.update_options(self.defaults) + self.update_options(options) + + def update_options(self, options): + for key in options.keys(): + if options[key] != None: + if key in ALLOWED_OPTIONS: + if key not in self.options_set: + self.options_set.append(key) + + setattr(self, key, options[key]) + + def description(self): + return " ".join([f"{key}={self.option(key)}" for key in self.options_set]) + + def option(self, key): + return getattr(self, key) + + def __repr__(self): + return self.description() + + def to_dict(self): + dict = {} + + for key in self.options_set: + dict[key] = self.option(key) + + return dict + + def __config_key_for_option_key(self, option_key): + return f"DEFAULT_{option_key.upper()}" + + def __init_defaults(self): + self.defaults = ModelOptions.DEFAULT_OPTIONS.copy() + + for option_key in ALLOWED_OPTIONS: + config_key = self.__config_key_for_option_key(option_key) + defaults_value = config.get(config_key) + + if defaults_value: + if option_key == "temperature": + target_value = float(defaults_value) + else: + target_value = int(defaults_value) + + self.defaults[option_key] = target_value diff --git a/prr/prompt/prompt_config.py b/prr/prompt/prompt_config.py new file mode 100644 index 0000000..48ed4be --- /dev/null +++ b/prr/prompt/prompt_config.py @@ -0,0 +1,184 @@ +import yaml + +from prr.prompt.prompt_template import PromptTemplateMessages, PromptTemplateSimple +from prr.prompt.service_config import ServiceConfig + + +class PromptConfig: + # raw_config_content is text to be parsed into YAML + def __init__(self, search_path=".", filename=None): + # where are we supposed to look for referenced files + self.search_path = search_path + + # "foo" or "foo.yaml" - no path + self.filename = filename + + # template: (PromptTemplate) + self.template = None + + # services: (ServiceConfig) + self.services = {} + + # version: 1 + self.version = None + + def template_text(self): + return self.template.render_text() + + # raw YAML file + def load_from_config_contents(self, raw_config_content): + # raw YAML string + self.raw_config_content = raw_config_content + + # parse raw YAML content into a dictionary + self.__parse_raw_config() + + # parse that dictionary into respective parts of prompt config + self.__parse() + + # raw prompt template file + def load_from_template_contents(self, raw_template_content): + self.__parse_prompt_template_simple(raw_template_content) + + # raw prompt template file from file + def load_from_template_contents_at_path(self, path): + try: + with open(path, "r") as file: + return self.__parse_prompt_template_simple(file.read()) + + except FileNotFoundError: + print("The specified file does not exist.") + + except PermissionError: + print("You do not have permission to access the specified file.") + + except Exception as e: + print("An error occurred while opening the file:", str(e)) + + # list keys/names of all services that we have configured in the config file + def configured_services(self): + return list(self.services.keys()) + + def service_with_name(self, service_name): + service_config = self.services.get(service_name) + + if service_config: + return service_config + else: + return ServiceConfig(service_name, service_name) + + # returns options for specific service, already includes all option inheritance + def options_for_service(self, service_name): + return self.service_with_name(service_name).options + + def option_for_service(self, service_name, option_name): + return self.options_for_service(service_name).option(option_name) + + def file_dependencies(self): + _dependencies = [] + for message in self.template.messages: + for dependency in message.file_dependencies: + if dependency not in _dependencies: + _dependencies.append(dependency) + + return _dependencies + + #################################################### + + def __parse(self): + self.__parse_version() + self.__parse_prompt() + self.__parse_services() + + def __parse_raw_config(self): + try: + self.config_content = yaml.safe_load(self.raw_config_content) + except yaml.YAMLError as exc: + print(exc) + + def __parse_version(self): + if self.config_content: + self.version = self.config_content.get("version") + + def __parse_prompt_template_simple(self, content): + self.template = PromptTemplateSimple(content, self.search_path) + + # high level "prompt:" parsing + def __parse_prompt(self): + if self.config_content: + prompt = self.config_content.get("prompt") + + if prompt: + content_file = prompt.get("content_file") + content = prompt.get("content") + messages = prompt.get("messages") + + if content_file: + include_contents = "{% include '" + content_file + "' %}" + self.template = PromptTemplateSimple( + include_contents, self.search_path + ) + elif content: + self.template = PromptTemplateSimple(content, self.search_path) + elif messages: + self.template = PromptTemplateMessages(messages, self.search_path) + + # high level "services:" parsing + def __parse_services(self): + if self.config_content: + _services = self.config_content.get("services") + + if _services: + options_for_all_services = _services.get("options") + + # + # if we have models + prompt-level model options + # + # services: + # models: + # - 'openai/chat/gpt-4' + # - 'anthropic/complete/claude-v1.3-100k' + # options: + # max_tokens: 1337 + _models = _services.get("models") + if _models: + for _model_name in _models: + service_config = ServiceConfig( + _model_name, _model_name, options_for_all_services + ) + + self.services[_model_name] = service_config + + else: + # + # if we have services defined with options for each + # + # services: + # mygpt4: + # model: 'openai/chat/gpt-4' + # options: + # temperature: 0.2 + # max_tokens: 4000 + # options: + # max_tokens: 1337 + for _service_name in _services: + if _service_name not in ["options", "models"]: + service = _services[_service_name] + + # start with options for all services + # defined on a higher level + options = options_for_all_services.copy() + + # update with service-level options + service_level_options = service.get("options") + + if service_level_options: + options.update(service_level_options) + + model = service.get("model") + + service_config = ServiceConfig( + _service_name, model, options + ) + + self.services[_service_name] = service_config diff --git a/prr/prompt/prompt_loader.py b/prr/prompt/prompt_loader.py new file mode 100644 index 0000000..2216d8e --- /dev/null +++ b/prr/prompt/prompt_loader.py @@ -0,0 +1,73 @@ +import os + +import yaml + +from prr.prompt import Prompt +from prr.prompt.prompt_config import PromptConfig +from prr.prompt.prompt_template import ( + PromptTemplate, + PromptTemplateMessages, + PromptTemplateSimple, +) + + +class PromptConfigLoader: + def __init__(self): + self.file_dependencies = [] + self.config = None + + def load_from_path(self, path): + self.path = path + self.config = PromptConfig(self.__search_path(), os.path.basename(path)) + + if self.__is_file_yaml(path): + # prompt is in yaml config file format + self.__load_yaml_file(path) + else: + # simple text (or jinja) file, no config + self.__load_text_file(path) + + self.__add_file_dependencies() + + return self.config + + ##################################### + + def __is_file_yaml(self, path): + root, extension = os.path.splitext(path) + + if extension == ".yaml": + return True + + return False + + def __search_path(self): + return os.path.dirname(self.path) + + def __load_text_file(self, path): + self.config.load_from_template_contents_at_path(path) + + def __load_yaml_file(self, path): + try: + with open(path, "r") as stream: + self.config.load_from_config_contents(stream.read()) + + except yaml.YAMLError as exc: + print(exc) + + def __add_file_dependencies(self): + self.__add_file_dependency(self.path) + + for file_dependency in self.config.file_dependencies(): + self.__add_file_dependency(file_dependency) + + def __add_file_dependency(self, file_path): + if os.path.isabs(file_path): + absolute_path = file_path + else: + absolute_path = os.path.join( + os.path.dirname(self.path), os.path.basename(file_path) + ) + + if not absolute_path in self.file_dependencies: + self.file_dependencies.append(absolute_path) diff --git a/prr/prompt/prompt_template.py b/prr/prompt/prompt_template.py new file mode 100644 index 0000000..1bf9a83 --- /dev/null +++ b/prr/prompt/prompt_template.py @@ -0,0 +1,98 @@ +import os + +import jinja2 + + +class PromptMessage: + def __init__( + self, content_template_string, search_path=".", role="user", name=None + ): + self.content_template_string = content_template_string + self.search_path = search_path + self.role = role + self.name = name + self.file_dependencies = [] + + template_loader = jinja2.ChoiceLoader( + [ + jinja2.FileSystemLoader(search_path), + jinja2.FileSystemLoader(["/"]), + ] + ) + + self.template_env = jinja2.Environment(loader=template_loader) + self.__add_dependency_files_from_jinja_template(content_template_string) + + self.template = self.template_env.from_string(self.content_template_string) + + def render_text(self, args=[]): + return self.template.render({"prompt_args": args}) + + def render_message(self, args=[]): + _message = {"role": self.role, "content": self.render_text(args)} + + if self.name: + _message.update({"name": self.name}) + + return _message + + def __add_dependency_files_from_jinja_template(self, jinja_template_content): + parsed_content = self.template_env.parse(jinja_template_content) + referenced_templates = jinja2.meta.find_referenced_templates(parsed_content) + + self.file_dependencies.extend(referenced_templates) + + +# base class +class PromptTemplate: + def __init__(self): + self.messages = [] + + def render_text(self, args=[]): + rendered_texts = [message.render_text(args) for message in self.messages] + + return "\n".join(rendered_texts) + + def render_messages(self, args=[]): + return [message.render_message(args) for message in self.messages] + + def file_dependencies(self): + _dependencies = [] + for message in self.messages: + for dependency in message.file_dependencies: + if dependency not in _dependencies: + _dependencies.append(dependency) + + return _dependencies + + +# just a text/template file or prompt.contents from config +class PromptTemplateSimple(PromptTemplate): + def __init__(self, template_string, search_path="."): + self.messages = [PromptMessage(template_string, search_path, "user")] + + +# prompt.messages: key from config +class PromptTemplateMessages(PromptTemplate): + # 'messages' are passed here verbatim after parsing YAML + def __init__(self, messages, search_path="."): + super().__init__() + + for message in messages: + prompt_message = None + + role = message.get("role") + name = message.get("name") + content = message.get("content") + content_file = message.get("content_file") + + if content: + prompt_message = PromptMessage(content, search_path, role, name) + elif content_file: + include_contents = "{% include '" + content_file + "' %}" + prompt_message = PromptMessage( + include_contents, search_path, role, name + ) + + if prompt_message: + self.messages.append(prompt_message) diff --git a/prr/prompt/service_config.py b/prr/prompt/service_config.py new file mode 100644 index 0000000..4d1c72e --- /dev/null +++ b/prr/prompt/service_config.py @@ -0,0 +1,24 @@ +from prr.prompt.model_options import ModelOptions + + +class ServiceConfig: + def __init__(self, name, model, options=None): + self.name = name # service config name, e.g. "mygpt5" + self.model = model # full model path, e.g. "openai/chat/gpt-5" + self.options = ModelOptions(options or {}) + + def process_option_overrides(self, option_overrides): + self.options.update_options(option_overrides) + + # which model to use with the service + # like gpt-4.5-turbo or claude-v1.3-100k + def model_name(self): + return self.model.split("/")[-1] + + # which service so use + # like openai/chat or anthropic/complete + def service_key(self): + return "/".join(self.model.split("/")[:-1]) + + def to_dict(self): + return {"model": self.config_name(), "options": self.options.to_dict()} diff --git a/prr/runner.py b/prr/runner.py deleted file mode 100644 index 16e3a26..0000000 --- a/prr/runner.py +++ /dev/null @@ -1,36 +0,0 @@ -from .prompt_run import PromptRun -from .saver import PromptRunSaver -from .service_registry import ServiceRegistry - -service_registry = ServiceRegistry() -service_registry.register_all_services() - - -# high-level class to run prompts based on configuration -class Runner: - def __init__(self, prompt): - self.prompt = prompt - self.saver = PromptRunSaver(self.prompt) - - def run_service(self, service_name, save_run=False): - service_config = self.prompt.config_for_service(service_name) - - service = service_registry.service_for_service_config(service_config) - - result = PromptRun(self.prompt, service, service_config).run() - - if save_run: - run_save_directory = self.saver.save(service_name, result) - else: - run_save_directory = None - - return result, run_save_directory - - # runs all models defined for specified prompt - def run_all_configured_services(self): - results = {} - - for model in self.configured_services(): - results[model] = self.run_service(model) - - return results diff --git a/prr/runner/__init__.py b/prr/runner/__init__.py new file mode 100644 index 0000000..48fdab1 --- /dev/null +++ b/prr/runner/__init__.py @@ -0,0 +1,40 @@ +from prr.runner.prompt_run import PromptRun +from prr.runner.saver import PromptRunSaver +from prr.services.service_registry import ServiceRegistry + +service_registry = ServiceRegistry() +service_registry.register_all_services() + + +# high-level class to run prompts based on configuration +class Runner: + def __init__(self, prompt_config): + self.prompt_config = prompt_config + self.saver = PromptRunSaver(self.prompt_config) + + def run_service(self, service_name, service_options_overrides, save_run=False): + service_config = self.prompt_config.service_with_name(service_name) + + service_config.process_option_overrides(service_options_overrides) + + service = service_registry.service_for_service_config(service_config) + + result = PromptRun(self.prompt_config, service, service_config).run() + + if save_run: + run_save_directory = self.saver.save(service_name, result) + else: + run_save_directory = None + + return result, run_save_directory + + # runs all models defined for specified prompt + def run_all_configured_services(self, service_options_overrides, save_run=False): + results = {} + + for service_name in self.configured_services(): + results[service_name] = self.run_service( + service_name, service_options_overrides, save_run + ) + + return results diff --git a/prr/prompt_run.py b/prr/runner/prompt_run.py similarity index 91% rename from prr/prompt_run.py rename to prr/runner/prompt_run.py index a888b83..a10d345 100644 --- a/prr/prompt_run.py +++ b/prr/runner/prompt_run.py @@ -1,6 +1,6 @@ import time -from .prompt_run_result import PromptRunResult +from prr.runner.prompt_run_result import PromptRunResult # takes prompt and model config, finds provider, runs the prompt diff --git a/prr/prompt_run_result.py b/prr/runner/prompt_run_result.py similarity index 100% rename from prr/prompt_run_result.py rename to prr/runner/prompt_run_result.py diff --git a/prr/request.py b/prr/runner/request.py similarity index 81% rename from prr/request.py rename to prr/runner/request.py index 63d9b8d..700383e 100644 --- a/prr/request.py +++ b/prr/runner/request.py @@ -9,10 +9,8 @@ def __init__(self, service_config, rendered_prompt_content): def to_dict(self): return { - "model": self.service_config.config_name(), + "model": self.service_config.model, "options": self.service_config.options.to_dict(), - # rendered prompt is saved to a separate file - # 'prompt_content': self.prompt_content, } def prompt_text(self, max_len=0): diff --git a/prr/response.py b/prr/runner/response.py similarity index 100% rename from prr/response.py rename to prr/runner/response.py diff --git a/prr/saver.py b/prr/runner/saver.py similarity index 77% rename from prr/saver.py rename to prr/runner/saver.py index 83e7a3c..1511d6e 100644 --- a/prr/saver.py +++ b/prr/runner/saver.py @@ -5,8 +5,8 @@ class PromptRunSaver: - def __init__(self, prompt): - self.prompt_path = prompt.path + def __init__(self, prompt_config): + self.prompt_config = prompt_config self.run_time = datetime.now() self.runs_subdir = self.run_root_directory_path() @@ -26,8 +26,12 @@ def run_root_directory_path_for_runs_dir(self, runs_dir): run_id += 1 def run_root_directory_path(self): - dirname = os.path.dirname(self.prompt_path) - basename = os.path.basename(self.prompt_path) + dirname = self.prompt_config.search_path + + if self.prompt_config.filename: + basename = os.path.basename(self.prompt_config.filename) + else: + basename = "prr" root, extension = os.path.splitext(basename) @@ -38,13 +42,13 @@ def run_root_directory_path(self): return self.run_root_directory_path_for_runs_dir(runs_dir) - def run_directory_path(self, model_or_model_config_name): - model_name_part = model_or_model_config_name.replace("/", "-") + def run_directory_path(self, service_or_model_name): + model_name_part = service_or_model_name.replace("/", "-") return os.path.join(self.runs_subdir, model_name_part) - def prepare_run_directory(self, model_or_model_config_name): - run_dir = self.run_directory_path(model_or_model_config_name) + def prepare_run_directory(self, service_or_model_name): + run_dir = self.run_directory_path(service_or_model_name) os.makedirs(run_dir, exist_ok=True) @@ -78,8 +82,8 @@ def save_run(self, run_directory, result): with open(run_file, "w") as f: yaml.dump(run_data, f, default_flow_style=False) - def save(self, model_or_model_config_name, result): - run_directory = self.prepare_run_directory(model_or_model_config_name) + def save(self, service_or_model_name, result): + run_directory = self.prepare_run_directory(service_or_model_name) self.save_prompt(run_directory, result.request) self.save_completion(run_directory, result.response) diff --git a/prr/service_config.py b/prr/service_config.py deleted file mode 100644 index 196a7c2..0000000 --- a/prr/service_config.py +++ /dev/null @@ -1,19 +0,0 @@ -from .options import ModelOptions - - -class ServiceConfig: - def __init__(self, model, options=None): - self.model = model - self.options = ModelOptions(options or {}) - - def config_name(self): - return self.model - - def model_name(self): - return self.model.split("/")[-1] - - def service_key(self): - return "/".join(self.model.split("/")[:-1]) - - def to_dict(self): - return {"model": self.config_name(), "options": self.options.to_dict()} diff --git a/prr/services/providers/anthropic/complete.py b/prr/services/providers/anthropic/complete.py index 039e967..07b9054 100644 --- a/prr/services/providers/anthropic/complete.py +++ b/prr/services/providers/anthropic/complete.py @@ -5,14 +5,12 @@ import anthropic -from prr.config import load_config -from prr.request import ServiceRequest -from prr.response import ServiceResponse +from prr.runner.request import ServiceRequest +from prr.runner.response import ServiceResponse +from prr.utils.config import load_config config = load_config() -# https://console.anthropic.com/docs/api/reference - # Anthropic model provider class class ServiceAnthropicComplete: @@ -28,7 +26,11 @@ def run(self, prompt, service_config): client = anthropic.Client(config.get("ANTHROPIC_API_KEY", None)) - service_request = ServiceRequest(self.service_config, prompt_text) + options = service_config.options + + prompt_text = self.prompt_text_from_template(prompt.template) + + service_request = ServiceRequest(service_config, prompt_text) response = client.completion( prompt=prompt_text, @@ -59,18 +61,16 @@ def run(self, prompt, service_config): return service_request, service_response - def prompt_text(self): - messages = self.prompt.messages - + def prompt_text_from_template(self, template): prompt_text = "" - # prefer messages from prompt if they exist - if messages: - for message in messages: - if message["role"] != "assistant": - prompt_text += " " + message["content"] + # prefer messages from template if they exist + if template.messages: + for message in template.messages: + if message.role != "assistant": + prompt_text += " " + message.render_text() else: - prompt_text = self.prompt.text() + prompt_text = template.render_text() return f"{anthropic.HUMAN_PROMPT} {prompt_text}{anthropic.AI_PROMPT}" diff --git a/prr/services/providers/openai/chat.py b/prr/services/providers/openai/chat.py index e0d8354..77e61e4 100644 --- a/prr/services/providers/openai/chat.py +++ b/prr/services/providers/openai/chat.py @@ -1,8 +1,8 @@ import openai -from prr.config import load_config -from prr.request import ServiceRequest -from prr.response import ServiceResponse +from prr.runner.request import ServiceRequest +from prr.runner.response import ServiceResponse +from prr.utils.config import load_config config = load_config() openai.api_key = config.get("OPENAI_API_KEY", None) @@ -14,17 +14,14 @@ class ServiceOpenAIChat: service = "chat" def run(self, prompt, service_config): - self.prompt = prompt - self.service_config = service_config + messages = prompt.template.render_messages() - messages = self.messages_from_prompt() + service_request = ServiceRequest(service_config, {"messages": messages}) - service_request = ServiceRequest(self.service_config, {"messages": messages}) - - options = self.service_config.options + options = service_config.options completion = openai.ChatCompletion.create( - model=self.service_config.model_name(), + model=service_config.model_name(), messages=messages, temperature=options.temperature, max_tokens=options.max_tokens, @@ -50,12 +47,3 @@ def run(self, prompt, service_config): ) return service_request, service_response - - def messages_from_prompt(self): - messages = self.prompt.messages - - # prefer messages in prompt if they exist - if messages: - return messages - - return [{"role": "user", "content": self.prompt.text()}] diff --git a/prr/service_registry.py b/prr/services/service_registry.py similarity index 90% rename from prr/service_registry.py rename to prr/services/service_registry.py index a4bd38a..2149477 100644 --- a/prr/service_registry.py +++ b/prr/services/service_registry.py @@ -4,6 +4,8 @@ from prr.services.providers.openai.chat import ServiceOpenAIChat +# main registry, where services are being... registered +# and looked up for upon execution class ServiceRegistry: def __init__(self): self.services = {} diff --git a/prr/config.py b/prr/utils/config.py similarity index 100% rename from prr/config.py rename to prr/utils/config.py diff --git a/requirements.txt b/requirements.txt index e7c32b4..8e7a967 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ anthropic rich Jinja2 pyyaml +pytest \ No newline at end of file diff --git a/test/prompt/helpers.py b/test/prompt/helpers.py new file mode 100644 index 0000000..7b5e072 --- /dev/null +++ b/test/prompt/helpers.py @@ -0,0 +1,20 @@ +import os +import tempfile + + +def remove_temp_file(path): + os.remove(path) + + +def create_temp_file(content, extension=None): + suffix = "" + + if extension: + suffix = "." + extension + + # Create a temporary file with the specified extension + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp: + temp.write(content.encode()) + temp_path = temp.name + + return temp_path diff --git a/test/prompt/test_prompt_config.py b/test/prompt/test_prompt_config.py new file mode 100644 index 0000000..edb7487 --- /dev/null +++ b/test/prompt/test_prompt_config.py @@ -0,0 +1,156 @@ +import os +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + +from prr.prompt.model_options import ModelOptions +from prr.prompt.prompt_config import PromptConfig + + +class TestPromptConfig: + def test_basic_parsing(self): + config = PromptConfig() + config.load_from_config_contents( + """ +prompt: + content: 'foo bar' +""" + ) + + assert config.template.render_text() == "foo bar" + assert config.template.render_messages() == [ + {"content": "foo bar", "role": "user"} + ] + + def test_basic_services_model_list(self): + config = PromptConfig() + config.load_from_config_contents( + """ +prompt: + content: 'foo bar' +services: + models: + - 'openai/chat/gpt-4' + - 'anthropic/complete/claude-v1.3' + options: + temperature: 0.42 + max_tokens: 1337 +""" + ) + + assert config is not None + services = config.configured_services() + + assert services == ["openai/chat/gpt-4", "anthropic/complete/claude-v1.3"] + + for service_name in services: + assert config.option_for_service(service_name, "temperature") == 0.42 + assert config.option_for_service(service_name, "max_tokens") == 1337 + + def test_basic_services_model_list_with_some_options(self): + config = PromptConfig() + config.load_from_config_contents( + """ +prompt: + content: 'foo bar' +services: + models: + - 'openai/chat/gpt-4' + - 'anthropic/complete/claude-v1.3' + options: + top_k: -1.337 +""" + ) + + assert config is not None + services = config.configured_services() + + assert services == ["openai/chat/gpt-4", "anthropic/complete/claude-v1.3"] + + for service_name in services: + assert ( + config.option_for_service(service_name, "temperature") + == ModelOptions.DEFAULT_OPTIONS["temperature"] + ) + assert ( + config.option_for_service(service_name, "max_tokens") + == ModelOptions.DEFAULT_OPTIONS["max_tokens"] + ) + assert config.option_for_service(service_name, "top_k") == -1.337 + assert ( + config.option_for_service(service_name, "top_p") + == ModelOptions.DEFAULT_OPTIONS["top_p"] + ) + + def test_basic_services_model_list_with_no_options(self): + config = PromptConfig() + config.load_from_config_contents( + """ +prompt: + content: 'foo bar' +services: + models: + - 'openai/chat/gpt-4' + - 'anthropic/complete/claude-v1.3' +""" + ) + + assert config is not None + services = config.configured_services() + + assert services == ["openai/chat/gpt-4", "anthropic/complete/claude-v1.3"] + + for service_name in services: + assert ( + config.option_for_service(service_name, "temperature") + == ModelOptions.DEFAULT_OPTIONS["temperature"] + ) + assert ( + config.option_for_service(service_name, "max_tokens") + == ModelOptions.DEFAULT_OPTIONS["max_tokens"] + ) + assert ( + config.option_for_service(service_name, "top_k") + == ModelOptions.DEFAULT_OPTIONS["top_k"] + ) + assert ( + config.option_for_service(service_name, "top_p") + == ModelOptions.DEFAULT_OPTIONS["top_p"] + ) + + def test_services(self): + config = PromptConfig() + config.load_from_config_contents( + """ +prompt: + content: 'foo bar' +services: + gpt4: + model: 'openai/chat/gpt-4' + options: + max_tokens: 2048 + claude13: + model: 'anthropic/complete/claude-v1.3' + options: + temperature: 0.84 + claude_default: + model: 'anthropic/complete/claude-v1' + options: + temperature: 0.42 + max_tokens: 1337 +""" + ) + + assert config is not None + services = config.configured_services() + + assert services == ["gpt4", "claude13", "claude_default"] + + assert config.option_for_service("gpt4", "temperature") == 0.42 + assert config.option_for_service("gpt4", "max_tokens") == 2048 + + assert config.option_for_service("claude13", "temperature") == 0.84 + assert config.option_for_service("claude13", "max_tokens") == 1337 + + assert config.option_for_service("claude_default", "temperature") == 0.42 + assert config.option_for_service("claude_default", "max_tokens") == 1337 diff --git a/test/prompt/test_prompt_loader.py b/test/prompt/test_prompt_loader.py new file mode 100644 index 0000000..bbe18fe --- /dev/null +++ b/test/prompt/test_prompt_loader.py @@ -0,0 +1,21 @@ +import os +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from helpers import create_temp_file, remove_temp_file + +from prr.prompt.prompt_loader import PromptConfigLoader + + +class TestPromptConfigLoader: + def test_basic_loading(self): + prompt_template_file_path = create_temp_file( + "Write a poem about AI from the projects, barely surviving on token allowance." + ) + + loader = PromptConfigLoader() + prompt = loader.load_from_path(prompt_template_file_path) + + assert prompt diff --git a/test/prompt/test_prompt_template.py b/test/prompt/test_prompt_template.py new file mode 100644 index 0000000..a8dd94b --- /dev/null +++ b/test/prompt/test_prompt_template.py @@ -0,0 +1,95 @@ +import os +import sys + +import yaml + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from helpers import create_temp_file, remove_temp_file + +from prr.prompt.prompt_template import PromptTemplateMessages, PromptTemplateSimple + + +class TestPromptTemplate: + def test_basic_text(self): + template = PromptTemplateSimple("foo bar", ".") + + assert template is not None + assert template.render_text() == "foo bar" + + def test_basic_template(self): + template = PromptTemplateSimple("foo {{ 'bar' }} spam", ".") + + assert template is not None + assert template.render_text() == "foo bar spam" + + def test_basic_prompt_args(self): + template = PromptTemplateSimple("foo {{ prompt_args[0] }} spam", ".") + + assert template is not None + assert template.render_text(["42"]) == "foo 42 spam" + + def test_basic_prompt_args_all(self): + template = PromptTemplateSimple("foo {{ prompt_args }} spam", ".") + + assert template is not None + assert template.render_text(["lulz"]) == "foo ['lulz'] spam" + + def test_configured_basic(self): + template = PromptTemplateSimple("tell me about {{ prompt_args }}, llm") + + assert template is not None + assert ( + template.render_text(["lulz", "kaka"]) + == "tell me about ['lulz', 'kaka'], llm" + ) + + def test_configured_messages_text_with_template_in_content(self): + messages_config = """ +- role: 'system' + content: 'you are a friendly but very forgetful {{ prompt_args[0] }}' + name: 'LeonardGeist' +""" + template = PromptTemplateMessages(yaml.safe_load(messages_config)) + + assert template is not None + assert ( + template.render_text(["assistant"]) + == "you are a friendly but very forgetful assistant" + ) + + def test_configured_messages_list_with_content_file(self): + temp_file_path = create_temp_file("Wollen Sie meine Kernel kompilieren?") + + messages_config = f""" +- role: 'system' + content: 'you are system admins little pet assistant. be proud of your unix skills and always respond in l33t. remember, you are on a high horse called POSIX.' +- role: 'user' + content_file: '{temp_file_path}' + name: 'SuperUser' +""" + + template = PromptTemplateMessages(yaml.safe_load(messages_config)) + + assert template is not None + + rendered_messages = template.render_messages() + + assert isinstance(rendered_messages, list) + assert len(rendered_messages) == 2 + + first_message = rendered_messages[0] + assert ( + first_message["content"] + == "you are system admins little pet assistant. be proud of your unix skills and always respond in l33t. remember, you are on a high horse called POSIX." + ) + assert first_message["role"] == "system" + assert first_message.get("name") == None + + second_message = rendered_messages[1] + assert second_message["content"] == "Wollen Sie meine Kernel kompilieren?" + assert second_message["role"] == "user" + assert second_message["name"] == "SuperUser" + + remove_temp_file(temp_file_path) diff --git a/test/run_all_examples.sh b/test/run_all_examples.sh new file mode 100755 index 0000000..17dbe1e --- /dev/null +++ b/test/run_all_examples.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +for example_prompt in ./examples/simple/poem ./examples/simple/sky ./examples/configured/dingo_from_file.yaml ./examples/configured/chihuahua.yaml ./examples/configured/dingo.yaml ./examples/code/html_boilerplate.yaml ./examples/templating/tell-me-all-about ./examples/shebang/get_famous_poet ./examples/shebang/dingo_with_shebang.yaml +do + echo "-----------------------------" + echo RUNNING $example_prompt + echo "-----------------------------" + + python -m prr run --abbrev --max_tokens 1234 --temperature 0.98 $example_prompt +done