Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into thomasjpfan/neptu…
Browse files Browse the repository at this point in the history
…ne_pr

Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan committed Aug 21, 2024
2 parents 770c586 + e3036f0 commit e36c2c6
Show file tree
Hide file tree
Showing 28 changed files with 593 additions and 108 deletions.
10 changes: 5 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
ARG PYTHON_VERSION
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim-bookworm

MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.authors="Flyte Team <[email protected]>"
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

WORKDIR /root
ENV PYTHONPATH /root
ENV FLYTE_SDK_RICH_TRACEBACKS 0
ENV PYTHONPATH=/root
ENV FLYTE_SDK_RICH_TRACEBACKS=0

ARG VERSION
ARG DOCKER_IMAGE
Expand Down Expand Up @@ -35,4 +35,4 @@ RUN apt-get update && apt-get install build-essential -y \

USER flytekit

ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE"
ENV FLYTE_INTERNAL_IMAGE="$DOCKER_IMAGE"
8 changes: 4 additions & 4 deletions Dockerfile.agent
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM python:3.10-slim-bookworm as agent-slim
FROM python:3.10-slim-bookworm AS agent-slim

MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.authors="Flyte Team <[email protected]>"
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG VERSION
Expand All @@ -19,9 +19,9 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \
&& rm -rf /var/lib/{apt,dpkg,cache,log}/ \
&& :

CMD pyflyte serve agent --port 8000
CMD ["pyflyte", "serve", "agent", "--port", "8000"]

FROM agent-slim as agent-all
FROM agent-slim AS agent-all
ARG VERSION

RUN pip install --no-cache-dir -U \
Expand Down
11 changes: 4 additions & 7 deletions Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
# From your test user code
# $ pyflyte run --image localhost:30000/flytekittest:someversion

ARG PYTHON_VERSION
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim-bookworm

MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.authors="Flyte Team <[email protected]>"
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

WORKDIR /root
ENV FLYTE_SDK_RICH_TRACEBACKS 0
ENV FLYTE_SDK_RICH_TRACEBACKS=0

# Flytekit version of flytekit to be installed in the image
ARG PSEUDO_VERSION
ARG PSEUDO_VERSION=1.13.3


# Note: Pod tasks should be exposed in the default image
Expand Down Expand Up @@ -50,8 +50,5 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \
&& chown flytekit: /home \
&& :


ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:"

