From 4a227b1a8afe7272ceb7608568e05a45076e4692 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:51:38 +0400 Subject: [PATCH] [feat] Store Run sequence metadata by separate types (#3260) * [feat] Use type-based sequence info when iterating Run sequence data * [feat] Add CLI command to update Run sequence metadata * [fix] Resolve code formatting issues --- aim/cli/runs/commands.py | 42 ++++++++++++++++++++++++++ aim/sdk/reporter/file_manager.py | 6 ++-- aim/sdk/run.py | 42 +++++++++++++++++++------- aim/sdk/run_status_watcher.py | 9 ++++-- aim/sdk/sequences/sequence_type_map.py | 14 +++++++++ aim/sdk/tracker.py | 9 +++++- 6 files changed, 105 insertions(+), 17 deletions(-) create mode 100644 aim/sdk/sequences/sequence_type_map.py diff --git a/aim/cli/runs/commands.py b/aim/cli/runs/commands.py index 6793c209d4..b881e4e2a2 100644 --- a/aim/cli/runs/commands.py +++ b/aim/cli/runs/commands.py @@ -7,6 +7,8 @@ from aim.cli.runs.utils import make_zip_archive, match_runs, upload_repo_runs from aim.sdk.repo import Repo +from aim.sdk.index_manager import RepoIndexManager +from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP from psutil import cpu_count @@ -169,3 +171,43 @@ def close_runs(ctx, hashes, yes): for _ in tqdm.tqdm(pool.imap_unordered(repo._close_run, hashes), desc='Closing runs', total=len(hashes)): pass + + +@runs.command(name='update-metrics') +@click.pass_context +@click.option('-y', '--yes', is_flag=True, help='Automatically confirm prompt') +def update_metrics(ctx, yes): + """Separate Sequence metadata for optimal read.""" + repo_path = ctx.obj['repo'] + repo = Repo.from_path(repo_path) + + click.secho( + f"This command will update Runs from Aim Repo '{repo_path}' to the latest data format to ensure better " + f'performance. Please make sure no Runs are active and Aim UI is not running.' + ) + if yes: + confirmed = True + else: + confirmed = click.confirm('Do you want to proceed?') + if not confirmed: + return + + index_manager = RepoIndexManager.get_index_manager(repo) + hashes = repo.list_all_runs() + for run_hash in tqdm.tqdm(hashes, desc='Updating runs', total=len(hashes)): + meta_tree = repo.request_tree('meta', run_hash, read_only=False, from_union=False) + meta_run_tree = meta_tree.subtree(('meta', 'chunks', run_hash)) + try: + # check if the Run has already been updated. + meta_run_tree.first_key('typed_traces') + click.secho(f'Run {run_hash} is uo-to-date. Skipping.') + continue + except KeyError: + for ctx_idx, run_ctx_dict in meta_run_tree.subtree('traces').items(): + assert isinstance(ctx_idx, int) + for seq_name in run_ctx_dict.keys(): + assert isinstance(seq_name, str) + dtype = run_ctx_dict[seq_name].get('dtype', 'float') + seq_type = SEQUENCE_TYPE_MAP.get(dtype, 'sequence') + meta_run_tree['typed_traces', seq_type, ctx_idx, seq_name] = 1 + index_manager.index(run_hash) diff --git a/aim/sdk/reporter/file_manager.py b/aim/sdk/reporter/file_manager.py index 80c2d9a855..72633f084b 100644 --- a/aim/sdk/reporter/file_manager.py +++ b/aim/sdk/reporter/file_manager.py @@ -10,10 +10,12 @@ class FileManager(object): @abstractmethod - def poll(self, pattern: str) -> Optional[str]: ... + def poll(self, pattern: str) -> Optional[str]: + ... @abstractmethod - def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): ... + def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): + ... class LocalFileManager(FileManager): diff --git a/aim/sdk/run.py b/aim/sdk/run.py index d21d2a882e..d277498831 100644 --- a/aim/sdk/run.py +++ b/aim/sdk/run.py @@ -27,6 +27,7 @@ from aim.sdk.reporter import RunStatusReporter, ScheduledStatusReporter from aim.sdk.reporter.file_manager import LocalFileManager from aim.sdk.sequence import Sequence +from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP from aim.sdk.sequence_collection import SingleRunSequenceCollection from aim.sdk.tracker import RunTracker from aim.sdk.types import AimObject @@ -496,17 +497,36 @@ def iter_sequence_info_by_type(self, dtypes: Union[str, Tuple[str, ...]]) -> Ite """ if isinstance(dtypes, str): dtypes = (dtypes,) - for ctx_idx, run_ctx_dict in self.meta_run_tree.subtree('traces').items(): - assert isinstance(ctx_idx, int) - ctx = self.idx_to_ctx(ctx_idx) - # run_ctx_view = run_meta_traces.view(ctx_idx) - for seq_name in run_ctx_dict.keys(): - assert isinstance(seq_name, str) - # skip sequences not matching dtypes. - # sequences with no dtype are considered to be float sequences. - # '*' stands for all data types - if '*' in dtypes or run_ctx_dict[seq_name].get('dtype', 'float') in dtypes: - yield seq_name, ctx, self + try: + self.meta_run_tree.first_key('typed_traces') + has_trace_type_info = True + except KeyError: + has_trace_type_info = False + + if has_trace_type_info: + # use set to remove duplicates for overlapping types (such as int and float for metric) + trace_types = set() + for dtype in dtypes: + trace_types.add(SEQUENCE_TYPE_MAP.get(dtype)) + for trace_type in trace_types: + for ctx_idx, run_ctx_dict in self.meta_run_tree.subtree('typed_traces').get(trace_type, {}).items(): + assert isinstance(ctx_idx, int) + ctx = self.idx_to_ctx(ctx_idx) + for seq_name in run_ctx_dict.keys(): + assert isinstance(seq_name, str) + yield seq_name, ctx, self + else: + for ctx_idx, run_ctx_dict in self.meta_run_tree.subtree('traces').items(): + assert isinstance(ctx_idx, int) + ctx = self.idx_to_ctx(ctx_idx) + # run_ctx_view = run_meta_traces.view(ctx_idx) + for seq_name in run_ctx_dict.keys(): + assert isinstance(seq_name, str) + # skip sequences not matching dtypes. + # sequences with no dtype are considered to be float sequences. + # '*' stands for all data types + if '*' in dtypes or run_ctx_dict[seq_name].get('dtype', 'float') in dtypes: + yield seq_name, ctx, self def metrics(self) -> 'SequenceCollection': """Get iterable object for all run tracked metrics. diff --git a/aim/sdk/run_status_watcher.py b/aim/sdk/run_status_watcher.py index 422cbff128..ccf203bd51 100644 --- a/aim/sdk/run_status_watcher.py +++ b/aim/sdk/run_status_watcher.py @@ -83,13 +83,16 @@ def __init__(self, *, obj_idx: Optional[str] = None, rank: Optional[int] = None, self.message = message @abstractmethod - def is_sent(self): ... + def is_sent(self): + ... @abstractmethod - def update_last_sent(self): ... + def update_last_sent(self): + ... @abstractmethod - def get_msg_details(self): ... + def get_msg_details(self): + ... class StatusNotification(Notification): diff --git a/aim/sdk/sequences/sequence_type_map.py b/aim/sdk/sequences/sequence_type_map.py new file mode 100644 index 0000000000..5f7476fed3 --- /dev/null +++ b/aim/sdk/sequences/sequence_type_map.py @@ -0,0 +1,14 @@ +SEQUENCE_TYPE_MAP = { + 'int': 'metric', + 'float': 'metric', + 'float64': 'metric', + 'number': 'metric', + 'aim.image': 'images', + 'list(aim.image)': 'images', + 'aim.audio': 'audios', + 'list(aim.audio)': 'audios', + 'aim.text': 'texts', + 'list(aim.text)': 'texts', + 'aim.distribution': 'distributions', + 'aim.figure': 'figures', +} diff --git a/aim/sdk/tracker.py b/aim/sdk/tracker.py index 43b78a5204..90b87dd26b 100644 --- a/aim/sdk/tracker.py +++ b/aim/sdk/tracker.py @@ -15,6 +15,7 @@ from aim.storage.hashing import hash_auto from aim.storage.object import CustomObject from aim.storage.types import AimObject +from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP if TYPE_CHECKING: @@ -152,7 +153,7 @@ def _load_sequence_info(self, ctx_id: int, name: str): try: seq_info.version = self.meta_run_tree['traces', ctx_id, name, 'version'] except KeyError: - self.meta_run_tree['traces', ctx_id, name, 'version'] = seq_info.dtype = 1 + self.meta_run_tree['traces', ctx_id, name, 'version'] = seq_info.version = 1 try: seq_info.dtype = self.meta_run_tree['traces', ctx_id, name, 'dtype'] except KeyError: @@ -213,7 +214,11 @@ def _update_sequence_info(self, ctx_id: int, name: str, val, step: int): def update_trace_dtype(old_dtype: str, new_dtype: str): logger.warning(f"Updating sequence '{name}' data type from {old_dtype} to {new_dtype}.") + new_trace_type = SEQUENCE_TYPE_MAP.get( + dtype, 'sequence' + ) # use mapping from value type to sequence type self.meta_tree['traces_types', new_dtype, ctx_id, name] = 1 + self.meta_run_tree['typed_traces', new_trace_type, ctx_id, name] = 1 self.meta_run_tree['traces', ctx_id, name, 'dtype'] = new_dtype seq_info.dtype = new_dtype @@ -222,11 +227,13 @@ def update_trace_dtype(old_dtype: str, new_dtype: str): raise ValueError(f"Cannot log value '{val}' on sequence '{name}'. Incompatible data types.") if seq_info.count == 0: + trace_type = SEQUENCE_TYPE_MAP.get(dtype, 'sequence') # use mapping from value type to sequence type self.meta_tree['traces_types', dtype, ctx_id, name] = 1 self.meta_run_tree['traces', ctx_id, name, 'dtype'] = dtype self.meta_run_tree['traces', ctx_id, name, 'version'] = seq_info.version self.meta_run_tree['traces', ctx_id, name, 'first'] = val self.meta_run_tree['traces', ctx_id, name, 'first_step'] = step + self.meta_run_tree['typed_traces', trace_type, ctx_id, name] = 1 seq_info.dtype = dtype if step >= seq_info.count: