Skip to content

Commit

Permalink
Revisiting local interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
SebS94 committed Dec 22, 2023
1 parent 1124fd4 commit ce7cb3a
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 71 deletions.
1 change: 1 addition & 0 deletions squirrel/artifact_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import squirrel.artifact_manager.drivers # noqa
61 changes: 52 additions & 9 deletions squirrel/artifact_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,50 @@
from types import TracebackType
from typing import Optional, Any, Iterable, Type

from squirrel.catalog import Catalog, Source
from squirrel.catalog import Catalog
from squirrel.catalog.catalog import CatalogSource

logger = logging.getLogger(__name__)


class TmpArtifact:
"""
Class to be used as a context for temporarily downloading an artifact, interacting with it and the deleting it.
When entering the scope it downloads the artifact to a local dir with a valid afid and returns the filepath to it.
"""
def __init__(self, artifact_manager: "ArtifactManager", collection: str, artifact: str, version: str ) -> None:
"""
Initializes the TmpArtifact.
Args:
artifact_manager: An artifact manager instance to use for downloading the artifact.
"""
self.artifact_manager = artifact_manager
self.tempdir = tempfile.TemporaryDirectory()
self.collection = collection
self.artifact = artifact
self.version = version

def __enter__(self) -> tuple[CatalogSource, Path]:
"""
Called when entering the context. Downloads the artifact to a local dir with a valid afid and returns the filepath to it.
Returns: Absolute path to artifact folder as str
"""
source, _ = self.artifact_manager.download_artifact(self.artifact, self.collection, self.version, Path(self.tempdir.name))
return source, Path(self.tempdir.name, self.artifact)

def __exit__(
self,
exctype: Optional[Type[BaseException]] = None,
excinst: Optional[BaseException] = None,
exctb: Optional[TracebackType] = None,
) -> None:
"""Called when exiting the context. Deletes the artifact folder."""
if self.tempdir is not None:
self.tempdir.cleanup()


class DirectoryLogger:
"""
Class to be used as a context for logging a directory as an artifact.
Expand Down Expand Up @@ -61,9 +100,9 @@ def __exit__(


class ArtifactManager(ABC):
def __init__(self):
def __init__(self, active_collection: str = "default"):
"""Artifact manager interface for various backends."""
self._active_collection = "default"
self._active_collection = active_collection

@property
def active_collection(self) -> str:
Expand Down Expand Up @@ -121,7 +160,7 @@ def log_files(
local_path: Path,
collection: Optional[str] = None,
artifact_path: Optional[Path] = None,
) -> Source:
) -> CatalogSource:
"""
Upload a file or folder into (current) collection, increment version automatically
Expand All @@ -134,7 +173,7 @@ def log_files(
raise NotImplementedError

@abstractmethod
def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) -> Source:
def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) -> CatalogSource:
"""
Log an arbitrary python object
Expand All @@ -145,9 +184,13 @@ def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) ->

@abstractmethod
def download_artifact(
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Path = "./"
) -> Source:
"""Retrieve file (from current collection) to specific location. Retrieve latest version unless specified."""
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Optional[Path] = None
) -> tuple[CatalogSource, Path]:
"""
Download artifact contents (from current collection) to specific location and return a source listing them.
If no target location is specified, a temporary directory is created and the path to it is returned.
Retrieve latest version unless specified."""
raise NotImplementedError

def exists(self, artifact: str) -> bool:
Expand All @@ -172,5 +215,5 @@ def download_collection(self, collection: Optional[str] = None, to: Path = "./")
catalog = self.collection_to_catalog(collection)
for _, artifact in catalog:
artifact_name = artifact.metadata["artifact"]
self.download_artifact(artifact_name, collection, to=to / artifact_name)
self.download_artifact(artifact_name, collection, to=to)
return catalog
31 changes: 31 additions & 0 deletions squirrel/artifact_manager/drivers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, Optional

from squirrel.driver.store import StoreDriver
from squirrel.framework.plugins.plugin_manager import register_driver
from squirrel.store import AbstractStore, FilesystemStore


class DirectoryDriver(StoreDriver):
name = "directory"

def __init__(self, url: str, storage_options: Optional[dict[str, Any]] = None, **kwargs) -> None:
"""Initializes FileSystemDriver.
Args:
url (str): the url of the store
**kwargs: Keyword arguments to pass to the super class initializer.
"""
if "store" in kwargs:
raise ValueError("Store of DirectoryDriver is fixed, `store` cannot be provided.")

super().__init__(url=url, serializer=None, storage_options=storage_options, **kwargs)

@property
def store(self) -> AbstractStore:
"""Store that is used by the driver."""
if self._store is None:
self._store = FilesystemStore(url=self.url, serializer=self.serializer, **self.storage_options)
return self._store


register_driver(DirectoryDriver)
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Optional, Any, List, Iterable
from typing import Optional, Any, List, Iterable, Union

from squirrel.artifact_manager.base import ArtifactManager
from squirrel.catalog import Catalog, Source
from squirrel.artifact_manager.base import ArtifactManager, TmpArtifact
from squirrel.catalog import Catalog
from squirrel.catalog.catalog import CatalogSource, Source
from squirrel.serialization import MessagepackSerializer, JsonSerializer, SquirrelSerializer
from squirrel.store import FilesystemStore
from squirrel.store.filesystem import get_random_key
Expand All @@ -19,7 +20,7 @@ class ArtifactFileStore(FilesystemStore):
The get and set methods are altered to allow for storing serialized data as well as raw files.
If the final path component is a serializer name, the data is stored as a serialized file.
If the final path component is "file", the data is stored as a raw file.
If the final path component is "files", the data is stored as a raw file.
"""

