Skip to content

Commit

Permalink
First wandb version
Browse files Browse the repository at this point in the history
  • Loading branch information
SebS94 committed Dec 15, 2023
1 parent 336f9b7 commit 5af0566
Show file tree
Hide file tree
Showing 6 changed files with 1,687 additions and 1,378 deletions.
2,692 changes: 1,448 additions & 1,244 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ s3fs = {version = ">=2021.7.0", optional = true}
torch = {version = ">=1.13.1", optional = true}
zarr = {version = "^2.10.3", optional = true}
pyarrow = {version = "^10.0.1", optional = true}
wandb = {version = "*", optional = true}

[tool.poetry.group.dev.dependencies]
twine = "^4.0.2"
Expand All @@ -53,7 +54,6 @@ pytest = "^6.2.1"
pytest-timeout = "^2.1.0"
pytest-cov = "^4.0.0"
pytest-xdist = "^3.2.0"
wandb = "^0.13.10"
mlflow = "^2.1.1"
pre-commit = "^2.16.0"
pip-tools = "^6.6.2"
Expand All @@ -76,12 +76,13 @@ dask = ["dask"]
excel = ["odfpy", "openpyxl", "pyxlsb", "xlrd"]
feather = ["pyarrow"]
gcp = ["gcsfs"]
numba = ["numba"]
parquet = ["pyarrow"]
s3 = ["s3fs"]
torch = ["torch"]
wandb = ["wandb"]
zarr = ["zarr"]
all = ["adlfs", "dask", "odfpy", "openpyxl", "pyxlsb", "xlrd", "pyarrow", "gcsfs", "s3fs", "torch", "zarr"]
numba = ["numba"]
all = ["adlfs", "dask", "odfpy", "openpyxl", "pyxlsb", "xlrd", "pyarrow", "gcsfs", "s3fs", "torch", "zarr", "wandb", "numba"]

[build-system]
requires = ["poetry-core"]
Expand Down
53 changes: 33 additions & 20 deletions squirrel/artifact_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,12 @@ def list_collection_names(self) -> Iterable:
"""Return list of all collections in the artifact store"""
raise NotImplementedError

@abstractmethod
def store_to_catalog(self) -> Catalog:
"""Provide a catalog of all stored artifacts."""
raise NotImplementedError

@abstractmethod
def collection_to_catalog(self, collection: Optional[str] = None) -> Catalog:
"""Catalog of all artifacts within a specific collection."""

@abstractmethod
def get_artifact(self, artifact: str, collection: Optional[str] = None, version: Optional[int] = None) -> Any:
def get_artifact(self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None) -> Any:
"""Retrieve specific artifact value."""
raise NotImplementedError

