diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 1a9f74f114..5ba43b78dc 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -190,6 +190,7 @@ from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase from flytekit.models.core.types import BlobType +from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.types import directory, file, numpy, schema diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index cb228b788a..5d65f8c3ca 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -463,7 +463,7 @@ class GCSConfig(object): gsutil_parallelism: bool = False @classmethod - def auto(self, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: config_file = get_config_file(config_file) kwargs = {} kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) @@ -647,6 +647,7 @@ class SerializationSettings(object): domain: typing.Optional[str] = None version: typing.Optional[str] = None env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -719,6 +720,7 @@ def new_builder(self) -> Builder: version=self.version, image_config=self.image_config, env=self.env.copy() if self.env else None, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, @@ -768,6 +770,7 @@ class Builder(object): version: str image_config: ImageConfig env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -783,6 +786,7 @@ def build(self) -> SerializationSettings: version=self.version, image_config=self.image_config, env=self.env, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index 467f660d42..793917cffe 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -224,7 +224,7 @@ def legacy_config(self) -> _configparser.ConfigParser: return self._legacy_config @property - def yaml_config(self) -> typing.Dict[str, Any]: + def yaml_config(self) -> typing.Dict[str, typing.Any]: return self._yaml_config diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index dccbaec803..491bed4385 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -45,6 +45,7 @@ from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext @@ -156,6 +157,7 @@ def __init__( metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, + docs: Optional[Documentation] = None, **kwargs, ): self._task_type = task_type @@ -164,6 +166,7 @@ def __init__( self._metadata = metadata if metadata else TaskMetadata() self._task_type_version = task_type_version self._security_ctx = security_ctx + self._docs = docs FlyteEntities.entities.append(self) @@ -195,6 +198,10 @@ def task_type_version(self) -> int: def security_context(self) -> SecurityContext: return self._security_ctx + @property + def docs(self) -> Documentation: + return self._docs + def get_type_for_input_var(self, k: str, v: Any) -> type: """ Returns the python native type for the given input variable @@ -390,6 +397,17 @@ def __init__( self._environment = environment if environment else {} self._task_config = task_config self._disable_deck = disable_deck + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + self._docs.short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs.long_description = Description(value=self._python_interface.docstring.long_description) # TODO lets call this interface and the other as flyte_interface? @property diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 63d7c8106f..954c1ae409 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -207,7 +207,6 @@ def transform_interface_to_typed_interface( """ if interface is None: return None - if interface.docstring is None: input_descriptions = output_descriptions = {} else: diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6e5b0a6b6a..bb24181338 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -7,6 +7,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -89,6 +90,7 @@ def task( secret_requests: Optional[List[Secret]] = None, execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + docs: Optional[Documentation] = None, disable_deck: bool = True, ) -> Union[Callable, PythonFunctionTask]: """ @@ -179,6 +181,7 @@ def foo2(): :param execution_mode: This is mainly for internal use. Please ignore. It is filled in automatically. :param task_resolver: Provide a custom task resolver. :param disable_deck: If true, this task will not output deck html file + :param docs: Documentation about this task """ def wrapper(fn) -> PythonFunctionTask: @@ -204,6 +207,7 @@ def wrapper(fn) -> PythonFunctionTask: execution_mode=execution_mode, task_resolver=task_resolver, disable_deck=disable_deck, + docs=docs, ) update_wrapper(task_instance, fn) return task_instance diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 468a5aa7ea..a1a1581e96 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -39,6 +39,7 @@ from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -168,6 +169,7 @@ def __init__( workflow_metadata: WorkflowMetadata, workflow_metadata_defaults: WorkflowMetadataDefaults, python_interface: Interface, + docs: Optional[Documentation] = None, **kwargs, ): self._name = name @@ -179,6 +181,20 @@ def __init__( self._unbound_inputs = set() self._nodes = [] self._output_bindings: List[_literal_models.Binding] = [] + self._docs = docs + + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + self._docs.short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs = Description(value=self._python_interface.docstring.long_description) + FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -186,6 +202,10 @@ def __init__( def name(self) -> str: return self._name + @property + def docs(self): + return self._docs + @property def short_name(self) -> str: return extract_obj_name(self._name) @@ -571,7 +591,8 @@ def __init__( workflow_function: Callable, metadata: Optional[WorkflowMetadata], default_metadata: Optional[WorkflowMetadataDefaults], - docstring: Docstring = None, + docstring: Optional[Docstring] = None, + docs: Optional[Documentation] = None, ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function @@ -586,6 +607,7 @@ def __init__( workflow_metadata=metadata, workflow_metadata_defaults=default_metadata, python_interface=native_interface, + docs=docs, ) @property @@ -690,6 +712,7 @@ def workflow( _workflow_function=None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, + docs: Optional[Documentation] = None, ): """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG @@ -718,6 +741,7 @@ def workflow( :param _workflow_function: This argument is implicitly passed and represents the decorated function. :param failure_policy: Use the options in flytekit.WorkflowFailurePolicy :param interruptible: Whether or not tasks launched from this workflow are by default interruptible + :param docs: Description entity for the workflow """ def wrapper(fn): @@ -730,6 +754,7 @@ def wrapper(fn): metadata=workflow_metadata, default_metadata=workflow_metadata_defaults, docstring=Docstring(callable_=fn), + docs=docs, ) workflow_instance.compile() update_wrapper(workflow_instance, fn) diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index f34e692123..e40307b6ba 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -1,13 +1,21 @@ +import typing + from flyteidl.admin import workflow_pb2 as _admin_workflow from flytekit.models import common as _common from flytekit.models.core import compiler as _compiler_models from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _core_workflow +from flytekit.models.documentation import Documentation class WorkflowSpec(_common.FlyteIdlEntity): - def __init__(self, template, sub_workflows): + def __init__( + self, + template: _core_workflow.WorkflowTemplate, + sub_workflows: typing.List[_core_workflow.WorkflowTemplate], + docs: typing.Optional[Documentation] = None, + ): """ This object fully encapsulates the specification of a workflow :param flytekit.models.core.workflow.WorkflowTemplate template: @@ -15,6 +23,7 @@ def __init__(self, template, sub_workflows): """ self._template = template self._sub_workflows = sub_workflows + self._docs = docs @property def template(self): @@ -30,6 +39,13 @@ def sub_workflows(self): """ return self._sub_workflows + @property + def docs(self): + """ + :rtype: Description entity for the workflow + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec @@ -37,6 +53,7 @@ def to_flyte_idl(self): return _admin_workflow.WorkflowSpec( template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + description=self._docs.to_flyte_idl() if self._docs else None, ) @classmethod @@ -48,6 +65,7 @@ def from_flyte_idl(cls, pb2_object): return cls( _core_workflow.WorkflowTemplate.from_flyte_idl(pb2_object.template), [_core_workflow.WorkflowTemplate.from_flyte_idl(s) for s in pb2_object.sub_workflows], + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, ) diff --git a/flytekit/models/documentation.py b/flytekit/models/documentation.py new file mode 100644 index 0000000000..e1bae8122e --- /dev/null +++ b/flytekit/models/documentation.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from flyteidl.admin import description_entity_pb2 + +from flytekit.models import common as _common_models + + +@dataclass +class Description(_common_models.FlyteIdlEntity): + """ + Full user description with formatting preserved. This can be rendered + by clients, such as the console or command line tools with in-tact + formatting. + """ + + class DescriptionFormat(Enum): + UNKNOWN = 0 + MARKDOWN = 1 + HTML = 2 + RST = 3 + + value: Optional[str] = None + uri: Optional[str] = None + icon_link: Optional[str] = None + format: DescriptionFormat = DescriptionFormat.RST + + def to_flyte_idl(self): + return description_entity_pb2.Description( + value=self.value if self.value else None, + uri=self.uri if self.uri else None, + format=self.format.value, + icon_link=self.icon_link, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.Description) -> "Description": + return cls( + value=pb2_object.value if pb2_object.value else None, + uri=pb2_object.uri if pb2_object.uri else None, + format=Description.DescriptionFormat(pb2_object.format), + icon_link=pb2_object.icon_link if pb2_object.icon_link else None, + ) + + +@dataclass +class SourceCode(_common_models.FlyteIdlEntity): + """ + Link to source code used to define this task or workflow. + """ + + link: Optional[str] = None + + def to_flyte_idl(self): + return description_entity_pb2.SourceCode(link=self.link) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.SourceCode) -> "SourceCode": + return cls(link=pb2_object.link) if pb2_object.link else None + + +@dataclass +class Documentation(_common_models.FlyteIdlEntity): + """ + DescriptionEntity contains detailed description for the task/workflow/launch plan. + Documentation could provide insight into the algorithms, business use case, etc. + Args: + short_description (str): One-liner overview of the entity. + long_description (Optional[Description]): Full user description with formatting preserved. + source_code (Optional[SourceCode]): link to source code used to define this entity + """ + + short_description: Optional[str] = None + long_description: Optional[Description] = None + source_code: Optional[SourceCode] = None + + def to_flyte_idl(self): + return description_entity_pb2.DescriptionEntity( + short_description=self.short_description, + long_description=self.long_description.to_flyte_idl() if self.long_description else None, + source_code=self.source_code.to_flyte_idl() if self.source_code else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.DescriptionEntity) -> "Documentation": + return cls( + short_description=pb2_object.short_description, + long_description=Description.from_flyte_idl(pb2_object.long_description) + if pb2_object.long_description + else None, + source_code=SourceCode.from_flyte_idl(pb2_object.source_code) if pb2_object.source_code else None, + ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f2ff5efd89..2129cdd88f 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -13,6 +13,7 @@ from flytekit.models import literals as _literals from flytekit.models import security as _sec from flytekit.models.core import identifier as _identifier +from flytekit.models.documentation import Documentation class Resources(_common.FlyteIdlEntity): @@ -480,11 +481,13 @@ def from_flyte_idl(cls, pb2_object): class TaskSpec(_common.FlyteIdlEntity): - def __init__(self, template): + def __init__(self, template: TaskTemplate, docs: typing.Optional[Documentation] = None): """ :param TaskTemplate template: + :param Documentation docs: """ self._template = template + self._docs = docs @property def template(self): @@ -493,11 +496,20 @@ def template(self): """ return self._template + @property + def docs(self): + """ + :rtype: Description entity for the task + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.tasks_pb2.TaskSpec """ - return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) + return _admin_task.TaskSpec( + template=self.template.to_flyte_idl(), description=self.docs.to_flyte_idl() if self.docs else None + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -505,7 +517,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.tasks_pb2.TaskSpec pb2_object: :rtype: TaskSpec """ - return cls(TaskTemplate.from_flyte_idl(pb2_object.template)) + return cls( + TaskTemplate.from_flyte_idl(pb2_object.template), + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, + ) class Task(_common.FlyteIdlEntity): diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index edd899d081..23c9803b07 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -19,6 +19,7 @@ from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 +from git import Repo from flytekit import Literal from flytekit.clients.friendly import SynchronousFlyteClient @@ -121,6 +122,17 @@ def _get_entity_identifier( ) +def _get_git_repo_url(source_path): + """ + Get git repo URL from remote.origin.url + """ + try: + return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1] + except Exception: + # If the file isn't in the git repo, we can't get the url from git config + return "" + + class FlyteRemote(object): """Main entrypoint for programmatically accessing a Flyte remote backend. @@ -790,11 +802,11 @@ def register_script( filename="scriptmode.tar.gz", ), ) - serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, + git_repo=_get_git_repo_url(source_path), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 50bac67844..3c9fe64068 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -12,7 +12,7 @@ from flytekit.models import launch_plan from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote -from flytekit.remote.remote import RegistrationSkipped +from flytekit.remote.remote import RegistrationSkipped, _get_git_repo_url from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities @@ -162,7 +162,7 @@ def load_packages_and_modules( :param options: :return: The common detected root path, the output of _find_project_root """ - + ss.git_repo = _get_git_repo_url(project_root) pkgs_and_modules = [] for pm in pkgs_or_mods: p = Path(pm).resolve() diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index ec2bdb0cb9..5ec249fa4b 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,9 +1,10 @@ +import sys import typing from collections import OrderedDict from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -23,6 +24,7 @@ from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import security from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model @@ -211,7 +213,8 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return TaskSpec(template=tt) + + return TaskSpec(template=tt, docs=entity.docs) def get_serializable_workflow( @@ -295,8 +298,9 @@ def get_serializable_workflow( nodes=serialized_nodes, outputs=entity.output_bindings, ) + return admin_workflow_models.WorkflowSpec( - template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()) + template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()), docs=entity.docs ) @@ -683,6 +687,16 @@ def get_serializable( else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") + if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): + # 1. Check if the size of long description exceeds 16KB + # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity + if entity.docs and entity.docs.long_description: + if entity.docs.long_description.value: + if sys.getsizeof(entity.docs.long_description.value) > 16 * 1024 * 1024: + raise ValueError( + "Long Description of the flyte entity exceeds the 16KB size limit. Please specify the uri in the long description instead." + ) + entity.docs.source_code = SourceCode(link=settings.git_repo) # This needs to be at the bottom not the top - i.e. dependent tasks get added before the workflow containing it entity_mapping[entity] = cp_entity return cp_entity diff --git a/requirements.txt b/requirements.txt index 5623078a25..22d31976ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with Python 3.7 -# by the following command: +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: # # make requirements.txt # @@ -66,9 +66,7 @@ idna==3.4 # via requests importlib-metadata==5.1.0 # via - # click # flytekit - # jsonschema # keyring jaraco-classes==3.2.3 # via keyring @@ -108,7 +106,6 @@ natsort==8.2.0 numpy==1.21.6 # via # -r requirements.in - # flytekit # pandas # pyarrow packaging==21.3 @@ -194,10 +191,7 @@ types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 # via - # arrow # flytekit - # importlib-metadata - # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json diff --git a/setup.py b/setup.py index 3d9710004f..b42f8a8490 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ # TODO: We should remove mentions to the deprecated numpy # aliases. More details in https://github.com/flyteorg/flyte/issues/3166 "numpy<1.24.0", + "gitpython", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 98af80638a..6e68c9d4be 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -207,7 +207,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 376 + assert len(tp) == 388 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 442851a8a2..db05de0ddb 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -4,6 +4,7 @@ from typing_extensions import Annotated # type: ignore +from flytekit import task from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.interface import ( @@ -320,3 +321,20 @@ def z(a: Foo) -> Foo: assert params.parameters["a"].default is None assert our_interface.outputs["o0"].__origin__ == FlytePickle assert our_interface.inputs["a"].__origin__ == FlytePickle + + +def test_doc_string(): + @task + def t1(a: int) -> int: + """Set the temperature value. + + The value of the temp parameter is stored as a value in + the class variable temperature. + """ + return a + + assert t1.docs.short_description == "Set the temperature value." + assert ( + t1.docs.long_description.value + == "The value of the temp parameter is stored as a value in\nthe class variable temperature." + ) diff --git a/tests/flytekit/unit/models/test_documentation.py b/tests/flytekit/unit/models/test_documentation.py new file mode 100644 index 0000000000..7702df0452 --- /dev/null +++ b/tests/flytekit/unit/models/test_documentation.py @@ -0,0 +1,29 @@ +from flytekit.models.documentation import Description, Documentation, SourceCode + + +def test_long_description(): + value = "long" + icon_link = "http://icon" + obj = Description(value=value, icon_link=icon_link) + assert Description.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.value == value + assert obj.icon_link == icon_link + assert obj.format == Description.DescriptionFormat.RST + + +def test_source_code(): + link = "https://github.com/flyteorg/flytekit" + obj = SourceCode(link=link) + assert SourceCode.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.link == link + + +def test_documentation(): + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + obj = Documentation(short_description=short_description, long_description=long_description, source_code=source_code) + assert Documentation.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.short_description == short_description + assert obj.long_description == long_description + assert obj.source_code == source_code diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index fcebf465f9..fed32b63aa 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -7,6 +7,7 @@ import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models +from flytekit import Description, Documentation, SourceCode from flytekit.models import literals, task, types from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -123,6 +124,60 @@ def test_task_template(in_tuple): assert obj.config == {"a": "b"} +def test_task_spec(): + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + ) + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + {"a": interface_models.Variable(int_type, "description1")}, + { + "b": interface_models.Variable(int_type, "description2"), + "c": interface_models.Variable(int_type, "description3"), + }, + ) + + resource = [task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1")] + resources = task.Resources(resource, resource) + + template = task.TaskTemplate( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + "python", + task_metadata, + interfaces, + {"a": 1, "b": {"c": 2, "d": 3}}, + container=task.Container( + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, + ), + config={"a": "b"}, + ) + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + obj = task.TaskSpec(template, docs) + assert task.TaskSpec.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.docs == docs + assert obj.template == template + + def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3a42f5af81..d229d0d5c9 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -5,8 +5,10 @@ from flytekit.models import task as _task from flytekit.models import types as _types from flytekit.models import workflow_closure as _workflow_closure +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow +from flytekit.models.documentation import Description, Documentation, SourceCode def test_workflow_closure(): @@ -81,3 +83,16 @@ def test_workflow_closure(): obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + workflow_spec = WorkflowSpec(template=template, sub_workflows=[], docs=docs) + assert WorkflowSpec.from_flyte_idl(workflow_spec.to_flyte_idl()) == workflow_spec + assert workflow_spec.docs.short_description == short_description + assert workflow_spec.docs.long_description == long_description + assert workflow_spec.docs.source_code == source_code