diff --git a/plugins/flytekit-kf-mpi/README.md b/plugins/flytekit-kf-mpi/README.md index 35c9444c421..f083f2f0a26 100644 --- a/plugins/flytekit-kf-mpi/README.md +++ b/plugins/flytekit-kf-mpi/README.md @@ -8,4 +8,67 @@ To install the plugin, run the following command: pip install flytekitplugins-kfmpi ``` -_Example coming soon!_ +## Code Example +MPI usage: +```python + @task( + task_config=MPIJob( + launcher=Launcher( + replicas=1, + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + ), + slots=2, + ), + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x +``` + + +Horovod Usage: +You can override the command of a replica group by: +```python + @task( + task_config=HorovodJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + worker=Worker( + replicas=1, + command=["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + restart_policy=RestartPolicy.NEVER, + ), + slots=2, + verbose=False, + log_level="INFO", + ), + ) + def my_horovod_task(): + ... +``` + + + + +## Upgrade MPI Plugin from V0 to V1 +MPI plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: +``` + task_config=MPIJob(num_workers=10), +``` +to +``` + task_config=MPIJob(worker=Worker(replicas=10)), +``` \ No newline at end of file diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index df5c74288ef..873d3d47bd2 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import HorovodJob, MPIJob +from .task import HorovodJob, MPIJob, Worker, Launcher, CleanPodPolicy, RunPolicy, RestartPolicy diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index e1c1be0a032..b399adace66 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -2,62 +2,82 @@ This Plugin adds the capability of running distributed MPI training to Flyte using backend plugins, natively on Kubernetes. It leverages `MPI Job `_ Plugin from kubeflow. """ -from dataclasses import dataclass -from typing import Any, Callable, Dict, List +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, Optional, Union, List -from flyteidl.plugins import mpi_pb2 as _mpi_task +from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.configuration import SerializationSettings from flytekit.extend import TaskPlugins -from flytekit.models import common as _common +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common - -class MPIJobModel(_common.FlyteIdlEntity): - """Model definition for MPI the plugin - - Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). - - num_launcher_replicas: Number of launcher server replicas to use - - slots: Number of slots per worker used in hostfile - .. note:: - - Please use resources=Resources(cpu="1"...) to specify per worker resource +@dataclass +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted """ - def __init__(self, num_workers, num_launcher_replicas, slots): - self._num_workers = num_workers - self._num_launcher_replicas = num_launcher_replicas - self._slots = slots - - @property - def num_workers(self): - return self._num_workers + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER - @property - def num_launcher_replicas(self): - return self._num_launcher_replicas +@dataclass +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. + """ - @property - def slots(self): - return self._slots + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING - def to_flyte_idl(self): - return _mpi_task.DistributedMPITrainingTask( - num_workers=self.num_workers, num_launcher_replicas=self.num_launcher_replicas, slots=self.slots - ) +@dataclass +class RunPolicy: + """ + RunPolicy describes some policy to apply to the execution of a kubeflow job. + Args: + clean_pod_policy: Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None. + ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs. + active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job. + can remain active before it is terminated. Must be a positive integer. This setting applies only to pods. + where restartPolicy is OnFailure or Always. + backoff_limit (int): Number of retries before marking this job as failed. + """ + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - num_workers=pb2_object.num_workers, - num_launcher_replicas=pb2_object.num_launcher_replicas, - slots=pb2_object.slots, - ) +@dataclass +class Worker: + """ + Worker replica configuration. Worker command can be customized. If not specified, the worker will use + default command generated by the mpi operator. + """ + command : Optional[List[str]] = None + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = 1 + restart_policy: Optional[RestartPolicy] = None + + +@dataclass +class Launcher: + """ + Launcher replica configuration. Launcher command can be customized. If not specified, the launcher will use + the command specified in the task signature. + """ + command : Optional[List[str]] = None + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = 1 + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -67,18 +87,20 @@ class MPIJob(object): to run distributed training on k8s with MPI Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). - - num_launcher_replicas: Number of launcher server replicas to use - - slots: Number of slots per worker used in hostfile - + launcher: Configuration for the launcher replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + slots: The number of slots per worker used in the hostfile. + num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. + num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job """ - - slots: int - num_launcher_replicas: int = 1 - num_workers: int = 1 + launcher: Launcher = field(default_factory=lambda: Launcher()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + slots: int = 1 + # Support v0 config for backwards compatibility + num_launcher_replicas: Optional[int] = None + num_workers: Optional[int] = None class MPIFunctionTask(PythonFunctionTask[MPIJob]): @@ -116,29 +138,86 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): task_type=self._MPI_JOB_TASK_TYPE, **kwargs, ) + + def _convert_replica_spec(self, replica_config: Union[Launcher, Worker]) -> mpi_task.DistributedMPITrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + return mpi_task.DistributedMPITrainingReplicaSpec( + command=replica_config.command, + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + + def _get_base_command(self, settings: SerializationSettings) -> List[str]: + return super().get_command(settings) def get_command(self, settings: SerializationSettings) -> List[str]: - cmd = super().get_command(settings) - num_procs = self.task_config.num_workers * self.task_config.slots + cmd = self._get_base_command(settings) + if self.task_config.num_workers: + num_workers = self.task_config.num_workers + else: + num_workers = self.task_config.worker.replicas + num_procs = num_workers * self.task_config.slots mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", settings.entrypoint_settings.path] + cmd # the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile return mpi_cmd def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = MPIJobModel( - num_workers=self.task_config.num_workers, - num_launcher_replicas=self.task_config.num_launcher_replicas, + worker = self._convert_replica_spec(self.task_config.worker) + if (self.task_config.num_workers): + worker.replicas = self.task_config.num_workers + + launcher = self._convert_replica_spec(self.task_config.launcher) + if (self.task_config.num_launcher_replicas): + launcher.replicas = self.task_config.num_launcher_replicas + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + mpi_job = mpi_task.DistributedMPITrainingTask( + worker_replicas=worker, + launcher_replicas=launcher, slots=self.task_config.slots, + run_policy=run_policy, ) - return MessageToDict(job.to_flyte_idl()) + return MessageToDict(mpi_job) @dataclass class HorovodJob(object): - slots: int - num_launcher_replicas: int = 1 - num_workers: int = 1 + """ + Configuration for an executable `Horovod Job using MPI operator`_. Use this + to run distributed training on k8s with MPI. For more info, check out Running Horovod`_. + Args: + worker: Worker configuration for the job. + launcher: Launcher configuration for the job. + run_policy: Configuration for the run policy. + slots: Number of slots per worker used in hostfile (default: 1). + verbose: Optional flag indicating whether to enable verbose logging (default: False). + log_level: Optional string specifying the log level (default: "INFO"). + discovery_script_path: Path to the discovery script used for host discovery (default: "/etc/mpi/discover_hosts.sh"). + num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. Please use launcher.replicas instead. + num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job. Please use worker.replicas instead. + """ + + worker: Worker = field(default_factory=lambda: Worker()) + launcher: Launcher = field(default_factory=lambda: Launcher()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + slots: int = 1 + verbose: Optional[bool] = False + log_level: Optional[str] = "INFO" + discovery_script_path: Optional[str] = "/etc/mpi/discover_hosts.sh" + # Support v0 config for backwards compatibility + num_launcher_replicas: Optional[int] = None + num_workers: Optional[int] = None class HorovodFunctionTask(MPIFunctionTask): """ @@ -146,11 +225,8 @@ class HorovodFunctionTask(MPIFunctionTask): """ # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. - ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" - discovery_script_path = "/etc/mpi/discover_hosts.sh" def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): - super().__init__( task_config=task_config, task_function=task_function, @@ -158,23 +234,21 @@ def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): ) def get_command(self, settings: SerializationSettings) -> List[str]: - cmd = super().get_command(settings) + cmd = self._get_base_command(settings) mpi_cmd = self._get_horovod_prefix() + cmd return mpi_cmd - def get_config(self, settings: SerializationSettings) -> Dict[str, str]: - config = super().get_config(settings) - return {**config, "worker_spec_command": self.ssh_command} - def _get_horovod_prefix(self) -> List[str]: - np = self.task_config.num_workers * self.task_config.slots + np = self.task_config.worker.replicas * self.task_config.slots + verbose = "--verbose" if self.task_config.verbose is True else "" + log_level = self.task_config.log_level base_cmd = [ "horovodrun", "-np", f"{np}", - "--verbose", + f"{verbose}", "--log-level", - "INFO", + f"{log_level}", "--network-interface", "eth0", "--min-np", @@ -184,7 +258,7 @@ def _get_horovod_prefix(self) -> List[str]: "--slots-per-host", f"{self.task_config.slots}", "--host-discovery-script", - self.discovery_script_path, + self.task_config.discovery_script_path, ] return base_cmd diff --git a/plugins/flytekit-kf-mpi/requirements.txt b/plugins/flytekit-kf-mpi/requirements.txt index f96f963645e..664842e1c14 100644 --- a/plugins/flytekit-kf-mpi/requirements.txt +++ b/plugins/flytekit-kf-mpi/requirements.txt @@ -1,27 +1,67 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfmpi # via -r requirements.in +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot charset-normalizer==3.1.0 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 @@ -29,11 +69,16 @@ cookiecutter==2.1.1 croniter==1.3.8 # via flytekit cryptography==39.0.2 - # via pyopenssl + # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -44,20 +89,53 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.14 +flyteidl==1.5.5 + # via flytekit +flytekit==1.6.1 + # via flytekitplugins-kfmpi +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.5.0 # via + # adlfs # flytekit - # flytekitplugins-kfmpi -flytekit==1.3.1 - # via flytekitplugins-kfmpi + # gcsfs + # s3fs +gcsfs==2023.5.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.18.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage googleapis-common-protos==1.59.0 # via # flyteidl # flytekit + # google-api-core # grpcio-status grpcio==1.51.3 # via @@ -66,11 +144,15 @@ grpcio==1.51.3 grpcio-status==1.51.3 # via flytekit idna==3.4 - # via requests + # via + # requests + # yarl importlib-metadata==6.1.0 # via # flytekit # keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -79,10 +161,16 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit +markdown-it-py==2.2.0 + # via rich markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -94,8 +182,21 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mdurl==0.1.2 + # via markdown-it-py more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.3.1 @@ -105,33 +206,48 @@ numpy==1.23.5 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==23.0 # via # docker # marshmallow pandas==1.5.3 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==4.22.1 # via # flyteidl + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal pyopenssl==23.0.0 # via flytekit python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -147,21 +263,48 @@ pyyaml==6.0 # via # cookiecutter # flytekit + # kubernetes # responses regex==2022.10.31 # via docker-image-py requests==2.28.2 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.3.5 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.5.0 # via flytekit six==1.16.0 - # via python-dateutil + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil smmap==5.0.0 # via gitdb sortedcontainers==2.4.0 @@ -174,23 +317,37 @@ types-pyyaml==6.0.12.8 # via responses typing-extensions==4.5.0 # via + # aioitertools + # azure-core + # azure-storage-blob # flytekit # typing-inspect typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.15 # via + # botocore # docker # flytekit + # google-auth + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker + # via + # docker + # kubernetes wheel==0.40.0 # via flytekit wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit +yarl==1.9.2 + # via aiohttp zipp==3.15.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index 566506069c9..5909adb91b3 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "flyteidl>=0.21.4"] +plugin_requires = ["flytekit>=1.6.1,<2.0.0"] __version__ = "0.0.0+develop" @@ -18,7 +18,7 @@ packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, license="apache2", - python_requires=">=3.6", + python_requires=">=3.9", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 7732d520c20..d0976bf7af2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,22 +1,24 @@ -from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel +import pytest +from flytekitplugins.kfmpi import HorovodJob, MPIJob, Worker, Launcher, RunPolicy, RestartPolicy, CleanPodPolicy +from flytekitplugins.kfmpi.task import MPIFunctionTask from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings - -def test_mpi_model_task(): - job = MPIJobModel( - num_workers=1, - num_launcher_replicas=1, - slots=1, +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert job.num_workers == 1 - assert job.num_launcher_replicas == 1 - assert job.slots == 1 - assert job.from_flyte_idl(job.to_flyte_idl()) + return settings -def test_mpi_task(): +def test_mpi_task(serialization_settings: SerializationSettings): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), requests=Resources(cpu="1"), @@ -31,36 +33,155 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), + + assert my_mpi_task.get_custom(serialization_settings) == {'launcherReplicas': {'replicas': 10, 'resources': {}}, 'workerReplicas': {'replicas': 10, 'resources': {}}, "slots": 1} + assert my_mpi_task.task_type == "mpi" + + +def test_pytorch_task_with_default_config(serialization_settings: SerializationSettings): + task_config = MPIJob() + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x + + assert my_mpi_task(x=10, y="hello") == 10 + + assert my_mpi_task.task_config is not None + assert my_mpi_task.task_type == "mpi" + assert my_mpi_task.resources.limits == Resources() + assert my_mpi_task.resources.requests == Resources(cpu="1") + assert ' '.join(my_mpi_task.get_command(serialization_settings)).startswith(' '.join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"])) + + expected_dict = { + "launcherReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "slots": 1, + } + assert my_mpi_task.get_custom(serialization_settings) == expected_dict + + +def test_pytorch_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = MPIJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + image="launcher:latest", + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.NEVER, + ), + run_policy=RunPolicy( + clean_pod_policy=CleanPodPolicy.ALL, + ), + slots=2, ) - assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x + + assert my_mpi_task(x=10, y="hello") == 10 + + assert my_mpi_task.task_config is not None assert my_mpi_task.task_type == "mpi" + assert my_mpi_task.resources.limits == Resources() + assert my_mpi_task.resources.requests == Resources(cpu="1") + assert ' '.join(my_mpi_task.get_command(serialization_settings)).startswith(' '.join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"])) + expected_custom_dict = { + "launcherReplicas": { + "replicas": 1, + "image": "launcher:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + }, + "slots": 2, + "runPolicy": {'cleanPodPolicy': 'CLEANPOD_POLICY_ALL'}, + } + assert my_mpi_task.get_custom(serialization_settings) == expected_custom_dict -def test_horovod_task(): + +def test_horovod_task(serialization_settings): @task( - task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + task_config=HorovodJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + worker=Worker( + replicas=1, + command=["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + restart_policy=RestartPolicy.NEVER, + ), + slots=2, + verbose=False, + log_level="INFO", + ), ) def my_horovod_task(): ... - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) - cmd = my_horovod_task.get_command(settings) - assert "horovodrun" in cmd - config = my_horovod_task.get_config(settings) - assert "/usr/sbin/sshd" in config["worker_spec_command"] - custom = my_horovod_task.get_custom(settings) - assert isinstance(custom, dict) is True + cmd = my_horovod_task.get_command(serialization_settings) + assert "horovodrun" in cmd + assert "--verbose" not in cmd + assert "--log-level" in cmd + assert "INFO" in cmd + expected_dict = { + "launcherReplicas": { + "replicas": 1, + "resources": { + "requests": [ + {"name": "CPU", "value": "1"}, + ], + "limits": [ + {"name": "CPU", "value": "2"}, + ], + }, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + "command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + }, + "slots": 2, + } + assert my_horovod_task.get_custom(serialization_settings) == expected_dict diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 7de27502bf9..c1516d32483 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -14,3 +14,42 @@ pip install flytekitplugins-kfpytorch To set up PyTorch operator in the Flyte deployment's backend, follow the [PyTorch Operator Setup](https://docs.flyte.org/en/latest/deployment/plugin_setup/pytorch_operator.html) guide. An [example](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/kfpytorch/pytorch_mnist.html#sphx-glr-auto-integrations-kubernetes-kfpytorch-pytorch-mnist-py) showcasing PyTorch operator can be found in the documentation. + +## Code Example +```python +from flytekitplugins.kfpytorch import PyTorch, Worker, Master, RestartPolicy, RunPolicy, CleanPodPolicy + +@task( + task_config = PyTorch( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + master=Master( + restart_policy=RestartPolicy.ALWAYS, + ), + ) + image="test_image", + resources=Resources(cpu="1", mem="1Gi"), +) +def pytorch_job(): + ... +``` + + +## Upgrade Pytorch Plugin from V0 to V1 +Pytorch plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: + ``` + task_config=Pytorch(num_workers=10), + ``` + to: + ``` + task_config=PyTorch(worker=Worker(replicas=10)), + ``` diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index cb9add7302f..66e8a5c2fd2 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -11,4 +11,4 @@ Elastic """ -from .task import Elastic, PyTorch +from .task import Elastic, PyTorch, Worker, Master, CleanPodPolicy, RunPolicy, RestartPolicy diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index a870227d7da..206a5fbd322 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -81,15 +81,14 @@ class Master: @dataclass class PyTorch(object): """ - Configuration for an executable `Pytorch Job `_. Use this - to run distributed pytorch training on k8s + Configuration for an executable `PyTorch Job `_. Use this + to run distributed PyTorch training on Kubernetes. Args: - master: Configuration for master replica group. - worker: Configuration for worker replica group. - run_policy: Configuration for run policy. - num_workers: This is deprecated. Use worker.replicas instead. - + master: Configuration for the master replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. """ master: Master = field(default_factory=lambda: Master()) worker: Worker = field(default_factory=lambda: Worker()) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 46973017acf..a9501d0ee6d 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -63,6 +63,7 @@ click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via # flytekit @@ -93,11 +94,9 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.19 - # via - # flytekit - # flytekitplugins-kfpytorch -flytekit==1.5.0 +flyteidl==1.5.5 + # via flytekit +flytekit==1.6.1 # via flytekitplugins-kfpytorch frozenlist==1.3.3 # via @@ -175,6 +174,8 @@ keyring==23.13.1 # via flytekit kubernetes==26.1.0 # via flytekit +markdown-it-py==2.2.0 + # via rich markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -186,6 +187,8 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mdurl==0.1.2 + # via markdown-it-py more-itertools==9.1.0 # via jaraco-classes msal==1.22.0 @@ -236,6 +239,8 @@ pyasn1-modules==0.3.0 # via google-auth pycparser==2.21 # via cffi +pygments==2.15.1 + # via rich pyjwt[crypto]==2.6.0 # via # adal @@ -290,6 +295,12 @@ requests-oauthlib==1.3.1 # kubernetes responses==0.23.1 # via flytekit +rich==13.3.5 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit rsa==4.9 # via google-auth s3fs==2023.4.0 diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index b078b6ce6c1..443314aa697 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["cloudpickle", "flytekit>=1.6.1", "flyteidl>=1.5.5"] +plugin_requires = ["cloudpickle", "flytekit>=1.6.1"] __version__ = "0.0.0+develop" @@ -21,7 +21,7 @@ "elastic": ["torch>=1.9.0"], }, license="apache2", - python_requires=">=3.8", + python_requires=">=3.9", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 920bb8e9fb9..78661963a10 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -1,6 +1,6 @@ import pytest -from flytekitplugins.kfpytorch.task import PyTorch, Worker, Master, RestartPolicy, RunPolicy, CleanPodPolicy +from flytekitplugins.kfpytorch.task import PyTorch, Worker, Master, RestartPolicy from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings diff --git a/plugins/flytekit-kf-tensorflow/README.md b/plugins/flytekit-kf-tensorflow/README.md index d2568e22740..bc9d886659b 100644 --- a/plugins/flytekit-kf-tensorflow/README.md +++ b/plugins/flytekit-kf-tensorflow/README.md @@ -39,11 +39,11 @@ def tf_job(): ``` -## Upgrade TensorFlow Plugin from to V1 +## Upgrade TensorFlow Plugin from V0 to V1 Tensorflow plugin is now updated from v0 to v1 to enable more configuration options. To migrate from v0 to v1, change the following: 1. Update flytepropeller to v1.6.0 -2. Update flytekit version to v1.6.1 +2. Update flytekit version to v1.6.2 3. Update your code from: ``` task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 57064c6e518..782b3955d42 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -38,14 +38,15 @@ class CleanPodPolicy(Enum): @dataclass class RunPolicy: """ - RunPolicy describes some policy to apply to the execution of a kubeflow job. + RunPolicy describes a set of policies to apply to the execution of a Kubeflow job. + Args: - clean_pod_policy (int): Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None. - ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs. - active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job. - can remain active before it is terminated. Must be a positive integer. This setting applies only to pods. - where restartPolicy is OnFailure or Always. - backoff_limit (int): Number of retries before marking this job as failed. + clean_pod_policy: The policy for cleaning up pods after the job completes. Defaults to None. + ttl_seconds_after_finished: The time-to-live (TTL) in seconds for cleaning up finished jobs. + active_deadline_seconds: The duration (in seconds) since startTime during which the job can remain + active before it is terminated. Must be a positive integer. This setting applies only to pods + where restartPolicy is OnFailure or Always. + backoff_limit: The number of retries before marking this job as failed. """ clean_pod_policy: CleanPodPolicy = None ttl_seconds_after_finished: Optional[int] = None @@ -82,6 +83,19 @@ class Worker: @dataclass class TfJob: + """ + Configuration for an executable `TensorFlow Job `_. Use this + to run distributed TensorFlow training on Kubernetes. + + Args: + chief: Configuration for the chief replica group. + ps: Configuration for the parameter server (PS) replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. + num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead. + num_chief_replicas: [DEPRECATED] This argument is deprecated. Use `chief.replicas` instead. + """ chief: Chief = field(default_factory=lambda: Chief()) ps: PS = field(default_factory=lambda: PS()) worker: Worker = field(default_factory=lambda: Worker()) @@ -120,7 +134,7 @@ def _convert_replica_spec(self, replica_config: Union[Chief, PS, Worker]) -> ten def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( - clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy.value else None, ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, active_deadline_seconds=run_policy.active_deadline_seconds, backoff_limit=run_policy.active_deadline_seconds, diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index 064ffd73d59..91d261f06ae 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -1,27 +1,67 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kftensorflow # via -r requirements.in +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot charset-normalizer==3.1.0 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 @@ -29,11 +69,16 @@ cookiecutter==2.1.1 croniter==1.3.8 # via flytekit cryptography==39.0.2 - # via pyopenssl + # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -44,18 +89,53 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.14 +flyteidl==1.5.5 # via flytekit -flytekit==1.3.1 +flytekit==1.6.1 # via flytekitplugins-kftensorflow +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.5.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.5.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.18.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage googleapis-common-protos==1.59.0 # via # flyteidl # flytekit + # google-api-core # grpcio-status grpcio==1.51.3 # via @@ -64,11 +144,15 @@ grpcio==1.51.3 grpcio-status==1.51.3 # via flytekit idna==3.4 - # via requests + # via + # requests + # yarl importlib-metadata==6.1.0 # via # flytekit # keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -77,10 +161,16 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit +markdown-it-py==2.2.0 + # via rich markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -92,8 +182,21 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mdurl==0.1.2 + # via markdown-it-py more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.3.1 @@ -103,33 +206,48 @@ numpy==1.23.5 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==23.0 # via # docker # marshmallow pandas==1.5.3 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==4.22.1 # via # flyteidl + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal pyopenssl==23.0.0 # via flytekit python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -145,21 +263,48 @@ pyyaml==6.0 # via # cookiecutter # flytekit + # kubernetes # responses regex==2022.10.31 # via docker-image-py requests==2.28.2 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.3.5 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.5.0 # via flytekit six==1.16.0 - # via python-dateutil + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil smmap==5.0.0 # via gitdb sortedcontainers==2.4.0 @@ -172,23 +317,37 @@ types-pyyaml==6.0.12.8 # via responses typing-extensions==4.5.0 # via + # aioitertools + # azure-core + # azure-storage-blob # flytekit # typing-inspect typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.15 # via + # botocore # docker # flytekit + # google-auth + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker + # via + # docker + # kubernetes wheel==0.40.0 # via flytekit wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit +yarl==1.9.2 + # via aiohttp zipp==3.15.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index ad19872e62a..bbfb01145bd 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -4,8 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -# TODO: Requirements are missing, add them back in later. -plugin_requires = ["flytekit>=1.6.1", "flyteidl>=1.5.5"] +plugin_requires = ["flytekit>=1.6.1"] __version__ = "0.0.0+develop" @@ -19,7 +18,7 @@ packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, license="apache2", - python_requires=">=3.8", + python_requires=">=3.9", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers",