Skip to content

Commit

Permalink
AnalysisCard encoder/decoder refactor (#2643)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2643

Add encoder/decoder logic for `AnalysisCard`

Reviewed By: mpolson64

Differential Revision: D60316107
  • Loading branch information
Cesar-Cardoso authored and facebook-github-bot committed Aug 8, 2024
1 parent 89a3d0c commit eff3685
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.utils.common.base import Base


class AnalysisCardLevel(Enum):
Expand All @@ -21,7 +22,7 @@ class AnalysisCardLevel(Enum):
CRITICAL = 4


class AnalysisCard:
class AnalysisCard(Base):
# Name of the analysis computed, usually the class name of the Analysis which
# produced the card. Useful for grouping by when querying a large collection of
# cards.
Expand Down
17 changes: 17 additions & 0 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -50,6 +51,7 @@
from ax.storage.sqa_store.db import session_scope
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -66,6 +68,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pandas import read_json
from pyre_extensions import assert_is_instance
from sqlalchemy.orm.exc import DetachedInstanceError

Expand Down Expand Up @@ -982,6 +985,20 @@ def data_from_sqa(
dat.db_id = data_sqa.id
return dat

def analysis_card_from_sqa(
self,
analysis_card_sqa: SQAAnalysisCard,
) -> AnalysisCard:
"""Convert SQLAlchemy Analysis to Ax Analysis Object."""
return AnalysisCard(
name=analysis_card_sqa.name,
title=analysis_card_sqa.title,
subtitle=analysis_card_sqa.subtitle,
level=AnalysisCardLevel(analysis_card_sqa.level),
df=read_json(analysis_card_sqa.dataframe_json),
blob=analysis_card_sqa.blob,
)

def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric:
"""Convert SQLAlchemy Metric to Ax Metric"""
if metric_sqa.metric_type not in self.config.reverse_metric_registry:
Expand Down
30 changes: 30 additions & 0 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

# pyre-strict

from datetime import datetime
from enum import Enum

from logging import Logger
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import AbandonedArm, BatchTrial
Expand Down Expand Up @@ -45,6 +48,7 @@
from ax.storage.json_store.encoder import object_to_json
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand Down Expand Up @@ -1051,3 +1055,29 @@ def data_to_sqa(
)
),
)

def analysis_card_to_sqa(
self,
analysis_card: AnalysisCard,
experiment_id: int,
timestamp: datetime,
) -> SQAAnalysisCard:
"""Convert Ax analysis to SQLAlchemy."""
# pyre-fixme: Expected `Base` for 1st...ot `typing.Type[BaseAnalysis]`.
analysis_card_class: SQAAnalysisCard = self.config.class_to_sqa_class[
AnalysisCard
]

# pyre-fixme[29]: `SQAAnalysisCard` is not a function.
return analysis_card_class(
id=analysis_card.db_id,
name=analysis_card.name,
title=analysis_card.title,
subtitle=analysis_card.subtitle,
level=analysis_card.level,
dataframe_json=analysis_card.df.to_json(),
blob=analysis_card.blob,
blob_annotation=analysis_card.blob_annotation,
time_created=timestamp,
experiment_id=experiment_id,
)

0 comments on commit eff3685

Please sign in to comment.