Expand All @@ -70,16 +65,6 @@ def log_file(self, local_path: Path, name: str, collection: Optional[str] = None
"""Upload file into (current) collection, increment version automatically"""
raise NotImplementedError

@abstractmethod
def log_files(self, local_paths: List[Path], collection: Optional[str] = None) -> Catalog:
"""Upload a collection of files into a (current) collection"""
raise NotImplementedError

@abstractmethod
def log_folder(self, files: Path, collection: Optional[str] = None) -> Catalog:
"""Upload folder as collection of artifacts into store"""
raise NotImplementedError

@abstractmethod
def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) -> Source:
"""
Expand All @@ -92,12 +77,40 @@ def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) ->

@abstractmethod
def download_artifact(
self, artifact: str, collection: Optional[str] = None, version: Optional[int] = None, to: Path = "./"
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Path = "./"
) -> Source:
"""Retrieve file (from current collection) to specific location. Retrieve latest version unless specified."""
raise NotImplementedError

@abstractmethod
def store_to_catalog(self) -> Catalog:
"""Provide Catalog of all artifacts stored in backend."""
catalog = Catalog()
for collection in self.list_collection_names():
catalog.update(self.collection_to_catalog(collection))
return catalog

def log_files(self, local_paths: List[Path], collection: Optional[str] = None) -> Catalog:
"""Upload a collection of file into a (current) collection"""
if collection is None:
collection = self.active_collection
for local_path in local_paths:
self.log_file(local_path, local_path.name, collection)
return self.collection_to_catalog(collection)

def log_folder(self, folder: Path, collection: Optional[str] = None) -> Catalog:
"""Log folder as collection of artifacts into store"""
if not folder.is_dir():
raise ValueError(f"Path {folder} is not a directory!")

if collection is None:
collection = folder.name

return self.log_files([f for f in folder.iterdir() if f.is_file()], collection)

def download_collection(self, collection: Optional[str] = None, to: Path = "./") -> Catalog:
"""Retrieve files (from current collection) to specific location. Retrieve latest version of all artifacts."""
raise NotImplementedError
"""Download all artifacts in collection to local directory."""
catalog = self.collection_to_catalog(collection)
for artifact in catalog.values():
artifact_name = artifact.metadata["artifact"]
self.download_artifact(artifact_name, to=to / artifact_name)
return catalog
77 changes: 22 additions & 55 deletions squirrel/artifact_manager/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,30 @@ def list_collection_names(self) -> Iterable:
"""List all collections managed by this ArtifactManager."""
return self.backend.keys(nested=False)

def get_artifact(self, artifact: str, collection: Optional[str] = None, version: Optional[int] = None) -> Any:
def get_artifact(self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None) -> Any:
"""Retrieve specific artifact value."""
if collection is None:
collection = self.active_collection
if version is None:
version = max(int(vs) for vs in self.backend.complete_key(Path(collection) / Path(artifact)))
if not self.backend.key_exists(Path(collection, artifact, str(version))):
if version is None or version == "latest":
version = f"v{max(int(vs[1:]) for vs in self.backend.complete_key(Path(collection) / Path(artifact)))}"
if not self.backend.key_exists(Path(collection, artifact, version)):
raise ValueError(f"Artifact {artifact} does not exist in collection {collection} with version {version}!")
path = Path(collection, artifact, str(version))
path = Path(collection, artifact, version)
return self.backend.get(path)

def get_artifact_source(
self, artifact: str, collection: Optional[str] = None, version: Optional[int] = None
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None
) -> Source:
"""Catalog entry for a specific artifact"""
if collection is None:
collection = self.active_collection
if version is None:
version = max(int(vs) for vs in self.backend.complete_key(Path(collection) / Path(artifact)))
if not self.backend.key_exists(Path(collection, artifact, str(version))):
if version is None or version == "latest":
version = f"v{max(int(vs[1:]) for vs in self.backend.complete_key(Path(collection) / Path(artifact)))}"
if not self.backend.key_exists(Path(collection, artifact, version)):
raise ValueError(f"Artifact {artifact} does not exist in collection {collection} with version {version}!")

if Serializers[self.backend.serializer.__class__] in self.backend.complete_key(
Path(collection, artifact, str(version))
Path(collection, artifact, version)
):
return Source(
driver_name=Serializers[self.backend.serializer.__class__],
Expand All @@ -107,7 +107,7 @@ def get_artifact_source(
self.backend.url,
collection,
artifact,
str(version),
version,
Serializers[self.backend.serializer.__class__],
).as_uri(),
"storage_options": self.backend.storage_options,
Expand All @@ -120,23 +120,23 @@ def get_artifact_source(
self.backend.url,
collection,
artifact,
str(version),
version,
Serializers[self.backend.serializer.__class__],
).as_uri(),
},
)
elif self.backend.key_exists(Path(collection, artifact, str(version), "file")):
elif self.backend.key_exists(Path(collection, artifact, version, "file")):
return Source(
driver_name="file",
driver_kwargs={
"url": Path(self.backend.url, collection, artifact, str(version), "file").as_uri(),
"url": Path(self.backend.url, collection, artifact, version, "file").as_uri(),
"storage_options": self.backend.storage_options,
},
metadata={
"collection": collection,
"artifact": artifact,
"version": version,
"location": Path(self.backend.url, collection, artifact, str(version), "file").as_uri(),
"location": Path(self.backend.url, collection, artifact, version, "file").as_uri(),
},
)
else:
Expand All @@ -149,14 +149,7 @@ def collection_to_catalog(self, collection: Optional[str] = None) -> Catalog:
catalog = Catalog()
for artifact in self.backend.complete_key(Path(collection)):
for version in self.backend.complete_key(Path(collection, artifact)):
catalog[str(Path(collection, artifact))] = self.get_artifact_source(artifact, collection, int(version))
return catalog

def store_to_catalog(self) -> Catalog:
"""Provide Catalog of all artifacts stored in backend"""
catalog = Catalog()
for collection in self.list_collection_names():
catalog.update(self.collection_to_catalog(collection))
catalog[str(Path(collection, artifact))] = self.get_artifact_source(artifact, collection, version)
return catalog

def log_file(self, local_path: Path, name: str, collection: Optional[str] = None) -> Source:
Expand All @@ -166,41 +159,23 @@ def log_file(self, local_path: Path, name: str, collection: Optional[str] = None
assert isinstance(local_path, Path), "Path to file must be passed as a pathlib.Path object!"
if collection is None:
collection = self.active_collection
version = len(self.backend.complete_key(Path(collection, name))) + 1
self.backend.set(local_path, Path(collection, name, str(version)))
version = f"v{len(self.backend.complete_key(Path(collection, name)))}"
self.backend.set(local_path, Path(collection, name, version))
return self.get_artifact_source(name, collection)

def log_files(self, local_paths: List[Path], collection: Optional[str] = None) -> Catalog:
"""Upload a collection of file into a (current) collection"""
if collection is None:
collection = self.active_collection
for local_path in local_paths:
self.log_file(local_path, local_path.name, collection)
return self.collection_to_catalog(collection)

def log_folder(self, file: Path, collection: Optional[str] = None) -> Catalog:
"""Log folder as collection of artifacts into store"""
if not file.is_dir():
raise ValueError(f"Path {file} is not a directory!")

if collection is None:
collection = file.name

return self.log_files([f for f in file.iterdir() if f.is_file()], collection)

def log_artifact(self, obj: Any, name: str, collection: Optional[str] = None) -> Source:
"""Log an arbitrary python object using store serialisation."""
if collection is None:
collection = self.active_collection
if self.backend.key_exists(Path(collection, name)):
version = len(self.backend.complete_key(Path(collection, name))) + 1
version = f"v{len(self.backend.complete_key(Path(collection, name)))}"
else:
version = 1
self.backend.set(obj, Path(collection, name, str(version)))
version = "v0"
self.backend.set(obj, Path(collection, name, version))
return self.get_artifact_source(name, collection)

def download_artifact(
self, artifact: str, collection: Optional[str] = None, version: Optional[int] = None, to: Path = "./"
self, artifact: str, collection: Optional[str] = None, version: Optional[str] = None, to: Path = "./"
) -> Source:
"""Download artifact to local path."""
location = self.get_artifact_source(artifact, collection, version).metadata["location"]
Expand All @@ -211,11 +186,3 @@ def download_artifact(
driver_kwargs={"url": str(to), "storage_options": self.backend.storage_options},
metadata={"collection": collection, "artifact": artifact, "version": version, "location": str(to)},
)

def download_collection(self, collection: Optional[str] = None, to: Path = "./") -> Catalog:
"""Download all artifacts in collection to local directory."""
catalog = self.collection_to_catalog(collection)
for artifact in catalog.values():
artifact_name = artifact.metadata["artifact"]
self.download_artifact(artifact_name, to=to / artifact_name)
return catalog
Loading

0 comments on commit 5af0566

Please sign in to comment.