def complete_key(self, partial_key: Path, **open_kwargs) -> List[str]:
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(self, url: str, serializer: Optional[SquirrelSerializer] = None, **
"""
super().__init__()
if serializer is None:
serializer = MessagepackSerializer()
serializer = JsonSerializer()
self.backend = ArtifactFileStore(url=url, serializer=serializer, **fs_kwargs)

def list_collection_names(self) -> Iterable:
Expand All @@ -104,9 +105,11 @@ def exists_in_collection(self, artifact: str, collection: Optional[str] = None)
return self.backend.key_exists(Path(collection, artifact))

def get_artifact_source(
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None
) -> Source:
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, catalog: Optional[Catalog] = None
) -> CatalogSource:
"""Catalog entry for a specific artifact"""
if catalog is None:
catalog = Catalog()
if collection is None:
collection = self.active_collection
if version is None or version == "latest":
Expand All @@ -115,18 +118,23 @@ def get_artifact_source(
raise ValueError(f"Artifact {artifact} does not exist in collection {collection} with version {version}!")

# TODO: Vary source (driver) description when support for serialised python values is added
return Source(
driver_name="file",
driver_kwargs={
"url": Path(self.backend.url, collection, artifact, version, "file").as_uri(),
"storage_options": self.backend.storage_options,
},
metadata={
"collection": collection,
"artifact": artifact,
"version": version,
"location": Path(self.backend.url, collection, artifact, version).as_uri(),
},
return CatalogSource(
Source(
driver_name="directory",
driver_kwargs={
"url": Path(self.backend.url, collection, artifact, version, "files").as_uri(),
"storage_options": self.backend.storage_options,

},
metadata={
"collection": collection,
"artifact": artifact,
"version": version,
},
),
identifier=str(Path(collection, artifact)),
catalog=catalog,
version=int(version[1:]) + 1, # Squirrel Catalog version is 1-based
)

def collection_to_catalog(self, collection: Optional[str] = None) -> Catalog:
Expand All @@ -136,7 +144,8 @@ def collection_to_catalog(self, collection: Optional[str] = None) -> Catalog:
catalog = Catalog()
for artifact in self.backend.complete_key(Path(collection)):
for version in self.backend.complete_key(Path(collection, artifact)):
catalog[str(Path(collection, artifact))] = self.get_artifact_source(artifact, collection, version)
src = self.get_artifact_source(artifact, collection, version, catalog=catalog)
catalog[str(Path(collection, artifact)), src.version] = src
return catalog

def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) -> Source:
Expand All @@ -161,7 +170,7 @@ def log_files(
local_path: Path,
collection: Optional[str] = None,
artifact_path: Optional[Path] = None,
) -> Source:
) -> CatalogSource:
"""Upload local file or folder to artifact store without serialisation"""
if not isinstance(local_path, (str, Path)):
raise ValueError("Path to file should be passed as a pathlib.Path object!")
Expand All @@ -182,24 +191,35 @@ def log_files(
return self.get_artifact_source(artifact_name, collection)

def download_artifact(
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Path = "./"
) -> Source:
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Optional[Path] = None
) -> Union[tuple[Source, Path], TmpArtifact]:
"""Download artifact to local path."""
if collection is None:
collection = self.active_collection
if version is None or version == "latest":
version = f"v{max(int(vs[1:]) for vs in self.backend.complete_key(Path(collection) / Path(artifact)))}"
if isinstance(to, str):
to = Path(to)
location = Path(collection, artifact, version)
self.backend.get(Path(location), target=to)
return Source(
driver_name="file",
driver_kwargs={"url": str(Path(to, artifact))},
metadata={
"collection": collection,
"artifact": artifact,
"version": version,
"location": str(Path(to, artifact)),
},
)

if to is not None:
if isinstance(to, str):
to = Path(to)
location = Path(collection, artifact, version)
self.backend.get(Path(location), target=to / artifact)
src = CatalogSource(
Source(
driver_name="directory",
driver_kwargs={
"url": (to / artifact).as_uri(),
},
metadata={
"collection": collection,
"artifact": artifact,
"version": version,
},
),
identifier=str(Path(collection, artifact)),
catalog=Catalog(),
version=int(version[1:]) + 1, # Squirrel Catalog version is 1-based
)
return src, to
else:
return TmpArtifact(self, collection, artifact, version)
17 changes: 11 additions & 6 deletions squirrel/artifact_manager/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ def __init__(self, entity: Optional[str] = None, project: Optional[str] = None):
self.project = project
elif wandb.run is not None:
self.project = wandb.run.project
logger.info(f"Using project {self.project} from active wandb run.")
elif wandb.Api().settings["project"] is not None:
self.project = wandb.Api().settings["project"]
logger.info(f"Using project {self.project} from wandb settings.")
else:
raise ValueError("No project name was provided and no active project could be identified.")
if entity is not None:
self.entity = entity
elif wandb.run is not None:
elif wandb.run is not None and wandb.run.entity is not None:
self.entity = wandb.run.entity
logger.info(f"Using entity {self.entity} from active wandb run.")
elif wandb.Api().settings["entity"] is not None:
self.entity = wandb.Api().settings["entity"]
logger.info(f"Using entity {self.entity} from wandb settings.")
else:
self.entity = wandb.Api().project(self.project).entity
logger.info(f"Using entity {self.entity} from wandb project.")

def list_collection_names(self) -> Iterable:
"""
Expand Down Expand Up @@ -72,7 +77,7 @@ def get_artifact_source(
if version is None:
versions = [
instance.version
for instance in wandb.Api().artifact_versions(type_name=collection, name=f"{self.project}/{artifact}")
for instance in wandb.Api().artifact_versions(type_name=collection, name=f"{self.entity}/{self.project}/{artifact}")
]
version = f"v{max([int(v[1:]) for v in versions])}"

Expand Down Expand Up @@ -113,7 +118,7 @@ def collection_to_catalog(self, collection: Optional[str] = None) -> Catalog:
]
catalog = Catalog()
for artifact in artifact_names:
for instance in wandb.Api().artifact_versions(type_name=collection, name=f"{self.project}/{artifact}"):
for instance in wandb.Api().artifact_versions(type_name=collection, name=f"{self.entity}/{self.project}/{artifact}"):
catalog[str(Path(collection, artifact))] = self.get_artifact_source(
artifact, collection, instance.version
)
Expand Down Expand Up @@ -168,8 +173,8 @@ def log_files(
return self.get_artifact_source(artifact_name, collection)

def download_artifact(
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Path = "./"
) -> Source:
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Optional[Path] = None
) -> tuple[Source, Path]:
"""
Download a specific artifact to a local path.
Expand All @@ -190,4 +195,4 @@ def download_artifact(
driver_name="file",
driver_kwargs={"url": str(to)},
metadata={"collection": collection, "artifact": artifact, "version": version, "location": str(to)},
)
)
1 change: 1 addition & 0 deletions test/test_artifact_manager/my_collection/folder/bar.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test: Bar
1 change: 1 addition & 0 deletions test/test_artifact_manager/my_collection/folder/baz.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test: Baz
1 change: 1 addition & 0 deletions test/test_artifact_manager/my_collection/folder/foo.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test: Foo
Loading

0 comments on commit ce7cb3a

Please sign in to comment.