Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnalysisCard encoder/decoder refactor #2643

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.utils.common.base import Base


class AnalysisCardLevel(Enum):
Expand All @@ -21,7 +22,7 @@ class AnalysisCardLevel(Enum):
CRITICAL = 4


class AnalysisCard:
class AnalysisCard(Base):
# Name of the analysis computed, usually the class name of the Analysis which
# produced the card. Useful for grouping by when querying a large collection of
# cards.
Expand Down
45 changes: 14 additions & 31 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
import plotly.io as pio

from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -54,7 +51,7 @@
from ax.storage.sqa_store.db import session_scope
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -67,12 +64,7 @@
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.utils import (
AnalysisType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
Expand Down Expand Up @@ -993,28 +985,19 @@ def data_from_sqa(
dat.db_id = data_sqa.id
return dat

def analysis_from_sqa(
def analysis_card_from_sqa(
self,
analysis_sqa: SQAAnalysis,
experiment: Experiment,
) -> BaseAnalysis:
analysis_card_sqa: SQAAnalysisCard,
) -> AnalysisCard:
"""Convert SQLAlchemy Analysis to Ax Analysis Object."""
# TODO: generalize solution for pd dataframe type casting of "arm_name" column.
if analysis_sqa.experiment_analysis_type == AnalysisType.PLOTLY_VISUALIZATION:
return BasePlotlyVisualization(
experiment=experiment,
df_input=read_json(
analysis_sqa.dataframe_json, dtype={"arm_name": "str"}
),
fig_input=pio.from_json(analysis_sqa.fig_json, output_type="Figure"),
)
else:
return BaseAnalysis(
experiment=experiment,
df_input=read_json(
analysis_sqa.dataframe_json, dtype={"arm_name": "str"}
),
)
return AnalysisCard(
name=analysis_card_sqa.name,
title=analysis_card_sqa.title,
subtitle=analysis_card_sqa.subtitle,
level=AnalysisCardLevel(analysis_card_sqa.level),
df=read_json(analysis_card_sqa.dataframe_json),
blob=analysis_card_sqa.blob,
)

def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric:
"""Convert SQLAlchemy Metric to Ax Metric"""
Expand Down
66 changes: 25 additions & 41 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@

# pyre-strict

from datetime import datetime
from enum import Enum

from logging import Logger
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import plotly
import plotly.io as pio

from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
Expand Down Expand Up @@ -51,7 +48,7 @@
from ax.storage.json_store.encoder import object_to_json
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -64,12 +61,7 @@
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.utils import (
AnalysisType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -1064,36 +1056,28 @@ def data_to_sqa(
),
)

def analysis_to_sqa(
def analysis_card_to_sqa(
self,
analysis: BaseAnalysis,
) -> SQAAnalysis:
analysis_card: AnalysisCard,
experiment_id: int,
timestamp: datetime,
) -> SQAAnalysisCard:
"""Convert Ax analysis to SQLAlchemy."""
# pyre-fixme: Expected `Base` for 1st...ot `typing.Type[BaseAnalysis]`.
analysis_class: SQAAnalysis = self.config.class_to_sqa_class[BaseAnalysis]

is_plotly_visualization: bool = isinstance(analysis, BasePlotlyVisualization)

# pyre-fixme[29]: `SQAAnalysis` is not a function.
return analysis_class(
id=-1,
analysis_class_name=type(analysis).__name__,
time_analysis_start=-1,
time_analysis_completed=-1,
experiment_analysis_type=(
AnalysisType.PLOTLY_VISUALIZATION
if is_plotly_visualization
else AnalysisType.ANALYSIS
),
dataframe_json=analysis.df.to_json(),
fig_json=(
None
if not is_plotly_visualization
else pio.to_json(
checked_cast(BasePlotlyVisualization, analysis).fig,
validate=True,
remove_uids=False,
)
),
plotly_version=plotly.__version__,
analysis_card_class: SQAAnalysisCard = self.config.class_to_sqa_class[
AnalysisCard
]

# pyre-fixme[29]: `SQAAnalysisCard` is not a function.
return analysis_card_class(
id=analysis_card.db_id,
name=analysis_card.name,
title=analysis_card.title,
subtitle=analysis_card.subtitle,
level=analysis_card.level,
dataframe_json=analysis_card.df.to_json(),
blob=analysis_card.blob,
blob_annotation=analysis_card.blob_annotation,
time_created=timestamp,
experiment_id=experiment_id,
)
89 changes: 30 additions & 59 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from ax.analysis.analysis import AnalysisCardLevel

from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import LifecycleStage
from ax.core.parameter import ParameterType
Expand All @@ -34,13 +36,7 @@
)
from ax.storage.sqa_store.sqa_enum import IntEnum, StringEnum
from ax.storage.sqa_store.timestamp import IntTimestamp
from ax.storage.utils import (
AnalysisType,
DataType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DataType, DomainType, MetricIntent, ParameterConstraintType
from sqlalchemy import (
BigInteger,
Boolean,
Expand Down Expand Up @@ -462,6 +458,31 @@ class SQATrial(Base):
)


class SQAAnalysisCard(Base):
__tablename__: str = "analysis_card"

# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
id: int = Column(Integer, primary_key=True)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
title: str = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
subtitle: str = Column(Text, nullable=False)
# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
level: AnalysisCardLevel = Column(IntEnum(AnalysisCardLevel), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
dataframe_json: str = Column(Text(LONGTEXT_BYTES), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
blob: str = Column(Text(LONGTEXT_BYTES), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
blob_annotation: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
time_created: datetime = Column(IntTimestamp, nullable=False)
# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id"), nullable=False)


class SQAExperiment(Base):
__tablename__: str = "experiment_v2"

Expand Down Expand Up @@ -520,56 +541,6 @@ class SQAExperiment(Base):
uselist=False,
lazy=True,
)


class SQAAnalysis(Base):
__tablename__: str = "analysis"

# pyre-fixme[8]: Attribute has type `int`; used as `Column[int]`.
id: int = Column(Integer, primary_key=True)

# pyre-fixme[8]: Attribute has type `str`; used as `Column[str]`.
analysis_class_name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`.
time_analysis_start: datetime = Column(IntTimestamp, nullable=False)
# pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`.
time_analysis_completed: datetime = Column(IntTimestamp, nullable=False)

# pyre-fixme[8]: Attribute has type `AnalysisType`; used as
# `Column[typing.Any]`.
experiment_analysis_type: AnalysisType = Column(
StringEnum(AnalysisType), nullable=False
)

# pyre-fixme[8]: Attribute has type `str`; used as `Column[str]`.
dataframe_json: str = Column(Text(LONGTEXT_BYTES), nullable=False)

# pyre-fixme[8]: Attribute has type `Optional[str]`; used as
# `Column[Optional[str]]`.
fig_json: Optional[str] = Column(Text(LONGTEXT_BYTES), nullable=True)
# pyre-fixme[8]: Attribute has type `Optional[str]`; used as
# `Column[Optional[str]]`.
plotly_version: Optional[str] = Column(
String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True
)

# pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`.
experiment_id: int = Column(Integer)
# pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`.
analysis_report_id: int = Column(Integer, ForeignKey("analysis_report_v2.id"))


class SQAAnalysisReport(Base):
__tablename__: str = "analysis_report_v2"

# pyre-fixme[8]: Attribute has type `int`; used as `Column[int]`.
id: int = Column(Integer, primary_key=True)
# pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`.
time_report_start: datetime = Column(IntTimestamp, nullable=False)

analyses: Optional[List[SQAAnalysis]] = relationship(
"SQAAnalysis", cascade="all, delete-orphan", lazy="selectin"
analysis_cards: List[SQAAnalysisCard] = relationship(
"SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin"
)

# pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`.
experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id"))
9 changes: 3 additions & 6 deletions ax/storage/sqa_store/sqa_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type, Union

from ax.analysis.old.analysis_report import AnalysisReport
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.batch_trial import AbandonedArm
Expand All @@ -36,8 +35,7 @@
from ax.storage.sqa_store.db import SQABase
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisReport,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand Down Expand Up @@ -80,8 +78,7 @@ def _default_class_to_sqa_class(self=None) -> Dict[Type[Base], Type[SQABase]]:
Metric: SQAMetric,
Runner: SQARunner,
Trial: SQATrial,
BaseAnalysis: SQAAnalysis,
AnalysisReport: SQAAnalysisReport,
AnalysisCard: SQAAnalysisCard,
}

class_to_sqa_class: Dict[Type[Base], Type[SQABase]] = field(
Expand Down
24 changes: 0 additions & 24 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@
update_runner_on_experiment,
)
from ax.storage.sqa_store.sqa_classes import (
AnalysisType,
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisReport,
SQAArm,
SQAExperiment,
SQAGeneratorRun,
Expand Down Expand Up @@ -1929,24 +1926,3 @@ def test_CreateAllTablesException(self) -> None:
engine.dialect.default_schema_name = "ax"
with self.assertRaises(ValueError):
create_all_tables(engine)

def test_CreateAnalysisRecords(self) -> None:

sqa_analysis = SQAAnalysis(
analysis_class_name="CrossValidationPlot",
experiment_analysis_type=AnalysisType.PLOTLY_VISUALIZATION,
time_analysis_start=datetime.now(),
time_analysis_completed=datetime.now(),
dataframe_json="none",
)
with session_scope() as session:
_ = session.merge(sqa_analysis)
session.flush()

def test_CreateAnalysisReport(self) -> None:
sqa_analysis_report = SQAAnalysisReport(
time_report_start=datetime.now(),
)
with session_scope() as session:
_ = session.merge(sqa_analysis_report)
session.flush()
Loading