Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: aca1015675b80a848e68c639374255f032b4728d
  • Loading branch information
Gretel Team authored and johnnygreco committed Oct 29, 2024
1 parent de61500 commit 479f6d7
Show file tree
Hide file tree
Showing 28 changed files with 2,189 additions and 2 deletions.
2 changes: 2 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ backports.cached-property==1.0.0.post2
certifi>=2021.10.8
click>=8.1.3,<9
docker==7.1.0
inflection==0.5.1
kubernetes>=28.1.0
opentelemetry-distro>=0.44b0
opentelemetry-exporter-prometheus>=0.44b0
Expand All @@ -10,6 +11,7 @@ pydantic>=2
python_dateutil>=2.8.0
pyyaml==6.0.1
requests>=2.25,<3
rich==13.7.1
smart_open>=2.1.0,<6.0
tabulate==0.8.9
tenacity==8.2.2
Expand Down
3 changes: 1 addition & 2 deletions src/gretel_client/inference_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from gretel_client.config import ClientConfig, configure_session, get_session_config
from gretel_client.rest.api_client import ApiClient
from gretel_client.rest.configuration import Configuration

MODELS_API_PATH = "/v1/inference/models"

Expand Down Expand Up @@ -161,7 +160,7 @@ def __init__(
elif len(session_kwargs) > 0:
raise ValueError("cannot specify session arguments when passing a session")

if session.default_runner != "cloud" and not ".serverless." in session.endpoint:
if session.default_runner != "cloud" and ".serverless." not in session.endpoint:
raise GretelInferenceAPIError(
"Gretel's Inference API is currently only "
"available within Gretel Cloud. Your current runner "
Expand Down
2 changes: 2 additions & 0 deletions src/gretel_client/navigator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from gretel_client.navigator.data_designer.interface import DataDesigner
from gretel_client.navigator.workflow import NavigatorWorkflow
Empty file.
130 changes: 130 additions & 0 deletions src/gretel_client/navigator/client/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Iterator, Optional, Type, TypeVar, Union

import pandas as pd

from gretel_client.projects import Project


def get_client(adapter: Union[Type[ClientAdapter], ClientAdapter]) -> Client:
if not isinstance(adapter, ClientAdapter):
adapter = adapter()
return Client(adapter)


@dataclass
class SubmitBatchWorkflowResponse:
project: Project
workflow_id: str
workflow_run_id: str


class Client:

_adapter: ClientAdapter

def __init__(self, adapter: ClientAdapter):
self._adapter = adapter

def run_task(
self,
name: str,
config: dict,
inputs: Optional[list[TaskInput]] = None,
globals: Optional[dict] = None,
verbose: bool = False,
) -> TaskOutput:
if inputs is None:
inputs = []
if globals is None:
globals = {}
return self._adapter.run_task(name, config, inputs, globals, verbose)

def get_workflow_preview(self, workflow_config: dict) -> Iterator:
return self._adapter.stream_workflow_outputs(workflow_config)

def submit_batch_workflow(
self,
workflow_config: dict,
num_records: int,
project_name: Optional[str] = None,
) -> SubmitBatchWorkflowResponse:
return self._adapter.submit_batch_workflow(
workflow_config, num_records, project_name
)

def get_step_output(
self,
workflow_run_id: str,
step_name: str,
format: Optional[str] = None,
) -> TaskOutput:
return self._adapter.get_step_output(workflow_run_id, step_name, format)

def download_step_output(
self,
workflow_run_id: str,
step_name: str,
output_dir: Path,
format: Optional[str] = None,
) -> Path:
return self._adapter.download_step_output(
workflow_run_id, step_name, output_dir, format
)

def registry(self) -> list[dict]:
return self._adapter.registry()


TaskInput = TypeVar("TaskInput")
TaskOutput = Union[pd.DataFrame, dict]


class ClientAdapter(ABC, Generic[TaskInput]):

@abstractmethod
def run_task(
self,
name: str,
config: dict,
inputs: list[TaskInput],
globals: dict,
verbose: bool = False,
) -> TaskOutput: ...

@abstractmethod
def stream_workflow_outputs(
self, workflow: dict, verbose: bool = False
) -> Iterator[dict]: ...

@abstractmethod
def registry(self) -> list[dict]: ...

def submit_batch_workflow(
self,
workflow_config: dict,
num_records: int,
project_name: Optional[str] = None,
) -> SubmitBatchWorkflowResponse:
raise NotImplementedError("Cannot submit batch Workflows")

def get_step_output(
self,
workflow_run_id: str,
step_name: str,
format: Optional[str] = None,
) -> TaskOutput:
raise NotImplementedError("Cannot get batch step outputs")

def download_step_output(
self,
workflow_run_id: str,
step_name: str,
output_dir: Path,
format: Optional[str] = None,
) -> Path:
raise NotImplementedError("Cannot download batch artifacts")
Loading

0 comments on commit 479f6d7

Please sign in to comment.