# Switch to the 'flytekit' user for better security.
USER flytekit
7 changes: 7 additions & 0 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import signal
import subprocess
import sys
import tempfile
import traceback
from sys import exit
Expand Down Expand Up @@ -376,6 +377,9 @@ def _execute_task(
dynamic_addl_distro,
dynamic_dest_dir,
) as ctx:
working_dir = os.getcwd()
if all(os.path.realpath(path) != working_dir for path in sys.path):
sys.path.append(working_dir)
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand Down Expand Up @@ -424,6 +428,9 @@ def _execute_map_task(
with setup_execution(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
working_dir = os.getcwd()
if all(os.path.realpath(path) != working_dir for path in sys.path):
sys.path.append(working_dir)
task_index = _compute_array_job_index()
mtr = load_object_from_module(resolver)()
map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency)
Expand Down
42 changes: 30 additions & 12 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import os
import re
import threading
import time
import typing
import urllib.parse as _urlparse
import webbrowser
Expand Down Expand Up @@ -236,6 +238,9 @@ def __init__(
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or requests.Session()
self._lock = threading.Lock()
self._cached_credentials = None
self._cached_credentials_ts = None

self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand Down Expand Up @@ -339,25 +344,38 @@ def _request_access_token(self, auth_code) -> Credentials:

def get_creds_from_remote(self) -> Credentials:
"""
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
retrieve credentials
This is the entrypoint method. It will kickoff the full authentication
flow and trigger a web-browser to retrieve credentials. Because this
needs to open a port on localhost and may be called from a
multithreaded context (e.g. pyflyte register), this call may block
multiple threads and return a cached result for up to 60 seconds.
"""
# In the absence of globally-set token values, initiate the token request flow
q = Queue()
with self._lock:
# Clear cache if it's been more than 60 seconds since the last check
cache_ttl_s = 60
if self._cached_credentials_ts is not None and self._cached_credentials_ts + cache_ttl_s < time.monotonic():
self._cached_credentials = None

# First prepare the callback server in the background
server = self._create_callback_server()
if self._cached_credentials is not None:
return self._cached_credentials
q = Queue()

self._request_authorization_code()
# First prepare the callback server in the background
server = self._create_callback_server()

server.handle_request(q)
server.server_close()
self._request_authorization_code()

# Send the call to request the authorization code in the background
server.handle_request(q)
server.server_close()

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
# Send the call to request the authorization code in the background

# Request the access token once the auth code has been received.
auth_code = q.get()
self._cached_credentials = self._request_access_token(auth_code)
self._cached_credentials_ts = time.monotonic()
return self._cached_credentials

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def get_upload_signed_url(
)
)
except Exception as e:
raise RuntimeError(f"Failed to get signed url for {filename}, reason: {e}")
raise RuntimeError(f"Failed to get signed url for {filename}.") from e

def get_download_signed_url(
self, native_url: str, expires_in: datetime.timedelta = None
Expand Down
4 changes: 3 additions & 1 deletion flytekit/clients/grpc_utils/wrap_exception_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import grpc

from flytekit.exceptions.base import FlyteException
from flytekit.exceptions.system import FlyteSystemException
from flytekit.exceptions.system import FlyteSystemException, FlyteSystemUnavailableException
from flytekit.exceptions.user import (
FlyteAuthenticationException,
FlyteEntityAlreadyExistsException,
Expand All @@ -28,6 +28,8 @@ def _raise_if_exc(request: typing.Any, e: Union[grpc.Call, grpc.Future]):
raise FlyteEntityNotExistException() from e
elif e.code() == grpc.StatusCode.INVALID_ARGUMENT:
raise FlyteInvalidInputException(request) from e
elif e.code() == grpc.StatusCode.UNAVAILABLE:
raise FlyteSystemUnavailableException() from e
raise FlyteSystemException() from e

def intercept_unary_unary(self, continuation, client_call_details, request):
Expand Down
107 changes: 99 additions & 8 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import sys
import tempfile
import typing
import typing as t
from dataclasses import dataclass, field, fields
from typing import Iterator, get_args

import rich_click as click
import yaml
from click import Context
from mashumaro.codecs.json import JSONEncoder
from rich.progress import Progress
from typing_extensions import get_origin
Expand All @@ -25,7 +28,12 @@
pretty_print_exception,
project_option,
)
from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.configuration import (
DefaultImages,
FastSerializationSettings,
ImageConfig,
SerializationSettings,
)
from flytekit.configuration.plugin import get_plugin
from flytekit.core import context_manager
from flytekit.core.artifact import ArtifactQuery
Expand All @@ -34,14 +42,24 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.exceptions.system import FlyteSystemException
from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback, labels_callback
from flytekit.interaction.click_types import (
FlyteLiteralConverter,
key_value_callback,
labels_callback,
)
from flytekit.interaction.string_literals import literal_string_repr
from flytekit.loggers import logger
from flytekit.models import security
from flytekit.models.common import RawOutputDataConfig
from flytekit.models.interface import Parameter, Variable
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow, remote_fs
from flytekit.remote import (
FlyteLaunchPlan,
FlyteRemote,
FlyteTask,
FlyteWorkflow,
remote_fs,
)
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.tools import module_loader
from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules
Expand Down Expand Up @@ -489,7 +507,8 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder:
return ctx.current_context().new_builder()

file_access = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=output_prefix
local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
raw_output_prefix=output_prefix,
)

# The task might run on a remote machine if raw_output_prefix is a remote path,
Expand Down Expand Up @@ -539,7 +558,10 @@ def _run(*args, **kwargs):
entity_type = "workflow" if isinstance(entity, PythonFunctionWorkflow) else "task"
logger.debug(f"Running {entity_type} {entity.name} with input {kwargs}")

click.secho(f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", fg="cyan")
click.secho(
f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.",
fg="cyan",
)
try:
inputs = {}
for input_name, v in entity.python_interface.inputs_with_defaults.items():
Expand Down Expand Up @@ -576,6 +598,8 @@ def _run(*args, **kwargs):
)
if processed_click_value is not None or optional_v:
inputs[input_name] = processed_click_value
if processed_click_value is None and v[0] == bool:
inputs[input_name] = False

if not run_level_params.is_remote:
with FlyteContextManager.with_context(_update_flyte_context(run_level_params)):
Expand Down Expand Up @@ -755,7 +779,10 @@ def list_commands(self, ctx):
run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
progress = Progress(transient=True)
task = progress.add_task(f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", total=None)
task = progress.add_task(
f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...",
total=None,
)
with progress:
progress.start_task(task)
try:
Expand Down Expand Up @@ -783,6 +810,70 @@ def get_command(self, ctx, name):
)


class YamlFileReadingCommand(click.RichCommand):
def __init__(
self,
name: str,
params: typing.List[click.Option],
help: str,
callback: typing.Callable = None,
):
params.append(
click.Option(
["--inputs-file"],
required=False,
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
help="Path to a YAML | JSON file containing inputs for the workflow.",
)
)
super().__init__(name=name, params=params, callback=callback, help=help)

def parse_args(self, ctx: Context, args: t.List[str]) -> t.List[str]:
def load_inputs(f: str) -> t.Dict[str, str]:
try:
inputs = yaml.safe_load(f)
except yaml.YAMLError as e:
yaml_e = e
try:
inputs = json.loads(f)
except json.JSONDecodeError as e:
raise click.BadParameter(
message=f"Could not load the inputs file. Please make sure it is a valid JSON or YAML file."
f"\n json error: {e},"
f"\n yaml error: {yaml_e}",
param_hint="--inputs-file",
)

return inputs

inputs = {}
if "--inputs-file" in args:
idx = args.index("--inputs-file")
args.pop(idx)
f = args.pop(idx)
with open(f, "r") as f:
inputs = load_inputs(f.read())
elif not sys.stdin.isatty():
f = sys.stdin.read()
if f != "":
inputs = load_inputs(f)

new_args = []
for k, v in inputs.items():
if isinstance(v, str):
new_args.extend([f"--{k}", v])
elif isinstance(v, bool):
if v:
new_args.append(f"--{k}")
else:
v = json.dumps(v)
new_args.extend([f"--{k}", v])
new_args.extend(args)
args = new_args

return super().parse_args(ctx, args)


class WorkflowCommand(click.RichGroup):
"""
click multicommand at the python file layer, subcommands should be all the workflows in the file.
Expand Down Expand Up @@ -837,11 +928,11 @@ def _create_command(
h = f"{click.style(entity_type, bold=True)} ({run_level_params.computed_params.module}.{entity_name})"
if loaded_entity.__doc__:
h = h + click.style(f"{loaded_entity.__doc__}", dim=True)
cmd = click.RichCommand(
cmd = YamlFileReadingCommand(
name=entity_name,
params=params,
callback=run_command(ctx, loaded_entity),
help=h,
callback=run_command(ctx, loaded_entity),
)
return cmd

Expand Down
Loading

0 comments on commit e36c2c6

Please sign in to comment.