diff --git a/requirements/base.txt b/requirements/base.txt index 9bb7677..dba861e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,6 +2,7 @@ backports.cached-property==1.0.0.post2 certifi>=2021.10.8 click>=8.1.3,<9 docker==7.1.0 +inflection==0.5.1 kubernetes>=28.1.0 opentelemetry-distro>=0.44b0 opentelemetry-exporter-prometheus>=0.44b0 @@ -10,6 +11,7 @@ pydantic>=2 python_dateutil>=2.8.0 pyyaml==6.0.1 requests>=2.25,<3 +rich==13.7.1 smart_open>=2.1.0,<6.0 tabulate==0.8.9 tenacity==8.2.2 diff --git a/src/gretel_client/inference_api/base.py b/src/gretel_client/inference_api/base.py index 365da5c..be70f21 100644 --- a/src/gretel_client/inference_api/base.py +++ b/src/gretel_client/inference_api/base.py @@ -8,7 +8,6 @@ from gretel_client.config import ClientConfig, configure_session, get_session_config from gretel_client.rest.api_client import ApiClient -from gretel_client.rest.configuration import Configuration MODELS_API_PATH = "/v1/inference/models" @@ -161,7 +160,7 @@ def __init__( elif len(session_kwargs) > 0: raise ValueError("cannot specify session arguments when passing a session") - if session.default_runner != "cloud" and not ".serverless." in session.endpoint: + if session.default_runner != "cloud" and ".serverless." not in session.endpoint: raise GretelInferenceAPIError( "Gretel's Inference API is currently only " "available within Gretel Cloud. Your current runner " diff --git a/src/gretel_client/navigator/__init__.py b/src/gretel_client/navigator/__init__.py new file mode 100644 index 0000000..14698ae --- /dev/null +++ b/src/gretel_client/navigator/__init__.py @@ -0,0 +1,2 @@ +from gretel_client.navigator.data_designer.interface import DataDesigner +from gretel_client.navigator.workflow import NavigatorWorkflow diff --git a/src/gretel_client/navigator/client/__init__.py b/src/gretel_client/navigator/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/client/interface.py b/src/gretel_client/navigator/client/interface.py new file mode 100644 index 0000000..588a6c1 --- /dev/null +++ b/src/gretel_client/navigator/client/interface.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Generic, Iterator, Optional, Type, TypeVar, Union + +import pandas as pd + +from gretel_client.projects import Project + + +def get_client(adapter: Union[Type[ClientAdapter], ClientAdapter]) -> Client: + if not isinstance(adapter, ClientAdapter): + adapter = adapter() + return Client(adapter) + + +@dataclass +class SubmitBatchWorkflowResponse: + project: Project + workflow_id: str + workflow_run_id: str + + +class Client: + + _adapter: ClientAdapter + + def __init__(self, adapter: ClientAdapter): + self._adapter = adapter + + def run_task( + self, + name: str, + config: dict, + inputs: Optional[list[TaskInput]] = None, + globals: Optional[dict] = None, + verbose: bool = False, + ) -> TaskOutput: + if inputs is None: + inputs = [] + if globals is None: + globals = {} + return self._adapter.run_task(name, config, inputs, globals, verbose) + + def get_workflow_preview(self, workflow_config: dict) -> Iterator: + return self._adapter.stream_workflow_outputs(workflow_config) + + def submit_batch_workflow( + self, + workflow_config: dict, + num_records: int, + project_name: Optional[str] = None, + ) -> SubmitBatchWorkflowResponse: + return self._adapter.submit_batch_workflow( + workflow_config, num_records, project_name + ) + + def get_step_output( + self, + workflow_run_id: str, + step_name: str, + format: Optional[str] = None, + ) -> TaskOutput: + return self._adapter.get_step_output(workflow_run_id, step_name, format) + + def download_step_output( + self, + workflow_run_id: str, + step_name: str, + output_dir: Path, + format: Optional[str] = None, + ) -> Path: + return self._adapter.download_step_output( + workflow_run_id, step_name, output_dir, format + ) + + def registry(self) -> list[dict]: + return self._adapter.registry() + + +TaskInput = TypeVar("TaskInput") +TaskOutput = Union[pd.DataFrame, dict] + + +class ClientAdapter(ABC, Generic[TaskInput]): + + @abstractmethod + def run_task( + self, + name: str, + config: dict, + inputs: list[TaskInput], + globals: dict, + verbose: bool = False, + ) -> TaskOutput: ... + + @abstractmethod + def stream_workflow_outputs( + self, workflow: dict, verbose: bool = False + ) -> Iterator[dict]: ... + + @abstractmethod + def registry(self) -> list[dict]: ... + + def submit_batch_workflow( + self, + workflow_config: dict, + num_records: int, + project_name: Optional[str] = None, + ) -> SubmitBatchWorkflowResponse: + raise NotImplementedError("Cannot submit batch Workflows") + + def get_step_output( + self, + workflow_run_id: str, + step_name: str, + format: Optional[str] = None, + ) -> TaskOutput: + raise NotImplementedError("Cannot get batch step outputs") + + def download_step_output( + self, + workflow_run_id: str, + step_name: str, + output_dir: Path, + format: Optional[str] = None, + ) -> Path: + raise NotImplementedError("Cannot download batch artifacts") diff --git a/src/gretel_client/navigator/client/remote.py b/src/gretel_client/navigator/client/remote.py new file mode 100644 index 0000000..6bed7ad --- /dev/null +++ b/src/gretel_client/navigator/client/remote.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import json +import logging + +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import Iterator, Optional, Union + +import pandas as pd +import pydantic +import requests +import smart_open + +from inflection import underscore +from requests import HTTPError +from rich import print as rich_print + +from gretel_client import Gretel +from gretel_client.config import get_session_config +from gretel_client.navigator.client.interface import ( + ClientAdapter, + SubmitBatchWorkflowResponse, + TaskInput, + TaskOutput, +) +from gretel_client.navigator.log import get_logger + +gretel_interface_logger = logging.getLogger("gretel_client.gretel.interface") +gretel_interface_logger.setLevel(logging.WARNING) + +logger = get_logger(__name__, level=logging.INFO) + + +Serializable = Union[pydantic.BaseModel, pd.DataFrame, dict] + + +@dataclass +class Message: + + step: str + """The name of the step""" + + stream: str + """ + The stream the message should be associated with. + + We use multiple streams so that we can differentiate between different types of outputs. + """ + + payload: dict + """The actual value of the output""" + + type: str + """The type of message""" + + ts: datetime + """The date and time the message was created""" + + @classmethod + def from_dict(cls, message: dict) -> Message: + message["ts"] = datetime.fromisoformat(message["ts"]) + return cls(**message) + + +def workflow_preview(workflow_outputs: Iterator, verbose: bool = False) -> TaskOutput: + terminal_output = None + for message_dict in workflow_outputs: + message = Message.from_dict(message_dict) + if message.stream == "step_outputs": + if message.type == "dataset": + terminal_output = pd.DataFrame.from_records( + message.payload.get("dataset") + ) + else: + terminal_output = message.payload + return terminal_output + + +class RemoteClient(ClientAdapter[Serializable]): + + def __init__( + self, + jarvis_endpoint: str = "https://jarvis.dev.gretel.cloud", + ): + self._session = get_session_config() + self._req_headers = {"Authorization": self._session.api_key} + self._jarvis_endpoint = jarvis_endpoint + + logger.debug(f"šŸŒŽ Connecting to {self._jarvis_endpoint}") + + # todo: pass an event handler and log non task outputs + def run_task( + self, + name: str, + config: dict, + inputs: list[TaskInput], + globals: dict, + verbose: bool = False, + ) -> TaskOutput: + if config is None: + config = {} + if inputs is None: + inputs = [] + if globals is None: + globals = {} + + inputs = serialize_inputs(inputs) + + response = requests.post( + f"{self._jarvis_endpoint}/tasks/exec", + json={"name": name, "config": config, "inputs": inputs, "globals": globals}, + headers=self._req_headers, + stream=True, + ) + + try: + response.raise_for_status() + except HTTPError as e: + rich_print(f"Got error: {str(e)}") + rich_print(response.json()) + raise e + + with response as messages: + try: + for o in messages.iter_lines(): + message = json.loads(o) + if message["stream"] == "step_outputs": + if message["type"] == "dataset": + return pd.DataFrame.from_records( + message["payload"]["dataset"] + ) + return message["payload"] + except Exception as e: + rich_print(e) + + raise Exception("Did not receive output for task") + + def stream_workflow_outputs( + self, workflow: dict, verbose: bool = False + ) -> Iterator[Message]: + with requests.post( + f"{self._jarvis_endpoint}/workflows/exec_streaming", + json=workflow, + headers=self._req_headers, + stream=True, + ) as outputs: + outputs.raise_for_status() + + for output in outputs.iter_lines(): + yield Message.from_dict(json.loads(output)) + + def submit_batch_workflow( + self, + workflow_config: dict, + num_records: int, + project_name: Optional[str] = None, + ) -> SubmitBatchWorkflowResponse: + + for step in workflow_config["steps"]: + if "num_records" in step["config"]: + step["config"]["num_records"] = num_records + + gretel = Gretel(session=self._session) + gretel.set_project(name=project_name) + project = gretel.get_project() + + logger.info("šŸ›œ Connecting to your Gretel Project:") + logger.info(f"šŸ”— -> {project.get_console_url()}") + + response = requests.post( + f"{self._jarvis_endpoint}/workflows/exec_batch", + json={ + "workflow_config": workflow_config, + "project_id": project.project_guid, + }, + headers=self._req_headers, + ) + response.raise_for_status() + response_body = response.json() + batch_response = SubmitBatchWorkflowResponse( + project=project, + workflow_id=response_body["workflow_id"], + workflow_run_id=response_body["workflow_run_id"], + ) + workflow_run_url = ( + f"{project.get_console_url().replace(project.project_guid, '')}workflows/" + f"{batch_response.workflow_id}/runs/{batch_response.workflow_run_id}" + ) + + logger.info(f"ā–¶ļø Starting your workflow run to generate {num_records} records:") + logger.info(f"šŸ”— -> {workflow_run_url}") + + return batch_response + + def get_step_output( + self, + workflow_run_id: str, + step_name: str, + format: Optional[str] = None, + ) -> TaskOutput: + with self._request_artifact(workflow_run_id, step_name, format) as response: + content_type = response.headers.get("content-type") + if content_type == "application/json": + return json.load(BytesIO(response.content)) + elif content_type == "application/vnd.apache.parquet": + return pd.read_parquet(BytesIO(response.content)) + else: + raise Exception( + f"Cannot get output format {format!r} as TaskOutput. Try downloading instead." + ) + + def download_step_output( + self, + workflow_run_id: str, + step_name: str, + output_dir: Path, + format: Optional[str] = None, + ) -> Path: + with self._request_artifact(workflow_run_id, step_name, format) as response: + filename = response.headers.get("content-disposition").split("filename=")[1] + out_file = output_dir / filename + with smart_open.open(out_file, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + return out_file + + @contextmanager + def _request_artifact( + self, + workflow_run_id: str, + step_name: str, + format: Optional[str] = None, + ) -> Iterator[requests.models.Response]: + endpoint = f"{self._jarvis_endpoint}/workflows/{workflow_run_id}/{step_name}" + params = {"format": format} + with requests.get( + endpoint, + headers=self._req_headers, + params=params, + stream=True, + ) as response: + response.raise_for_status() + yield response + + def registry(self) -> list[dict]: + response = requests.get( + f"{self._jarvis_endpoint}/registry", headers=self._req_headers + ) + response.raise_for_status() + + return response.json()["tasks"] + + +def serialize_inputs(inputs: list[TaskInput]) -> list[dict]: + inputs_as_json = [] + for _input in inputs: + if isinstance(_input, dict): + inputs_as_json.append(_input) + if isinstance(_input, pydantic.BaseModel): + inputs_as_json.append(_serialize_pydantic(_input)) + if isinstance(_input, pd.DataFrame): + inputs_as_json.append(_serialize_df(_input)) + return inputs_as_json + + +def _serialize_df(df: pd.DataFrame) -> dict: + return {"type": "dataset", "obj": {"dataset": df.to_dict(orient="records")}} + + +def _serialize_pydantic(pydantic_model: pydantic.BaseModel) -> dict: + return { + "type": underscore(pydantic_model.__name__), + "obj": pydantic_model.model_dump(), + } diff --git a/src/gretel_client/navigator/client/utils.py b/src/gretel_client/navigator/client/utils.py new file mode 100644 index 0000000..77481f7 --- /dev/null +++ b/src/gretel_client/navigator/client/utils.py @@ -0,0 +1,17 @@ +from typing import Optional, Type, Union + +from gretel_client.config import configure_session +from gretel_client.navigator.client.interface import Client, ClientAdapter +from gretel_client.navigator.client.remote import RemoteClient + + +def get_navigator_client( + client_adapter: Optional[Union[Type[ClientAdapter], ClientAdapter]] = None, + **session_kwargs, +) -> Client: + configure_session(**session_kwargs) + if client_adapter is None: + client_adapter = RemoteClient() + if not isinstance(client_adapter, ClientAdapter): + client_adapter = client_adapter() + return Client(client_adapter) diff --git a/src/gretel_client/navigator/data_designer/__init__.py b/src/gretel_client/navigator/data_designer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/data_designer/data_column.py b/src/gretel_client/navigator/data_designer/data_column.py new file mode 100644 index 0000000..cade4cf --- /dev/null +++ b/src/gretel_client/navigator/data_designer/data_column.py @@ -0,0 +1,112 @@ +from typing import Optional + +from pydantic import BaseModel, Field + +from gretel_client.navigator.client.interface import Client +from gretel_client.navigator.data_designer.prompt_templates import ( + COLUMN_GENERATION_PROMPT, + get_prompt_template_keywords, + system_prompt_dict, +) +from gretel_client.navigator.tasks.generate.generate_column_from_template import ( + GenerateColumnFromTemplate, + TextParserType, +) +from gretel_client.navigator.tasks.types import LLMType, OutputColumnType + +parser_instructions_map = { + TextParserType.PASS_THROUGH: ( + "Respond only with the requested text, " + "without any additional comments or instructions." + ), + TextParserType.JSON_ARRAY: "Respond only with a list as a valid JSON array.", + TextParserType.EXTRACT_CODE: ( + "Respond only with the requested code, " + "without any preamble or additional text." + ), +} + +output_parser_type_map = { + "str": TextParserType.PASS_THROUGH, + "string": TextParserType.PASS_THROUGH, + "text": TextParserType.PASS_THROUGH, + "nl": TextParserType.PASS_THROUGH, + "json": TextParserType.JSON, + "dict": TextParserType.JSON, + "list": TextParserType.JSON_ARRAY, + "json_array": TextParserType.JSON_ARRAY, + "code": TextParserType.EXTRACT_CODE, +} + + +class DataColumn(BaseModel): + name: str + description: str + output_type: OutputColumnType = OutputColumnType.TEXT + relevant_columns: list[str] = Field(default_factory=list) + specific_instructions: str = "" + llm_type: LLMType = LLMType.NL + + def get_context_list_string(self, exclude: Optional[set[str]] = None) -> str: + exclude = exclude or set() + + if len(set(self.relevant_columns) - exclude) == 0: + return "" + + section_title = "\n### Other Relevant Data ###\n" + return ( + section_title + + "\n".join( + [ + f" * {c.replace('_', ' ').capitalize()}: {{{c}}}" + for c in self.relevant_columns + if c not in exclude + ] + ) + + "\n" + ) + + def to_generation_task( + self, + special_system_instructions: Optional[str] = None, + client: Optional[Client] = None, + ) -> GenerateColumnFromTemplate: + + extra = "" + specific = "" + if len(self.specific_instructions) > 0: + specific = ( + f"\n### Specific Instructions ###\n{self.specific_instructions}\n" + ) + extra = "\n * Pay particularly close attention to the above Specific Instructions." + + output_parser = output_parser_type_map[self.output_type] + + # Exclude relevant_columns that are present in the specific instructions. + exclude = set() + for key in get_prompt_template_keywords(specific): + if {key}.issubset(self.relevant_columns): + exclude |= {key} + + return GenerateColumnFromTemplate( + prompt_template=COLUMN_GENERATION_PROMPT.format( + name=self.name, + description=self.description, + specific_instructions=specific, + context=self.get_context_list_string(exclude), + parser_instructions=parser_instructions_map[output_parser], + extra_instructions=extra, + ), + output_parser=output_parser, + response_column_name=self.name, + workflow_label=f"generating {self.name}", + llm_type=self.llm_type, + system_prompt=system_prompt_dict[self.llm_type].format( + special_instructions=( + "" + if special_system_instructions is None + else f"\n{special_system_instructions}\n" + ) + ), + client=client, + ) diff --git a/src/gretel_client/navigator/data_designer/interface.py b/src/gretel_client/navigator/data_designer/interface.py new file mode 100644 index 0000000..9b96192 --- /dev/null +++ b/src/gretel_client/navigator/data_designer/interface.py @@ -0,0 +1,547 @@ +import logging + +from collections import defaultdict +from pathlib import Path +from typing import Optional, Union + +import pandas as pd + +from rich import print as rich_print +from rich.pretty import pprint as rich_pprint +from typing_extensions import Self + +from gretel_client.gretel.config_setup import smart_load_yaml +from gretel_client.navigator.data_designer.data_column import DataColumn +from gretel_client.navigator.data_designer.prompt_templates import ( + get_prompt_template_keywords, +) +from gretel_client.navigator.data_designer.viz_tools import display_sample_record +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.generate.generate_seed_category_values import ( + GenerateSeedCategoryValues, +) +from gretel_client.navigator.tasks.load_data_seeds import LoadDataSeeds +from gretel_client.navigator.tasks.seed.sample_data_seeds import SampleDataSeeds +from gretel_client.navigator.tasks.types import ( + CodeLang, + DataSeedStatus, + DEFAULT_MODEL_SUITE, + LLMType, + ModelSuite, + OutputColumnType, + SeedCategory, + SeedSubcategory, + SQL_DIALECTS, + ValidatorType, +) +from gretel_client.navigator.tasks.validate.validate_code import ( + VALIDATE_PYTHON_ADDED_COLUMN_PREFIXES, + VALIDATE_SQL_ADDED_COLUMN_PREFIXES, + ValidateCode, +) +from gretel_client.navigator.workflow import ( + get_output_seeds_with_generation, + NavigatorWorkflow, + NavigatorWorkflowJobResults, + PreviewResults, +) + +logger = get_logger(__name__, level=logging.INFO) + + +class DataDesigner: + """High-level interface for designing synthetic data generation workflows with Gretel Navigator. + + The DataDesigner class streamlines the process of building synthetic datasets using Gretel + Navigator's Task Workflow execution framework. It provides a declarative config framework for + defining categorical data seeds, data columns, and data validators, which are used to assemble + a scalable synthetic data generation workflow. + + Args: + dataset_description: Optional description of the dataset to be generated. This description will + be used in prompts to provide high-level context about the dataset. + special_system_instructions: Optional instructions for the system to follow when generating + the dataset. These instructions will be added to the system prompts. + model_suite: The model suite to use for generating synthetic data. Defaults to the + Apache-2.0 licensed model suite. + **session_kwargs: kwargs for your Gretel session. See options below. + + Keyword Args: + api_key (str): Your Gretel API key. If set to "prompt" and no API key + is found on the system, you will be prompted for the key. + endpoint (str): Specifies the Gretel API endpoint. This must be a fully + qualified URL. The default is "https://api.gretel.cloud". + default_runner (str): Specifies the runner mode. Must be one of "cloud", + "local", "manual", or "hybrid". The default is "cloud". + artifact_endpoint (str): Specifies the endpoint for project and model + artifacts. Defaults to "cloud" for running in Gretel Cloud. If + working in hybrid mode, set to the URL of your artifact storage bucket. + cache (str): Valid options are "yes" or "no". If set to "no", the session + configuration will not be written to disk. If set to "yes", the + session configuration will be written to disk only if one doesn't + already exist. The default is "no". + validate (bool): If `True`, will validate the login credentials at + instantiation. The default is `False`. + clear (bool): If `True`, existing Gretel credentials will be removed. + The default is `False.` + """ + + def __init__( + self, + *, + dataset_description: Optional[str] = None, + special_system_instructions: Optional[str] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + **session_kwargs, + ): + self.workflow = NavigatorWorkflow(model_suite=model_suite, **session_kwargs) + self.dataset_description = dataset_description + self.special_system_instructions = special_system_instructions + + self._seed_categories: dict[str, list[SeedCategory]] = {} + self._data_columns: dict[str, list[DataColumn]] = {} + self._data_validators: dict[str, list[Task]] = {} + self._seed_subcategory_names = defaultdict(list) + self._seed_categories_with_generation: Optional[ + dict[str, list[SeedCategory]] + ] = None + self._seed_status: DataSeedStatus = DataSeedStatus.NO_GENERATION + self._reset_workflow() + + @property + def _data_column_names(self) -> list[str]: + """Return a list of the names of the data columns (note order matters).""" + return list(self._data_columns.keys()) + + @property + def _seed_category_names(self) -> list[str]: + """Return a list of the names of the seed categories.""" + return list(self._seed_categories.keys()) + + @property + def column_names(self) -> list[str]: + """Return a list of all seed (including seed subcategories) and data column names.""" + return ( + self._seed_category_names + + [s for ss in self._seed_subcategory_names.values() for s in ss] + + self._data_column_names + ) + + @classmethod + def from_config(cls, config: Union[dict, str, Path], **session_kwargs) -> Self: + """Instantiate a DataDesigner instance from a YAML configuration str, dict, or file. + + Args: + config: A YAML configuration file, dict, or string. + **session_kwargs: kwargs for your Gretel session. + + Returns: + An instance of DataDesigner configured with the settings from the provided YAML config. + """ + config = smart_load_yaml(config) + + designer = cls( + dataset_description=config.get("dataset_description"), + special_system_instructions=config.get("special_system_instructions"), + model_suite=config.get("model_suite", DEFAULT_MODEL_SUITE), + **session_kwargs, + ) + + if "seed_categories" not in config: + raise ValueError( + "No seed categories were defined in the config. At least one seed " + "category must be defined in the seed_categories field." + ) + + for seed_category in config.get("seed_categories", []): + designer.add_seed_category(**seed_category) + logger.debug(f"šŸŒ± Adding seed category: {seed_category['name']}") + + for data_column in config.get("data_columns", []): + designer.add_data_column(**data_column) + logger.debug(f"šŸ’½ Adding data column: {data_column['name']}") + + if len(designer.column_names) == 0: + raise ValueError("No seed or data columns were defined in the config.") + + for validator_settings in config.get("data_validators", []): + validator_type = ValidatorType(validator_settings.pop("validator")) + designer.add_data_validator(validator_type, **validator_settings) + + designer._config = config + + return designer + + def _get_final_seed_categories(self) -> Optional[dict]: + """Return the final seed categories to be used in the workflow.""" + if self._seed_status == DataSeedStatus.NO_GENERATION: + return {"seed_categories": list(self._seed_categories.values())} + elif self._seed_status == DataSeedStatus.GENERATED: + return self._seed_categories_with_generation + elif self._seed_status == DataSeedStatus.NEEDS_GENERATION: + return None + else: + raise ValueError(f"Unknown seed status: {self._seed_status}") + + def _create_sequential_task_list(self) -> list[Task]: + if len(self._seed_categories) == 0: + raise ValueError("No seed columns have been defined.") + + task_list = [] + final_seed_categories = self._get_final_seed_categories() + + if final_seed_categories is None: + task_list.append( + GenerateSeedCategoryValues( + seed_categories=list(self._seed_categories.values()), + dataset_context=self.dataset_description, + client=self.workflow._client, + ) + ) + else: + task_list.append( + LoadDataSeeds( + categorical_data_seeds=final_seed_categories, + client=self.workflow._client, + ) + ) + + task_list.append(SampleDataSeeds(client=self.workflow._client)) + + for column in self._data_columns.values(): + task = column.to_generation_task( + self.special_system_instructions, client=self.workflow._client + ) + task_list.append(task) + + for validator in self._data_validators.values(): + task_list.append(validator) + + return task_list + + def _reset_workflow(self) -> None: + self.workflow.reset_steps() + self.workflow_step_names = [] + self.task_to_step_map = {} + + def _prepare_workflow(self) -> str: + self._reset_workflow() + task_list = self._create_sequential_task_list() + steps = self.workflow.create_steps_from_sequential_tasks(task_list) + self.workflow_step_names = [s.name for s in steps] + self.task_to_step_map = {t.name: s.name for t, s in zip(task_list, steps)} + self.workflow.add_steps(steps) + # We need context from within DataDesigner to identify the last dataset step. + # We assume that the user will generally want the final dataset step + # from the workflow, since this will contain the final synthesized dataset. + dataset_steps = { + self.task_to_step_map[k] + for k, v in self.workflow._task_io.items() + if k in self.task_to_step_map + if v["output"] == "dataset" + } + dataset_step_list = [s.name for s in steps if s.name in dataset_steps] + final_dataset_step_name = ( + dataset_step_list[-1] if len(dataset_step_list) > 0 else None + ) + return final_dataset_step_name + + def _validate_data_column_inputs( + self, + name: str, + description: str, + relevant_columns: Optional[list[str]] = None, + specific_instructions: Optional[str] = None, + ) -> tuple[str, str, list[str], str]: + if name in self._data_columns: + raise ValueError(f"Column name `{name}` already exists.") + if name in self._seed_categories: + raise ValueError(f"Column name `{name}` already exists as a seed category.") + + specific_instructions = specific_instructions or "" + for n, v in zip( + ["description", "specific_instructions"], + [description, specific_instructions], + ): + template_kwargs = get_prompt_template_keywords(v) + if not template_kwargs.issubset(self.column_names): + raise ValueError( + f"The `{n}` field of `{name}`contains template keywords that " + "are not available as columns.\n" + f"* Template keywords found in `{n}`: {template_kwargs}\n" + f"* Available seed columns: {self._seed_category_names}\n" + f"* Available data columns: {self._data_column_names}" + ) + + relevant_columns = relevant_columns or [] + if any(col not in self.column_names for col in relevant_columns): + raise ValueError( + f"The `relevant_columns` field of `{name}` is not configured correctly. " + "Relevant columns must be added *before* the column that references them.\n" + f"* Available seed columns: {self._seed_category_names}\n" + f"* Available data columns: {self._data_column_names}\n" + ) + + return name, description, relevant_columns, specific_instructions + + def generate_seed_category_values( + self, force_generation: bool = False, verbose_logging: bool = False + ) -> None: + if len(self._seed_categories) == 0: + raise ValueError("No seed categories have been defined.") + + if self._seed_status == DataSeedStatus.NO_GENERATION: + logger.warning("Your seed categories do not require any generation.") + return + elif self._seed_status == DataSeedStatus.GENERATED and not force_generation: + logger.warning( + "Your seed category values have already been generated. To force " + "generation, set `force_generation=True`." + ) + return + + task = GenerateSeedCategoryValues( + seed_categories=list(self._seed_categories.values()), + dataset_context=self.dataset_description, + client=self.workflow._client, + ) + + self._reset_workflow() + steps = self.workflow.create_steps_from_sequential_tasks([task]) + self.workflow.add_steps(steps) + preview = self.workflow._execute_preview(verbose_logging) + self._reset_workflow() + self._seed_status = DataSeedStatus.GENERATED + self._seed_categories_with_generation = preview.output + + def inspect_data_seeds(self) -> None: + if self._seed_status == DataSeedStatus.NEEDS_GENERATION: + logger.warning( + "You have seed categories with values that require generation. " + "Run `generate_seed_category_values()` to generate seed values." + ) + else: + columns_to_print = [ + "name", + "description", + "values", + "generated_values", + "num_values_to_generate", + "subcategories", + ] + rich_print("-" * 80 + "\nšŸŒ± Categorical Data Seeds \n" + "-" * 80) + for seed in self._get_final_seed_categories()["seed_categories"][::-1]: + if isinstance(seed, SeedCategory): + seed = seed.model_dump() + rich_pprint( + {k: v for k, v in seed.items() if k in columns_to_print}, + indent_guides=False, + ) + print("-" * 80) + + def reset_generated_seed_values(self) -> None: + if self._seed_status == DataSeedStatus.GENERATED: + logger.info("šŸ”„ Resetting generated seed values") + self._seed_categories_with_generation = None + self._seed_status = DataSeedStatus.NEEDS_GENERATION + + def add_seed_category( + self, + name: str, + *, + description: Optional[str] = None, + values: Optional[list[Union[str, int, float]]] = None, + weights: Optional[list[float]] = None, + num_values_to_generate: Optional[int] = None, + subcategories: Optional[Union[list[SeedSubcategory], list[dict]]] = None, + ) -> None: + if len(self._data_columns) > 0: + raise ValueError( + "Seed categories must be added *before* data columns.\n" + f"-> Current data columns: {self._data_column_names}" + ) + if num_values_to_generate is None and values is None: + raise ValueError( + "You must provide *at least* one of `values` or `num_values_to_generate`." + ) + if name in self._seed_category_names: + raise ValueError(f"Seed category `{name}` already exists.") + + if len(subcategories or []) > 0: + for seed in subcategories: + if isinstance(seed, dict): + seed = SeedSubcategory(**seed) + if seed.name in self._seed_category_names: + raise ValueError(f"Seed category `{seed.name}` already exists.") + self._seed_subcategory_names[name].append(seed.name) + self._seed_status = DataSeedStatus.NEEDS_GENERATION + + if num_values_to_generate is not None and num_values_to_generate > 0: + self._seed_status = DataSeedStatus.NEEDS_GENERATION + + self._seed_categories[name] = SeedCategory( + name=name, + description=description, + values=values or [], + weights=weights or [], + num_values_to_generate=num_values_to_generate, + subcategories=subcategories or [], + ) + + def add_data_column( + self, + name: str, + *, + description: str, + output_type: OutputColumnType = OutputColumnType.TEXT, + relevant_columns: Optional[list[str]] = None, + specific_instructions: Optional[str] = None, + llm_type: LLMType = LLMType.NL, + ) -> None: + name, description, relevant_columns, specific_instructions = ( + self._validate_data_column_inputs( + name, description, relevant_columns, specific_instructions + ) + ) + relevant_columns = relevant_columns or self._seed_category_names + self._data_columns[name] = DataColumn( + name=name, + description=description, + output_type=output_type, + relevant_columns=relevant_columns or [], + specific_instructions=specific_instructions or "", + llm_type=llm_type, + ) + + def add_data_validator(self, validator: ValidatorType, **settings) -> None: + if validator == ValidatorType.CODE: + if "code_lang" not in settings: + raise ValueError("You must provide `code_lang` for code validation.") + CodeLang.validate(settings["code_lang"]) + if "code_columns" not in settings: + raise ValueError("You must provide `code_columns` for code validation.") + if not isinstance(settings["code_columns"], list): + raise ValueError( + "`code_columns` must be a list of column names. " + f"You provided: {settings['code_columns']}" + ) + if not set(settings["code_columns"]).issubset(self.column_names): + raise ValueError( + "`code_columns` contains columns that have not been defined." + f"\n* Available columns: {self.column_names}" + ) + self._data_validators[ValidatorType(validator).value] = ValidateCode( + client=self.workflow._client, **settings + ) + else: + raise ValueError(f"Unknown validator type: {validator}") + + def get_seed_category(self, name: str) -> SeedCategory: + return self._seed_categories[name] + + def get_data_column(self, name: str) -> DataColumn: + return self._data_columns[name] + + def generate_dataset_preview( + self, *, verbose_logging: bool = False, use_last_generated_seeds: bool = True + ) -> PreviewResults: + if ( + not use_last_generated_seeds + and self._seed_status == DataSeedStatus.GENERATED + ): + self._seed_status = DataSeedStatus.NEEDS_GENERATION + + self._prepare_workflow() + preview = self.workflow.generate_dataset_preview( + verbose_logging=verbose_logging + ) + + if ( + seeds := get_output_seeds_with_generation(preview.outputs_by_step) + ) and self._seed_status in [ + DataSeedStatus.GENERATED, + DataSeedStatus.NEEDS_GENERATION, + ]: + logger.debug("Saving generated seed values") + self._seed_status = DataSeedStatus.GENERATED + self._seed_categories_with_generation = seeds + return preview + + def submit_batch_workflow( + self, num_records: int, project_name: Optional[str] = None + ) -> NavigatorWorkflowJobResults: + final_dataset_step_name = self._prepare_workflow() + batch_workflow_results = self.workflow.submit_batch_job( + num_records=num_records, project_name=project_name + ) + return NavigatorWorkflowJobResults( + batch_workflow_results=batch_workflow_results, + workflow_step_names=self.workflow_step_names, + final_dataset_step_name=final_dataset_step_name, + ) + + def display_sample_record( + self, + record: Union[dict, pd.Series], + background_color: Optional[str] = None, + theme: str = "dracula", + ) -> None: + code_lang = None + code_columns = [] + validation_columns = [] + for validator in self._data_validators.values(): + if validator.name == "validate_code": + column_prefix = ( + VALIDATE_SQL_ADDED_COLUMN_PREFIXES + if validator.config.code_lang in SQL_DIALECTS + else VALIDATE_PYTHON_ADDED_COLUMN_PREFIXES + ) + code_lang = validator.config.code_lang + code_columns.extend(validator.config.code_columns) + for col in code_columns: + for prefix in column_prefix: + validation_columns.append(f"{col}{prefix}") + break + + display_sample_record( + record=record, + seed_categories=self._seed_category_names, + data_columns=self._data_column_names, + seed_subcategories=self._seed_subcategory_names, + background_color=background_color, + code_columns=code_columns, + validation_columns=validation_columns, + code_lang=code_lang, + theme=theme, + ) + + def __repr__(self): + seed_categories = [ + ( + name + if len(s.subcategories) == 0 + else f"{name}:{','.join([n.name for n in s.subcategories])}" + ) + for name, s in self._seed_categories.items() + ] + + validators = ( + f" validators: {list(self._data_validators)}\n" + if len(self._data_validators) > 0 + else "" + ) + + seed_generated_label = "" + if self._seed_status == DataSeedStatus.NEEDS_GENERATION: + seed_generated_label = " (needs generation)" + elif self._seed_status == DataSeedStatus.GENERATED: + seed_generated_label = " (has generated values)" + + return ( + f"{self.__class__.__name__}(\n" + f" seed_categories{seed_generated_label}: {seed_categories}\n" + f" data_columns: {self._data_column_names}\n" + f"{validators}" + ")" + ) diff --git a/src/gretel_client/navigator/data_designer/prompt_templates.py b/src/gretel_client/navigator/data_designer/prompt_templates.py new file mode 100644 index 0000000..680de4b --- /dev/null +++ b/src/gretel_client/navigator/data_designer/prompt_templates.py @@ -0,0 +1,60 @@ +from string import Formatter + + +def get_prompt_template_keywords(template: str) -> set[str]: + return { + k[1] for k in Formatter().parse(template) if len(k) > 1 and k[1] is not None + } + + +DATA_DESIGNER_BASE_SYSTEM_PROMPT = """\ +You are an expert data practioner, skilled in the creation of diverse, high-quality datasets, \ +leveraging deep expertise across domains in both the academic and industry sectors. \ +You always carefully consider all information provided to you, and you always follow all \ +instructions. \ +{llm_type_specific_instructions} \ +{{special_instructions}} +You always provide your response after the '### Response ###' section header. + +YOU MUST GENERATE ALL OUTPUT IN ENGLISH ONLY. +""" + +DATA_DESIGNER_NL_SYSTEM_PROMPT = DATA_DESIGNER_BASE_SYSTEM_PROMPT.format( + llm_type_specific_instructions="""\ +You are particularly adept at writing natural language, strictly adhering to \ +all formatting constraints and instructions provided to you. +""" +) + + +DATA_DESIGNER_CODE_SYSTEM_PROMPT = DATA_DESIGNER_BASE_SYSTEM_PROMPT.format( + llm_type_specific_instructions="""\ +You are obsessed with writing excellent software that strictly adheres to all formatting \ +constraints and instructions provided to you. You always use markdown code blocks to format your code, \ +and you always respond with only the requested code, without any preamble or additional text. \ +Importantly, you ALWAYS write self-contained code. +""" +) + +COLUMN_GENERATION_PROMPT = """\ +Your task is to generate the `{name}` column in a dataset based on the information given below. \ +It is VERY IMPORTANT that you follow ALL instructions carefully. + +### Column Description ### +{description} +{context}\ +{specific_instructions} +### General Instructions ### + * Generate the `{name}` column as described above. + * Remember to generate all output in English. + * Remember to base your response on the above information. \ + {extra_instructions} + * {parser_instructions} + +### Response ### +""" + +system_prompt_dict = { + "nl": DATA_DESIGNER_NL_SYSTEM_PROMPT, + "code": DATA_DESIGNER_CODE_SYSTEM_PROMPT, +} diff --git a/src/gretel_client/navigator/data_designer/viz_tools.py b/src/gretel_client/navigator/data_designer/viz_tools.py new file mode 100644 index 0000000..c389263 --- /dev/null +++ b/src/gretel_client/navigator/data_designer/viz_tools.py @@ -0,0 +1,103 @@ +import numbers + +from typing import Optional, Union + +import pandas as pd + +from rich.console import Console +from rich.panel import Panel +from rich.syntax import Syntax +from rich.table import Table + +from gretel_client.navigator.tasks.types import CodeLang + +console = Console() + + +def display_sample_record( + record: Union[dict, pd.Series, pd.DataFrame], + seed_categories: list[str], + data_columns: list[str], + seed_subcategories: Optional[dict[str, list[str]]] = None, + code_lang: Optional[CodeLang] = None, + code_columns: Optional[list[str]] = None, + validation_columns: Optional[list[str]] = None, + background_color: Optional[str] = None, + theme: str = "dracula", +): + if isinstance(record, (dict, pd.Series)): + record = pd.DataFrame([record]).iloc[0] + elif isinstance(record, pd.DataFrame): + if record.shape[0] > 1: + raise ValueError( + "The record must be a single record. You provided a " + f"DataFrame with {record.shape[0]} records." + ) + record = record.iloc[0] + else: + raise ValueError( + "The record must be a single record in a dictionary, pandas Series, " + f"or pandas DataFrame. You provided: {type(record)}." + ) + + code_columns = code_columns or [] + seed_subcategories = seed_subcategories or {} + validation_columns = validation_columns or [] + code_lang = None if code_lang is None else CodeLang.validate(code_lang) + + table_kws = dict(show_lines=True, expand=True) + + if len(seed_categories) > 0: + table = Table(title="Seed Columns", **table_kws) + table.add_column("Column Name") + for col in [c for c in seed_categories if c not in code_columns]: + table.add_row(col, str(record[col])) + if col in seed_subcategories: + for nested_col in seed_subcategories[col]: + table.add_row(f" |- {nested_col}", str(record[nested_col])) + console.print(table, end="\n\n") + + if len(data_columns) > 0: + table = Table(title="Data Columns", **table_kws) + table.add_column("Column Name") + table.add_column("Value") + for col in [c for c in data_columns if c not in code_columns]: + table.add_row(col, str(record[col])) + console.print(table, end="\n\n") + + if len(validation_columns) > 0: + table = Table(title="Validation", **table_kws) + table.add_column("Column Name") + table.add_column("Value") + for col in validation_columns: + value = record[col] + if isinstance(value, numbers.Number): + name, value = col, f"{value:.2f}" + elif isinstance(value, list) and len(value) > 0: + length = len(value) + label = "" if length == 1 else f" (first of {length} messages)" + name = f"{col}{label}" + value = str(value[0]) + else: + name, value = col, str(value) + table.add_row(name, value) + console.print(table, end="\n\n") + + for col in code_columns: + if code_lang is None: + raise ValueError( + "`code_lang` must be provided when code_columns are specified." + f"Valid options are: {', '.join([c.value for c in CodeLang])}" + ) + panel = Panel( + Syntax( + record[col], + lexer=code_lang.to_syntax_lexer(), + theme=theme, + word_wrap=True, + background_color=background_color, + ), + title=col, + expand=True, + ) + console.print(panel) diff --git a/src/gretel_client/navigator/log.py b/src/gretel_client/navigator/log.py new file mode 100644 index 0000000..aed585a --- /dev/null +++ b/src/gretel_client/navigator/log.py @@ -0,0 +1,14 @@ +import logging +import sys + + +def get_logger(name: str, *, level: int = logging.INFO) -> logging.Logger: + logger = logging.getLogger(name) + logger.propagate = False + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s", "%H:%M:%S") + ) + logger.addHandler(handler) + logger.setLevel(level) + return logger diff --git a/src/gretel_client/navigator/tasks/__init__.py b/src/gretel_client/navigator/tasks/__init__.py new file mode 100644 index 0000000..75ce3a5 --- /dev/null +++ b/src/gretel_client/navigator/tasks/__init__.py @@ -0,0 +1,10 @@ +from gretel_client.navigator.tasks.generate.generate_column_from_template import ( + GenerateColumnFromTemplate, +) +from gretel_client.navigator.tasks.generate.generate_seed_category_values import ( + GenerateSeedCategoryValues, +) +from gretel_client.navigator.tasks.load_data_seeds import LoadDataSeeds +from gretel_client.navigator.tasks.seed.sample_data_seeds import SampleDataSeeds +from gretel_client.navigator.tasks.seed.seed_from_records import SeedFromRecords +from gretel_client.navigator.tasks.validate.validate_code import ValidateCode diff --git a/src/gretel_client/navigator/tasks/base.py b/src/gretel_client/navigator/tasks/base.py new file mode 100644 index 0000000..e11cff4 --- /dev/null +++ b/src/gretel_client/navigator/tasks/base.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + +from pydantic import BaseModel + +from gretel_client.navigator.client.interface import Client, ClientAdapter, TaskOutput +from gretel_client.navigator.client.utils import get_navigator_client +from gretel_client.navigator.tasks.io import Dataset +from gretel_client.navigator.tasks.types import ( + check_model_suite, + DEFAULT_MODEL_SUITE, + ModelSuite, + RecordsT, +) + + +class Task(ABC): + + def __init__( + self, + config: BaseModel, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + ): + self.config = config + self.workflow_label = workflow_label + self._client = client or get_navigator_client() + self._globals = {"model_suite": check_model_suite(model_suite)} + + def _records_to_dataset_if_needed( + self, dataset: Union[Dataset, RecordsT] + ) -> Dataset: + if isinstance(dataset, Dataset): + return dataset + return Dataset.from_records(dataset) + + def _set_client(self, adapter: ClientAdapter): + """Set client adapter for task execution. + + This is an internal method that is not useable by end users. + """ + self._client = get_navigator_client(adapter) + + def _run(self, *inputs) -> TaskOutput: + try: + return self._client.run_task( + name=self.name, + config=self.config.model_dump(), + inputs=list(inputs), + globals=self._globals, + ) + except Exception as e: + print(e) + + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def run(self, *args, **kwargs) -> TaskOutput: ... diff --git a/src/gretel_client/navigator/tasks/generate/__init__.py b/src/gretel_client/navigator/tasks/generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/tasks/generate/generate_column_from_template.py b/src/gretel_client/navigator/tasks/generate/generate_column_from_template.py new file mode 100644 index 0000000..b39f051 --- /dev/null +++ b/src/gretel_client/navigator/tasks/generate/generate_column_from_template.py @@ -0,0 +1,57 @@ +from typing import Optional, Union + +from pydantic import BaseModel + +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.io import Dataset +from gretel_client.navigator.tasks.types import ( + DEFAULT_MODEL_SUITE, + LLMType, + ModelSuite, + TextParserType, +) + +DEFAULT_RESPONSE_COLUMN_NAME = "response" + + +class GenerateColumnFromTemplateConfig(BaseModel): + prompt_template: str + response_column_name: str = DEFAULT_RESPONSE_COLUMN_NAME + output_parser: TextParserType = TextParserType.PASS_THROUGH + llm_type: LLMType = LLMType.NL + system_prompt: Optional[str] = None + + +class GenerateColumnFromTemplate(Task): + + def __init__( + self, + prompt_template: str, + response_column_name: str = DEFAULT_RESPONSE_COLUMN_NAME, + output_parser: TextParserType = TextParserType.PASS_THROUGH, + llm_type: LLMType = LLMType.NL, + system_prompt: Optional[str] = None, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + ): + super().__init__( + config=GenerateColumnFromTemplateConfig( + prompt_template=prompt_template, + response_column_name=response_column_name, + output_parser=output_parser, + llm_type=llm_type, + system_prompt=system_prompt, + ), + workflow_label=workflow_label, + client=client, + model_suite=model_suite, + ) + + @property + def name(self) -> str: + return "generate_column_from_template" + + def run(self, dataset: Union[Dataset, list[dict]]) -> TaskOutput: + return self._run(self._records_to_dataset_if_needed(dataset)) diff --git a/src/gretel_client/navigator/tasks/generate/generate_seed_category_values.py b/src/gretel_client/navigator/tasks/generate/generate_seed_category_values.py new file mode 100644 index 0000000..059b30a --- /dev/null +++ b/src/gretel_client/navigator/tasks/generate/generate_seed_category_values.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel + +from gretel_client.gretel.config_setup import smart_load_yaml +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.types import ( + DEFAULT_MODEL_SUITE, + ModelSuite, + SeedCategory, +) + + +class GenerateSeedCategoryValuesConfig(BaseModel): + seed_categories: list[SeedCategory] + dataset_context: str = "" + + +class GenerateSeedCategoryValues(Task): + + def __init__( + self, + seed_categories: Union[str, Path, list[dict], list[SeedCategory]], + dataset_context: Optional[str] = None, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + ): + super().__init__( + config=GenerateSeedCategoryValuesConfig( + seed_categories=self._check_and_get_seed_categories(seed_categories), + dataset_context=dataset_context or "", + ), + workflow_label=workflow_label, + client=client, + model_suite=model_suite, + ) + + @staticmethod + def _check_and_get_seed_categories( + categories: Union[str, Path, list[dict], list[SeedCategory]] + ) -> list[SeedCategory]: + if isinstance(categories, (str, Path)): + categories = smart_load_yaml(categories).get("seed_categories") + + if not isinstance(categories, list): + raise ValueError( + "`seed_categories` must be a list of dicts or SeedCategory objects" + ) + + # Convert dicts to DataSeedColumn objects to ensure they are valid. + if all(isinstance(seed, dict) for seed in categories): + categories = SeedCategory.from_dicts(categories) + + if not all(isinstance(seed, SeedCategory) for seed in categories): + raise ValueError( + "`seed_categories` must be a list of dicts or SeedCategory objects" + ) + + return categories + + @property + def name(self) -> str: + return "generate_seed_category_values" + + def run(self) -> TaskOutput: + return self._run() diff --git a/src/gretel_client/navigator/tasks/io.py b/src/gretel_client/navigator/tasks/io.py new file mode 100644 index 0000000..3662f1b --- /dev/null +++ b/src/gretel_client/navigator/tasks/io.py @@ -0,0 +1,3 @@ +import pandas as pd + +Dataset = pd.DataFrame diff --git a/src/gretel_client/navigator/tasks/judge_with_llm.py b/src/gretel_client/navigator/tasks/judge_with_llm.py new file mode 100644 index 0000000..04a3bf5 --- /dev/null +++ b/src/gretel_client/navigator/tasks/judge_with_llm.py @@ -0,0 +1,45 @@ +from typing import Optional, Union + +from pydantic import BaseModel + +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.io import Dataset +from gretel_client.navigator.tasks.types import LLMJudgePromptTemplateType, RecordsT + + +class JudgeWithLLMConfig(BaseModel): + judge_template_type: LLMJudgePromptTemplateType + instruction_column_name: str + response_column_name: str + context_column_name: Optional[str] = None + + +class JudgeWithLLM(Task): + + def __init__( + self, + judge_template_type: LLMJudgePromptTemplateType, + instruction_column_name: str, + response_column_name: str, + context_column_name: Optional[str] = None, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + ): + super().__init__( + config=JudgeWithLLMConfig( + judge_template_type=judge_template_type, + instruction_column_name=instruction_column_name, + response_column_name=response_column_name, + context_column_name=context_column_name, + ), + workflow_label=workflow_label, + client=client, + ) + + @property + def name(self) -> str: + return "judge_with_llm" + + def run(self, dataset: Union[Dataset, RecordsT]) -> TaskOutput: + return self._run(self._records_to_dataset_if_needed(dataset)) diff --git a/src/gretel_client/navigator/tasks/load_data_seeds.py b/src/gretel_client/navigator/tasks/load_data_seeds.py new file mode 100644 index 0000000..e14cb85 --- /dev/null +++ b/src/gretel_client/navigator/tasks/load_data_seeds.py @@ -0,0 +1,27 @@ +from typing import Optional, Union + +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.types import CategoricalDataSeeds + + +class LoadDataSeeds(Task): + + def __init__( + self, + categorical_data_seeds: Union[dict, CategoricalDataSeeds], + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + ): + if categorical_data_seeds and isinstance(categorical_data_seeds, dict): + categorical_data_seeds = CategoricalDataSeeds(**categorical_data_seeds) + super().__init__( + config=categorical_data_seeds, workflow_label=workflow_label, client=client + ) + + @property + def name(self) -> str: + return "load_data_seeds" + + def run(self) -> TaskOutput: + return self._run() diff --git a/src/gretel_client/navigator/tasks/seed/__init__.py b/src/gretel_client/navigator/tasks/seed/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/tasks/seed/sample_data_seeds.py b/src/gretel_client/navigator/tasks/seed/sample_data_seeds.py new file mode 100644 index 0000000..3be91a0 --- /dev/null +++ b/src/gretel_client/navigator/tasks/seed/sample_data_seeds.py @@ -0,0 +1,36 @@ +from typing import Optional + +from pydantic import BaseModel + +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task + + +class SampleDataSeedsConfig(BaseModel): + num_records: int = 10 + + +class SampleDataSeeds(Task): + + def __init__( + self, + num_records: int = 10, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + ): + super().__init__( + config=SampleDataSeedsConfig(num_records=num_records), + workflow_label=workflow_label, + client=client, + ) + + @property + def name(self): + return "sample_data_seeds" + + def run(self, categorical_data_seeds: dict) -> TaskOutput: + if self.config.num_records > 10: + raise ValueError("You can only preview up to to 10 records at a time.") + return self._run( + {"type": "categorical_data_seeds", "obj": categorical_data_seeds} + ) diff --git a/src/gretel_client/navigator/tasks/seed/seed_from_records.py b/src/gretel_client/navigator/tasks/seed/seed_from_records.py new file mode 100644 index 0000000..1b0a9ab --- /dev/null +++ b/src/gretel_client/navigator/tasks/seed/seed_from_records.py @@ -0,0 +1,32 @@ +from typing import Optional + +from pydantic import BaseModel + +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task + + +class SeedFromRecordsConfig(BaseModel): + records: list[dict] + + +class SeedFromRecords(Task): + + def __init__( + self, + records: list[dict], + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + ): + super().__init__( + config=SeedFromRecordsConfig(records=records), + workflow_label=workflow_label, + client=client, + ) + + @property + def name(self) -> str: + return "seed_from_records" + + def run(self) -> TaskOutput: + return self._run(self.config.records) diff --git a/src/gretel_client/navigator/tasks/types.py b/src/gretel_client/navigator/tasks/types.py new file mode 100644 index 0000000..054f4d0 --- /dev/null +++ b/src/gretel_client/navigator/tasks/types.py @@ -0,0 +1,151 @@ +from enum import Enum +from typing import Any, Optional, Union + +from annotated_types import Len +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Self + +from gretel_client.config import get_session_config + +MAX_NUM_DATA_SEED_VALUES = 25 +MAX_NUM_NESTED_DATA_SEEDS = 5 + +RecordsT = list[dict[str, Any]] +SeedValueT = Union[str, int, bool] + + +class ModelSuite(str, Enum): + APACHE_2_0 = "Apache-2.0" + LLAMA_3_x = "Llama 3.x" + + +DEFAULT_MODEL_SUITE = ModelSuite.APACHE_2_0 + + +def check_model_suite(model_suite: Union[ModelSuite, str]) -> str: + is_gretel_dev = get_session_config().stage == "dev" + + if not is_gretel_dev: + # Make sure that the model_suite is a valid ModelSuite enum. + # Why? Faster feedback for users who are using the wrong model suite. + return ModelSuite(model_suite).value + + # Allow for more flexibility in dev mode. + if isinstance(model_suite, ModelSuite): + return model_suite.value + return model_suite + + +class OutputColumnType(str, Enum): + TEXT = "text" + DICT = "dict" + LIST = "list" + CODE = "code" + + +class LLMType(str, Enum): + NL = "nl" + CODE = "code" + JUDGE = "judge" + + +class TextParserType(str, Enum): + EXTRACT_CODE = "extract_code" + JSON = "json" + JSON_ARRAY = "json_array" + PASS_THROUGH = "pass_through" + + +class LLMJudgePromptTemplateType(str, Enum): + NL2PYTHON = "nl2python" + NL2SQL = "nl2sql" + + +class CodeLang(str, Enum): + PYTHON = "python" + + # SQL dialects match the SQLFluff naming conventions. + ANSI = "ansi" + TSQL = "tsql" + BIGQUERY = "bigquery" + MYSQL = "mysql" + POSTGRES = "postgres" + + @classmethod + def validate(cls, value: Union[str, Self]) -> Self: + try: + return cls(value) + except ValueError: + raise ValueError( + f"Unsupported code language: {value}\n" + f"Supported code languages: {', '.join([x.value for x in cls])}" + ) + + def is_sql_dialect(self) -> bool: + return self in SQL_DIALECTS + + def to_syntax_lexer(self) -> str: + if self == CodeLang.PYTHON: + return "python" + elif self == CodeLang.ANSI: + return "sql" + elif self == CodeLang.TSQL: + return "tsql" + elif self == CodeLang.BIGQUERY: + return "sql" + elif self == CodeLang.MYSQL: + return "mysql" + elif self == CodeLang.POSTGRES: + return "postgres" + else: + raise ValueError(f"Unsupported code language: {self}") + + +SQL_DIALECTS = { + CodeLang.ANSI, + CodeLang.TSQL, + CodeLang.BIGQUERY, + CodeLang.MYSQL, + CodeLang.POSTGRES, +} + + +class DataSeedStatus(str, Enum): + NO_GENERATION = "no_generation" + NEEDS_GENERATION = "needs_generation" + GENERATED = "generated" + + +class ValidatorType(str, Enum): + CODE = "code" + + +class SeedColumn(BaseModel): + name: str + description: Optional[str] = None + + @classmethod + def from_dicts(cls, seeds: list[dict]) -> list[Self]: + return [cls(**seed) for seed in seeds] + + +class SeedSubcategory(SeedColumn): + num_values_to_generate: int = Field(default=1, gt=0, le=MAX_NUM_DATA_SEED_VALUES) + generated_values: dict[str, list[SeedValueT]] = {} + + +class SeedCategory(SeedColumn): + values: list[SeedValueT] = Field(default=[]) + weights: list[float] = Field(default=[]) + num_values_to_generate: Optional[int] = Field( + default=None, gt=0, le=MAX_NUM_DATA_SEED_VALUES + ) + subcategories: Annotated[ + list[SeedSubcategory], Len(max_length=MAX_NUM_NESTED_DATA_SEEDS) + ] = [] + quality_rank: Optional[int] = None + generated_values: list[SeedValueT] = [] + + +class CategoricalDataSeeds(BaseModel): + seed_categories: list[SeedCategory] diff --git a/src/gretel_client/navigator/tasks/validate/__init__.py b/src/gretel_client/navigator/tasks/validate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/tasks/validate/validate_code.py b/src/gretel_client/navigator/tasks/validate/validate_code.py new file mode 100644 index 0000000..48ae0b2 --- /dev/null +++ b/src/gretel_client/navigator/tasks/validate/validate_code.py @@ -0,0 +1,48 @@ +from typing import Optional, Union + +from pydantic import BaseModel + +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.io import Dataset +from gretel_client.navigator.tasks.types import CodeLang, RecordsT + +VALIDATE_PYTHON_ADDED_COLUMN_PREFIXES = [ + "_is_valid", + "_score", + "_severity", + "_messages", +] + +VALIDATE_SQL_ADDED_COLUMN_PREFIXES = [ + "_is_valid", + "_messages", +] + + +class ValidateCodeConfig(BaseModel): + code_lang: CodeLang + code_columns: list[str] = ["code"] + + +class ValidateCode(Task): + + def __init__( + self, + code_lang: CodeLang, + code_columns: list[str] = ["code"], + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + ): + super().__init__( + config=ValidateCodeConfig(code_lang=code_lang, code_columns=code_columns), + workflow_label=workflow_label, + client=client, + ) + + @property + def name(self) -> str: + return "validate_code" + + def run(self, dataset: Union[Dataset, RecordsT]) -> TaskOutput: + return self._run(self._records_to_dataset_if_needed(dataset)) diff --git a/src/gretel_client/navigator/workflow.py b/src/gretel_client/navigator/workflow.py new file mode 100644 index 0000000..d943c47 --- /dev/null +++ b/src/gretel_client/navigator/workflow.py @@ -0,0 +1,383 @@ +import json +import logging +import time + +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import yaml + +from pydantic import BaseModel +from typing_extensions import Self + +from gretel_client.analysis_utils import display_dataframe_in_notebook +from gretel_client.navigator.client.interface import Client, TaskOutput +from gretel_client.navigator.client.remote import Message +from gretel_client.navigator.client.utils import get_navigator_client +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.io import Dataset +from gretel_client.navigator.tasks.types import ( + check_model_suite, + DEFAULT_MODEL_SUITE, + ModelSuite, +) +from gretel_client.projects import Project +from gretel_client.rest_v1.api.workflows_api import WorkflowsApi +from gretel_client.rest_v1.api_client import ApiClient +from gretel_client.workflows.logs import print_logs_for_workflow_run + +logger = get_logger(__name__, level=logging.INFO) + +DEFAULT_WORKFLOW_NAME = "navigator-workflow" + +TASK_TYPE_EMOJI_MAP = { + "generate": "šŸ¦œ", + "validate": "šŸ”", + "sample": "šŸŒ±", + "seed": "šŸŒ±", + "load": "šŸ“„", +} + + +def _get_task_log_emoji(task_name: str) -> str: + log_emoji = "" + for task_type, emoji in TASK_TYPE_EMOJI_MAP.items(): + if task_name.startswith(task_type): + log_emoji = emoji + " " + return log_emoji + + +def get_output_seeds_with_generation( + outputs_by_step: dict[str, TaskOutput] +) -> Optional[dict]: + seeds = [ + v for k, v in outputs_by_step.items() if "generate-seed-category-values" in k + ] + return None if len(seeds) == 0 else seeds[0] + + +@dataclass +class PreviewResults: + output: Dataset + outputs_by_step: dict[str, TaskOutput] + + def display_dataframe_in_notebook( + self, num_records: int = 10, settings: Optional[dict] = None + ) -> None: + """Display preview as pandas DataFrame in notebook with better settings for readability. + + This function is intended to be used in a Jupyter notebook. + + Args: + num_records: The number of records to display. + settings: Optional properties to set on the DataFrame's style. + If None, default settings with text wrapping are used. + """ + if isinstance(self.output, pd.DataFrame): + display_dataframe_in_notebook( + self.output.head(num_records), settings=settings + ) + else: + raise ValueError("Workflow output is not a DataFrame.") + + +_TERMINAL_STATUSES = [ + "RUN_STATUS_COMPLETED", + "RUN_STATUS_ERROR", + "RUN_STATUS_CANCELLED", +] + + +class BatchWorkflowRun: + workflow_id: str + workflow_run_id: str + _client: Client + _project: Project + _workflow_api: ApiClient + + def __init__( + self, project: Project, client: Client, workflow_id: str, workflow_run_id: str + ): + self.workflow_id = workflow_id + self.workflow_run_id = workflow_run_id + self._client = client + self._project = project + self._workflow_api = project.session.get_v1_api(WorkflowsApi) + + @property + def console_url(self) -> str: + return ( + f"{self._project.get_console_url().replace(self._project.project_guid, '')}workflows/" + f"{self.workflow_id}/runs/{self.workflow_run_id}" + ) + + def wait_for_completion(self) -> None: + logger.info(f"šŸ‘€ Follow along -> {self.console_url}") + while True: + if self._reached_terminal_status(): + break + time.sleep(10) + + def run_status(self) -> str: + run = self._workflow_api.get_workflow_run(workflow_run_id=self.workflow_run_id) + return run.status + + def _reached_terminal_status(self) -> bool: + status = self.run_status() + return status in _TERMINAL_STATUSES + + def poll_logs(self) -> None: + print_logs_for_workflow_run(self.workflow_run_id, self._project.session) + + def get_step_output( + self, step_name: str, format: Optional[str] = None + ) -> TaskOutput: + return self._client.get_step_output( + workflow_run_id=self.workflow_run_id, + step_name=step_name, + format=format, + ) + + def download_step_output( + self, + step_name: str, + format: Optional[str] = None, + output_dir: Union[str, Path] = ".", + ) -> Path: + return self._client.download_step_output( + workflow_run_id=self.workflow_run_id, + step_name=step_name, + output_dir=Path(output_dir), + format=format, + ) + + +class NavigatorWorkflowJobResults: + + def __init__( + self, + *, + workflow_step_names: list[str], + batch_workflow_results: BatchWorkflowRun, + final_dataset_step_name: Optional[str] = None, + ): + self.workflow_step_names = workflow_step_names + self.final_dataset_step_name = final_dataset_step_name + self._batch_workflow_results = batch_workflow_results + + @property + def workflow_id(self) -> str: + return self._batch_workflow_results.workflow_id + + @property + def workflow_run_id(self) -> str: + return self._batch_workflow_results.workflow_run_id + + @property + def console_url(self) -> str: + return self._batch_workflow_results.console_url + + def fetch_dataset(self, wait_for_completion: bool = False) -> Dataset: + if self.final_dataset_step_name is None: + raise ValueError("No dataset step was found in the workflow results") + status = self._batch_workflow_results.run_status() + if status == "RUN_STATUS_COMPLETED": + logger.info("āœ… Fetching dataset from completed workflow run") + return self._batch_workflow_results.get_step_output( + self.final_dataset_step_name + ) + elif status in {"RUN_STATUS_ERROR", "RUN_STATUS_LOST"}: + logger.error("šŸ›‘ Workflow run failed. Cannot fetch dataset.") + elif status in { + "RUN_STATUS_PENDING", + "RUN_STATUS_CREATED", + "RUN_STATUS_ACTIVE", + }: + logger.info( + "šŸ—ļø We are still building your dataset. " + f"Workflow status: {status.split('_')[-1]}." + ) + if wait_for_completion: + logger.info("ā³ Waiting for workflow run to complete...") + self._batch_workflow_results.wait_for_completion() + return self.fetch_dataset() + elif status in {"RUN_STATUS_CANCELLING", "RUN_STATUS_CANCELLED"}: + logger.warning("Workflow run was cancelled.") + else: + logger.warning(f"Unknown workflow status: {status}") + + +class Step(BaseModel): + name: Optional[str] = None + task: str + config: dict + inputs: Optional[list[str]] = [] + + +class NavigatorWorkflow: + def __init__( + self, + *, + steps: Optional[list[Step]] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + workflow_name: Optional[str] = None, + **session_kwargs, + ): + self._workflow_name = ( + workflow_name + or f"{DEFAULT_WORKFLOW_NAME}-{datetime.now().isoformat(timespec='seconds')}" + ) + self._steps = steps or [] + self._client = get_navigator_client(**session_kwargs) + self._model_suite = check_model_suite(model_suite) + self._globals = { + "num_records": 10, + "model_suite": self._model_suite, + } + self._task_io = {} + # Create a mapping of task names to their inputs and output. + # This is helpful for finding the last step to emit a dataset. + for task in self._client.registry(): + self._task_io[task["name"]] = { + "inputs": task["inputs"], + "output": task["output"], + } + + @staticmethod + def create_steps_from_sequential_tasks(task_list: list[Task]) -> list[Step]: + steps = [] + step_names = [] + for i in range(len(task_list)): + inputs = [] + task = task_list[i] + suffix = "" if task.workflow_label is None else f"-{task.workflow_label}" + step_names.append( + f"{task.name}-{i + 1}{suffix}".replace("_", "-").replace(" ", "-") + ) + if i > 0: + prev_name = step_names[i - 1] + inputs = [prev_name] + steps.append( + Step( + name=step_names[i], + task=task.name, + config=task.config.model_dump(), + inputs=inputs, + ) + ) + return steps + + @classmethod + def from_sequential_tasks( + cls, task_list: list[Task], workflow_name: str = None, **session_kwargs + ) -> Self: + workflow = cls(workflow_name=workflow_name, **session_kwargs) + workflow.add_steps(cls.create_steps_from_sequential_tasks(task_list)) + return workflow + + @classmethod + def from_yaml(cls, yaml_str: str) -> Self: + yaml_dict = yaml.safe_load(yaml_str) + workflow = cls(workflow_name=yaml_dict["name"]) + workflow.add_steps([Step(**step) for step in yaml_dict["steps"]]) + workflow.set_globals(yaml_dict.get("globals", None)) + return workflow + + def _execute_preview(self, verbose: bool = False) -> PreviewResults: + step_idx = 0 + message: Message + current_step = None + final_output = None + outputs_by_step = {} + for message in self._client.get_workflow_preview(self.to_dict()): + if current_step != message.step: + current_step = message.step + task_name = self._steps[step_idx].task.replace("_", "-") + step_name = message.step.replace("-" + str(step_idx + 1), "") + label = ( + "" + if task_name == step_name + else f" >>{step_name.split(task_name)[-1].replace('-', ' ')}" + ) + logger.info( + f"{_get_task_log_emoji(task_name)}Step {step_idx + 1}: " + f"{task_name.replace('-', ' ').capitalize()}{label}" + ) + step_idx += 1 + + # todo: make this log level aware + if message.stream == "logs": + level, msg = message.payload.get("level"), message.payload.get("msg") + if (level == "info" and verbose) or level == "error": + logger.info(f" |-- {msg}") + + if message.stream == "step_outputs": + logger.debug(f"Step output: {json.dumps(message.payload, indent=4)}") + + output = message.payload + if message.type == "dataset": + output = pd.DataFrame.from_records(message.payload.get("dataset")) + final_output = output + outputs_by_step[message.step] = output + # the final output is either the dataset produced by the last + # task in the workflow, or, if no dataset is produced by the workflow + # the final output will be the output of the last task. + if final_output is None: + final_output = outputs_by_step[current_step] + return PreviewResults(output=final_output, outputs_by_step=outputs_by_step) + + def add_step(self, step: Step) -> None: + self._steps.append(step) + + def add_steps(self, steps: list[Step]) -> None: + self._steps.extend(steps) + + def reset_steps(self) -> None: + self._steps = [] + + def to_dict(self) -> dict: + return dict( + name=self._workflow_name, + steps=list( + map(lambda x: x.model_dump() if isinstance(x, Step) else x, self._steps) + ), + globals=self._globals or {}, + ) + + def to_json(self, file_path: Optional[Union[Path, str]] = None) -> Optional[str]: + json_str = json.dumps(self.to_dict(), indent=4) + if file_path is None: + return json_str + with open(file_path, "w") as f: + f.write(json_str) + + def to_yaml(self, file_path: Optional[Union[Path, str]] = None) -> Optional[str]: + yaml_str = yaml.dump(json.loads(self.to_json()), default_flow_style=False) + if file_path is None: + return yaml_str + with open(file_path, "w") as f: + f.write(yaml_str) + + def generate_dataset_preview( + self, *, verbose_logging: bool = False + ) -> PreviewResults: + logger.info("šŸš€ Generating dataset preview") + preview = self._execute_preview(verbose=verbose_logging) + logger.info("šŸ‘€ Your dataset preview is ready for a peek!") + return preview + + def submit_batch_job(self, num_records: int, project_name: Optional[str] = None): + self._globals.update({"num_records": num_records}) + response = self._client.submit_batch_workflow( + self.to_dict(), num_records, project_name + ) + return BatchWorkflowRun( + workflow_id=response.workflow_id, + workflow_run_id=response.workflow_run_id, + client=self._client, + project=response.project, + )