Skip to content

Commit

Permalink
[feat] Store Run sequence metadata by separate types (#3260)
Browse files Browse the repository at this point in the history
* [feat] Use type-based sequence info when iterating Run sequence data

* [feat] Add CLI command to update Run sequence metadata

* [fix] Resolve code formatting issues
  • Loading branch information
alberttorosyan authored Dec 3, 2024
1 parent 2c00b44 commit 4a227b1
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 17 deletions.
42 changes: 42 additions & 0 deletions aim/cli/runs/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from aim.cli.runs.utils import make_zip_archive, match_runs, upload_repo_runs
from aim.sdk.repo import Repo
from aim.sdk.index_manager import RepoIndexManager
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from psutil import cpu_count


Expand Down Expand Up @@ -169,3 +171,43 @@ def close_runs(ctx, hashes, yes):

for _ in tqdm.tqdm(pool.imap_unordered(repo._close_run, hashes), desc='Closing runs', total=len(hashes)):
pass


@runs.command(name='update-metrics')
@click.pass_context
@click.option('-y', '--yes', is_flag=True, help='Automatically confirm prompt')
def update_metrics(ctx, yes):
"""Separate Sequence metadata for optimal read."""
repo_path = ctx.obj['repo']
repo = Repo.from_path(repo_path)

click.secho(
f"This command will update Runs from Aim Repo '{repo_path}' to the latest data format to ensure better "
f'performance. Please make sure no Runs are active and Aim UI is not running.'
)
if yes:
confirmed = True
else:
confirmed = click.confirm('Do you want to proceed?')
if not confirmed:
return

index_manager = RepoIndexManager.get_index_manager(repo)
hashes = repo.list_all_runs()
for run_hash in tqdm.tqdm(hashes, desc='Updating runs', total=len(hashes)):
meta_tree = repo.request_tree('meta', run_hash, read_only=False, from_union=False)
meta_run_tree = meta_tree.subtree(('meta', 'chunks', run_hash))
try:
# check if the Run has already been updated.
meta_run_tree.first_key('typed_traces')
click.secho(f'Run {run_hash} is uo-to-date. Skipping.')
continue
except KeyError:
for ctx_idx, run_ctx_dict in meta_run_tree.subtree('traces').items():
assert isinstance(ctx_idx, int)
for seq_name in run_ctx_dict.keys():
assert isinstance(seq_name, str)
dtype = run_ctx_dict[seq_name].get('dtype', 'float')
seq_type = SEQUENCE_TYPE_MAP.get(dtype, 'sequence')
meta_run_tree['typed_traces', seq_type, ctx_idx, seq_name] = 1
index_manager.index(run_hash)
6 changes: 4 additions & 2 deletions aim/sdk/reporter/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

class FileManager(object):
@abstractmethod
def poll(self, pattern: str) -> Optional[str]: ...
def poll(self, pattern: str) -> Optional[str]:
...

@abstractmethod
def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): ...
def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None):
...


class LocalFileManager(FileManager):
Expand Down
42 changes: 31 additions & 11 deletions aim/sdk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from aim.sdk.reporter import RunStatusReporter, ScheduledStatusReporter
from aim.sdk.reporter.file_manager import LocalFileManager
from aim.sdk.sequence import Sequence
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.sequence_collection import SingleRunSequenceCollection
from aim.sdk.tracker import RunTracker
from aim.sdk.types import AimObject
Expand Down Expand Up @@ -496,17 +497,36 @@ def iter_sequence_info_by_type(self, dtypes: Union[str, Tuple[str, ...]]) -> Ite
"""
if isinstance(dtypes, str):
dtypes = (dtypes,)
for ctx_idx, run_ctx_dict in self.meta_run_tree.subtree('traces').items():
assert isinstance(ctx_idx, int)
ctx = self.idx_to_ctx(ctx_idx)
# run_ctx_view = run_meta_traces.view(ctx_idx)
for seq_name in run_ctx_dict.keys():
assert isinstance(seq_name, str)
# skip sequences not matching dtypes.
# sequences with no dtype are considered to be float sequences.
# '*' stands for all data types
if '*' in dtypes or run_ctx_dict[seq_name].get('dtype', 'float') in dtypes:
yield seq_name, ctx, self
try:
self.meta_run_tree.first_key('typed_traces')
has_trace_type_info = True
except KeyError:
has_trace_type_info = False

if has_trace_type_info:
# use set to remove duplicates for overlapping types (such as int and float for metric)
trace_types = set()
for dtype in dtypes:
trace_types.add(SEQUENCE_TYPE_MAP.get(dtype))
for trace_type in trace_types:
for ctx_idx, run_ctx_dict in self.meta_run_tree.subtree('typed_traces').get(trace_type, {}).items():
assert isinstance(ctx_idx, int)
ctx = self.idx_to_ctx(ctx_idx)
for seq_name in run_ctx_dict.keys():
assert isinstance(seq_name, str)
yield seq_name, ctx, self
else:
for ctx_idx, run_ctx_dict in self.meta_run_tree.subtree('traces').items():
assert isinstance(ctx_idx, int)
ctx = self.idx_to_ctx(ctx_idx)
# run_ctx_view = run_meta_traces.view(ctx_idx)
for seq_name in run_ctx_dict.keys():
assert isinstance(seq_name, str)
# skip sequences not matching dtypes.
# sequences with no dtype are considered to be float sequences.
# '*' stands for all data types
if '*' in dtypes or run_ctx_dict[seq_name].get('dtype', 'float') in dtypes:
yield seq_name, ctx, self

def metrics(self) -> 'SequenceCollection':
"""Get iterable object for all run tracked metrics.
Expand Down
9 changes: 6 additions & 3 deletions aim/sdk/run_status_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,16 @@ def __init__(self, *, obj_idx: Optional[str] = None, rank: Optional[int] = None,
self.message = message

@abstractmethod
def is_sent(self): ...
def is_sent(self):
...

@abstractmethod
def update_last_sent(self): ...
def update_last_sent(self):
...

@abstractmethod
def get_msg_details(self): ...
def get_msg_details(self):
...


class StatusNotification(Notification):
Expand Down
14 changes: 14 additions & 0 deletions aim/sdk/sequences/sequence_type_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
SEQUENCE_TYPE_MAP = {
'int': 'metric',
'float': 'metric',
'float64': 'metric',
'number': 'metric',
'aim.image': 'images',
'list(aim.image)': 'images',
'aim.audio': 'audios',
'list(aim.audio)': 'audios',
'aim.text': 'texts',
'list(aim.text)': 'texts',
'aim.distribution': 'distributions',
'aim.figure': 'figures',
}
9 changes: 8 additions & 1 deletion aim/sdk/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aim.storage.hashing import hash_auto
from aim.storage.object import CustomObject
from aim.storage.types import AimObject
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP


if TYPE_CHECKING:
Expand Down Expand Up @@ -152,7 +153,7 @@ def _load_sequence_info(self, ctx_id: int, name: str):
try:
seq_info.version = self.meta_run_tree['traces', ctx_id, name, 'version']
except KeyError:
self.meta_run_tree['traces', ctx_id, name, 'version'] = seq_info.dtype = 1
self.meta_run_tree['traces', ctx_id, name, 'version'] = seq_info.version = 1
try:
seq_info.dtype = self.meta_run_tree['traces', ctx_id, name, 'dtype']
except KeyError:
Expand Down Expand Up @@ -213,7 +214,11 @@ def _update_sequence_info(self, ctx_id: int, name: str, val, step: int):

def update_trace_dtype(old_dtype: str, new_dtype: str):
logger.warning(f"Updating sequence '{name}' data type from {old_dtype} to {new_dtype}.")
new_trace_type = SEQUENCE_TYPE_MAP.get(
dtype, 'sequence'
) # use mapping from value type to sequence type
self.meta_tree['traces_types', new_dtype, ctx_id, name] = 1
self.meta_run_tree['typed_traces', new_trace_type, ctx_id, name] = 1
self.meta_run_tree['traces', ctx_id, name, 'dtype'] = new_dtype
seq_info.dtype = new_dtype

Expand All @@ -222,11 +227,13 @@ def update_trace_dtype(old_dtype: str, new_dtype: str):
raise ValueError(f"Cannot log value '{val}' on sequence '{name}'. Incompatible data types.")

if seq_info.count == 0:
trace_type = SEQUENCE_TYPE_MAP.get(dtype, 'sequence') # use mapping from value type to sequence type
self.meta_tree['traces_types', dtype, ctx_id, name] = 1
self.meta_run_tree['traces', ctx_id, name, 'dtype'] = dtype
self.meta_run_tree['traces', ctx_id, name, 'version'] = seq_info.version
self.meta_run_tree['traces', ctx_id, name, 'first'] = val
self.meta_run_tree['traces', ctx_id, name, 'first_step'] = step
self.meta_run_tree['typed_traces', trace_type, ctx_id, name] = 1
seq_info.dtype = dtype

if step >= seq_info.count:
Expand Down

0 comments on commit 4a227b1

Please sign in to comment.