Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Remove tha artifact_name argument from ds.log_to_mlflow() #563

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions dagshub/common/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import functools
import types
import logging
import importlib
Expand Down Expand Up @@ -29,6 +30,12 @@ def to_timestamp(ts: Union[float, int, datetime.datetime]) -> int:
return int(ts)


def removeprefix(val: str, prefix: str) -> str:
if val.startswith(prefix):
return val[len(prefix) :]
return val


def lazy_load(module_name, source_package=None, callback=None):
if source_package is None:
# TODO: need to have a map for commonly used imports here. Also handle dots
Expand Down Expand Up @@ -94,3 +101,24 @@ def _import_module(self):
# Update this object's dict so that attribute references are efficient
# (__getattr__ is only called on lookups that fail)
self.__dict__.update(module.__dict__)


def deprecated(additional_message=""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

"""
Decorator to mark functions as deprecated. It will print a warning
message when the function is called.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
additional = "\n" + additional_message if additional_message else ""
logger.warning(
f"DagsHub Deprecation Warning: "
f"{func.__name__} is deprecated and may be removed in future versions.{additional}",
)
return func(*args, **kwargs)

return wrapper

return decorator
14 changes: 9 additions & 5 deletions dagshub/data_engine/datasets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict

from dagshub.common.analytics import send_analytics_event
from dagshub.data_engine.client.data_client import DataClient
from dagshub.data_engine.client.models import DatasetResult
from dagshub.data_engine import datasources
from dagshub.data_engine.model.datasource import Datasource, DEFAULT_MLFLOW_ARTIFACT_NAME, DatasetState
from dagshub.data_engine.model.datasource import Datasource, DatasetState
from dagshub.data_engine.model.datasource_state import DatasourceState
from dagshub.data_engine.model.errors import DatasetNotFoundError

Expand Down Expand Up @@ -61,19 +61,23 @@ def get_dataset_from_file(path: str) -> Datasource:
return datasources.get_datasource_from_file(path)


def get_from_mlflow(run=None, artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME) -> Datasource:
def get_from_mlflow(run=None, artifact_name: Optional[str] = None) -> Dict[str, Datasource]:
"""
Load a dataset from an MLflow run.
Load datasets from an MLflow run.

To save a datasource to MLflow, use
:func:`Datasource.log_to_mlflow()<dagshub.data_engine.model.datasource.Datasource.log_to_mlflow>`.
:func:`QueryResult.log_to_mlflow()<dagshub.data_engine.model.query_result.QueryResult.log_to_mlflow>`.

This is a copy of :func:`datasources.get_from_mlflow()<dagshub.data_engine.datasources.get_from_mlflow>`

Args:
run: Run or ID of the MLflow run to load the datasource from.
If ``None``, gets it from the current active run.
artifact_name: Name of the artifact in the run.
If specified, will only return the dataset defined in this artifact.

Returns:
Dictionary where keys are the artifacts, and the values are the dataset stored in the artifact.
"""
return datasources.get_from_mlflow(run, artifact_name)

Expand Down
47 changes: 35 additions & 12 deletions dagshub/data_engine/datasources.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import logging
from typing import Optional, Union, List, TYPE_CHECKING
from typing import Optional, Union, List, TYPE_CHECKING, Dict

from dagshub.common.analytics import send_analytics_event
from dagshub.common.api.repo import RepoAPI
from dagshub.common.util import lazy_load
from dagshub.common.util import lazy_load, removeprefix
from dagshub.data_engine.client.data_client import DataClient
from dagshub.data_engine.model.datasource import Datasource, DEFAULT_MLFLOW_ARTIFACT_NAME
from dagshub.data_engine.model.datasource import Datasource
from dagshub.data_engine.model.datasource_state import DatasourceState, DatasourceType, path_regexes
from dagshub.data_engine.model.errors import DatasourceNotFoundError

Expand Down Expand Up @@ -169,18 +169,22 @@ def get_datasources(repo: str) -> List[Datasource]:


def get_from_mlflow(
run: Optional[Union["mlflow.entities.Run", str]] = None, artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME
) -> Datasource:
run: Optional[Union["mlflow.entities.Run", str]] = None, artifact_name: Optional[str] = None
simonlsk marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, Datasource]:
"""
Load a datasource from an MLflow run.
Load datasources from an MLflow run.

To save a datasource to MLflow, use
:func:`Datasource.log_to_mlflow()<dagshub.data_engine.model.datasource.Datasource.log_to_mlflow>`.
:func:`QueryResult.log_to_mlflow()<dagshub.data_engine.model.query_result.QueryResult.log_to_mlflow>`.

Args:
run: MLflow Run or its ID to load the datasource from.
If ``None``, loads datasource from the current active run.
artifact_name: Name of the datasource artifact in the run.
If specified, will only return the datasource defined in this artifact.

Returns:
Dictionary where keys are the artifacts, and the values are the datasource stored in the artifact.
"""
mlflow_run: "mlflow.entities.Run"
if run is None:
Expand All @@ -190,11 +194,7 @@ def get_from_mlflow(
else:
mlflow_run = run

artifact_uri: str = mlflow_run.info.artifact_uri
artifact_path = f"{artifact_uri.rstrip('/')}/{artifact_name.lstrip('/')}"

ds_state = mlflow_artifacts.load_dict(artifact_path)
return Datasource.load_from_serialized_state(ds_state)
return _load_datasources_from_run(mlflow_run, artifact_name)


def get(*args, **kwargs) -> Datasource:
Expand All @@ -210,6 +210,29 @@ def _create_datasource_state(repo: str, name: str, source_type: DatasourceType,
return ds


def _load_datasources_from_run(
run: "mlflow.entities.Run", artifact_name: Optional[str] = None
) -> Dict[str, Datasource]:
if artifact_name is not None:
artifact_uri: str = run.info.artifact_uri
artifact_path = f"{artifact_uri.rstrip('/')}/{artifact_name.lstrip('/')}"

ds_states = {artifact_name: mlflow_artifacts.load_dict(artifact_path)}
else:
# Load all artifacts ending with ".dagshub.dataset.json"
artifacts = mlflow_artifacts.list_artifacts(run.info.artifact_uri)
# mlflow returns paths with `artifacts/` in the beginning. Need to cut it off because it's also in artifact_uri
artifact_paths = [removeprefix(a.path, "artifacts/") for a in artifacts]

ds_states = {
p: mlflow_artifacts.load_dict(f"{run.info.artifact_uri}/{p}")
for p in artifact_paths
if p.endswith(".dagshub.dataset.json")
}

return {p: Datasource.load_from_serialized_state(ds_state) for p, ds_state in ds_states.items()}


__all__ = [
create_datasource.__name__,
create.__name__,
Expand Down
72 changes: 51 additions & 21 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dagshub.common.environment import is_mlflow_installed
from dagshub.common.helpers import prompt_user, http_request, log_message
from dagshub.common.rich_util import get_rich_progress
from dagshub.common.util import lazy_load, multi_urljoin, to_timestamp, exclude_if_none
from dagshub.common.util import lazy_load, multi_urljoin, to_timestamp, exclude_if_none, deprecated
from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationType, AnnotationLocation
from dagshub.data_engine.client.models import (
PreprocessingStatus,
Expand Down Expand Up @@ -77,7 +77,6 @@
logger = logging.getLogger(__name__)

LS_ORCHESTRATOR_URL = "http://127.0.0.1"
DEFAULT_MLFLOW_ARTIFACT_NAME = "datasource.dagshub.json"
MLFLOW_DATASOURCE_TAG_NAME = "dagshub.datasets.datasource_id"
MLFLOW_DATASET_TAG_NAME = "dagshub.datasets.dataset_id"

Expand Down Expand Up @@ -845,33 +844,21 @@ def save_dataset(self, name: str) -> "Datasource":
copy_with_ds_assigned.load_from_dataset(dataset_name=name, change_query=False)
return copy_with_ds_assigned

def _autolog_mlflow(self, qr: "QueryResult"):
if not is_mlflow_installed:
return
# Run ONLY if there's an active run going on
active_run = mlflow.active_run()
if active_run is None:
return
source_name = self.source.name

now_time = qr.query_data_time.strftime("%Y-%m-%dT%H-%M-%S") # Not ISO format to make it a valid filename
uuid_chunk = str(uuid.uuid4())[-4:]

artifact_name = f"autolog_{source_name}_{now_time}_{uuid_chunk}.dagshub.json"
threading.Thread(
target=self.log_to_mlflow,
kwargs={"artifact_name": artifact_name, "run": active_run, "as_of": qr.query_data_time},
).start()

@deprecated("Either use autologging, or QueryResult.log_to_mlflow() if there autologging is turned off")
kbolashev marked this conversation as resolved.
Show resolved Hide resolved
def log_to_mlflow(
self,
artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME,
artifact_name: Optional[str] = None,
run: Optional["mlflow.entities.Run"] = None,
as_of: Optional[datetime.datetime] = None,
) -> "mlflow.Entities.Run":
"""
Logs the current datasource state to MLflow as an artifact.

.. warning::
This function is deprecated. Use autologging or
:func:`QueryResult.log_to_mlflow() <dagshub.data_engine.model.query_result.QueryResult.log_to_mlflow>`
instead.

Args:
artifact_name: Name of the artifact that will be stored in the MLflow run.
run: MLflow run to save to. If ``None``, uses the active MLflow run or creates a new run.
Expand All @@ -884,6 +871,35 @@ def log_to_mlflow(
Returns:
Run to which the artifact was logged.
"""
if artifact_name is None:
as_of = as_of or (self._query.as_of or datetime.datetime.now())
artifact_name = self._get_mlflow_artifact_name("log", as_of)
elif not artifact_name.endswith(".dagshub.dataset.json"):
artifact_name += ".dagshub.dataset.json"

return self._log_to_mlflow(artifact_name, run, as_of)

def _autolog_mlflow(self, qr: "QueryResult"):
if not is_mlflow_installed:
return
# Run ONLY if there's an active run going on
active_run = mlflow.active_run()
if active_run is None:
return

artifact_name = self._get_mlflow_artifact_name("autolog", qr.query_data_time)

threading.Thread(
target=self._log_to_mlflow,
kwargs={"artifact_name": artifact_name, "run": active_run, "as_of": qr.query_data_time},
).start()

def _log_to_mlflow(
self,
artifact_name,
run: Optional["mlflow.entities.Run"] = None,
as_of: Optional[datetime.datetime] = None,
) -> "mlflow.Entities.Run":
if run is None:
run = mlflow.active_run()
if run is None:
Expand All @@ -896,6 +912,11 @@ def log_to_mlflow(
log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"')
return run

def _get_mlflow_artifact_name(self, prefix: str, as_of: datetime.datetime) -> str:
now_time = as_of.strftime("%Y-%m-%dT%H-%M-%S") # Not ISO format to make it a valid filename
uuid_chunk = str(uuid.uuid4())[-4:]
return f"{prefix}_{self.source.name}_{now_time}_{uuid_chunk}.dagshub.dataset.json"

def save_to_file(self, path: Union[str, PathLike] = ".") -> Path:
"""
Saves a JSON file representing the current state of datasource or dataset.
Expand Down Expand Up @@ -974,6 +995,8 @@ def load_from_serialized_state(state_dict: Dict) -> "Datasource":
"""

state = DatasourceSerializedState.from_dict(state_dict)
# The json_dataclasses.from_dict() doesn't respect the default value hints, so we fill it out for it
state.query._fill_out_defaults()

ds_state = DatasourceState(repo=state.repo, name=state.datasource_name, id=state.datasource_id)
ds_state.get_from_dagshub()
Expand All @@ -983,6 +1006,9 @@ def load_from_serialized_state(state_dict: Dict) -> "Datasource":
if state.dataset_id is not None:
ds.load_from_dataset(state.dataset_id, state.dataset_name, change_query=False)

if state.timestamp is not None:
ds = ds.as_of(datetime.datetime.fromtimestamp(state.timestamp, tz=datetime.timezone.utc))

return ds

def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset":
Expand Down Expand Up @@ -1888,6 +1914,10 @@ class DatasourceQuery(DataClassJsonMixin):
default=None, metadata=config(exclude=exclude_if_none, letter_case=LetterCase.CAMEL)
)

def _fill_out_defaults(self):
"""For functions that don't utilize the default hints of the dataclass"""
self.filter = QueryFilterTree()

def __deepcopy__(self, memodict={}):
other = DatasourceQuery(
as_of=self.as_of,
Expand Down
19 changes: 19 additions & 0 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import datasets as hf_ds
import tensorflow as tf
import mlflow
import mlflow.entities
else:
plugin_server_module = lazy_load("dagshub.data_engine.voxel_plugin_server.server")
fo = lazy_load("fiftyone")
Expand Down Expand Up @@ -942,3 +943,21 @@ def _load_autoload_fields(self, documents=True, annotations=True):

if annotations:
self.get_annotations()

def log_to_mlflow(self, run: Optional["mlflow.entities.Run"] = None) -> "mlflow.entities.Run":
"""
Logs the query result information to MLflow as an artifact.
The artifact will be saved at the root of the run with the name in the format of
``log_{datasource_name}_{query_time}_{random_chunk}.dagshub.dataset.json``.

You can later load the dataset back from MLflow using :func:`dagshub.data_engine.datasources.get_from_mlflow`.

Args:
run: MLflow run to save to. If ``None``, uses the active MLflow run or creates a new run.

Returns:
Run to which the artifact was logged.
"""
assert self.query_data_time is not None
artifact_name = self.datasource._get_mlflow_artifact_name("log", self.query_data_time)
return self.datasource._log_to_mlflow(artifact_name, run, self.query_data_time)
Loading