Skip to content

Commit

Permalink
Add ImageSpec.from_env (#2895)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan authored Dec 2, 2024
1 parent 47fe660 commit 6cdc4ab
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
32 changes: 32 additions & 0 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pathlib
import re
import sys
import typing
from abc import abstractmethod
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -332,6 +333,37 @@ def force_push(self) -> "ImageSpec":

return copied_image_spec

@classmethod
def from_env(cls, *, pinned_packages: Optional[List[str]] = None, **kwargs) -> "ImageSpec":
"""Create ImageSpec with the environment's Python version and packages pinned to the ones in the environment."""

from importlib.metadata import version

# Invalid kwargs when using `ImageSpec.from_env`
invalid_kwargs = ["python_version"]
for invalid_kwarg in invalid_kwargs:
if invalid_kwarg in kwargs and kwargs[invalid_kwarg] is not None:
msg = (
f"{invalid_kwarg} can not be used with `from_env` because it will be inferred from the environment"
)
raise ValueError(msg)

version_info = sys.version_info
python_version = f"{version_info.major}.{version_info.minor}"

if "packages" in kwargs:
packages = kwargs.pop("packages")
else:
packages = []

pinned_packages = pinned_packages or []

for package_to_pin in pinned_packages:
package_version = version(package_to_pin)
packages.append(f"{package_to_pin}=={package_version}")

return ImageSpec(packages=packages, python_version=python_version, **kwargs)


class ImageSpecBuilder:
@abstractmethod
Expand Down
31 changes: 31 additions & 0 deletions tests/flytekit/unit/core/image_spec/test_image_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from unittest.mock import Mock

import mock
Expand Down Expand Up @@ -244,3 +245,33 @@ def test_registry_name():
]
for valid_registry_name in valid_registry_names:
ImageSpec(registry=valid_registry_name)


def test_image_spec_from_env_error():
msg = "python_version can not be used with `from_env`"
with pytest.raises(ValueError, match=msg):
ImageSpec.from_env(pinned_packages=["joblib"], python_version="3.9")


def test_image_spec_from_env_with_pinned_packages():
import joblib
import msgpack
joblib_version = joblib.__version__
msgpack_version = msgpack.__version__

version_info = sys.version_info
python_version = f"{version_info.major}.{version_info.minor}"

image_spec = ImageSpec.from_env(pinned_packages=["joblib", "msgpack"], packages=["scikit-learn"])
assert image_spec.python_version == python_version
assert f"joblib=={joblib_version}" in image_spec.packages
assert f"msgpack=={msgpack_version}" in image_spec.packages
assert 'scikit-learn' in image_spec.packages


def test_image_spec_from_env_empty():
version_info = sys.version_info
python_version = f"{version_info.major}.{version_info.minor}"

image_spec = ImageSpec.from_env()
assert image_spec.python_version == python_version

0 comments on commit 6cdc4ab

Please sign in to comment.