Skip to content

Commit

Permalink
fix(reset): Properly reset state on task change
Browse files Browse the repository at this point in the history
Also start using class annotation

fix #10
  • Loading branch information
jourdain committed Aug 4, 2023
1 parent f25300f commit 90d7ba0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 35 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ keywords =
packages = find:
include_package_data = True
install_requires =
trame>3
trame>=3.1.0
trame-vuetify
trame-components>=2.0.4
trame-plotly
Expand Down
40 changes: 20 additions & 20 deletions xaitk_saliency_demo/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .ml.models import get_model
from .ml.xai import get_saliency

from trame.decorators import TrameApp, change, life_cycle
from trame.app import get_server
from trame_client.encoders import numpy
from trame.ui.vuetify import SinglePageLayout
Expand All @@ -18,6 +19,7 @@
logger = logging.getLogger("xaitks_saliency_demo")


@TrameApp()
class XaitkSaliency:
def __init__(self, server):
if server is None:
Expand Down Expand Up @@ -59,24 +61,6 @@ def __init__(self, server):
)
self.state.client_only("xai_viz_heatmap_opacity")

# Bind method to controller
self.ctrl.on_server_ready.add(self.on_ready)

# Bind state change
self.state.change("task_active")(self.on_task_change)
self.state.change("model_active")(self.on_model_change)
self.state.change("TOP_K")(self.on_nb_class_change)
self.state.change("saliency_active")(self.on_xai_algo_change)
self.state.change("input_file")(self.on_input_file_change)
self.state.change("input_1_img_url", "input_2_img_url")(self.reset_image)
self.state.change(
*[f"xai_param__{k}" for k in config.ALL_SALIENCY_PARAMS.keys()]
)(self.on_saliency_param_update)

self.state.change("xai_viz_color_min", "xai_viz_color_max")(
self.xai_viz_color_range_change
)

# Build GUI
self.ui()

Expand Down Expand Up @@ -135,12 +119,14 @@ def run_saliency(self):

# Create saliency and run it
xaitk = get_saliency(self._task, self._model, **self._xaitk_config)
self.state.xai_ready = True
return xaitk.run(self._image_1, self._image_2)

# -----------------------------------------------------
# Exec API
# -----------------------------------------------------

@life_cycle.server_ready
def on_ready(self, task_active, **kwargs):
"""Executed only once when application start"""
logger.info("on_ready", task_active)
Expand Down Expand Up @@ -181,6 +167,11 @@ def update_model_execution(self):
-> btn press in model section
-> state.change(TOP_K, input_file, model_active)
"""

# We don't have input to deal with
if not self.state.input_1_img_url:
return

results = {}

if self.can_run():
Expand Down Expand Up @@ -284,38 +275,44 @@ def reset_model_execution(self):
self.state.model_viz_detection_areas = []

def reset_all(self):
self.state.xai_ready = False
self.state.input_needed = True
self.state.input_1_img_url = None
self.state.input_2_img_url = None
self.reset_model_execution()

@change("task_active")
def on_task_change(self, task_active, **kwargs):
# Use static dependency to update state values
if task_active in config.TASK_DEPENDENCY:
for key, value in config.TASK_DEPENDENCY[task_active].items():
self.state[key] = value

self.set_task(task_active)

# New task => clear UI content
self.reset_all()

self.set_task(task_active)

@change("model_active")
def on_model_change(self, model_active, **kwargs):
if model_active:
logger.info("set model to", model_active)
self.set_model(model_active)
self.reset_model_execution()
self.update_model_execution()

@change("TOP_K")
def on_nb_class_change(self, **kwargs):
self.update_model_execution()

@change("saliency_active")
def on_xai_algo_change(self, saliency_active, **kwargs):
if saliency_active in config.SALIENCY_PARAMS:
# Show/hide parameters relevant to current algo
self.state.xai_params_to_show = config.SALIENCY_PARAMS[saliency_active]
self.update_active_xai_algorithm()

@change("input_file")
def on_input_file_change(
self, input_file, input_1_img_url, input_2_img_url, input_expected, **kwargs
):
Expand All @@ -337,6 +334,7 @@ def on_input_file_change(
self.set_image_2(input_file.get("content"))
self.update_model_execution()

@change("input_1_img_url", "input_2_img_url")
def reset_image(self, input_1_img_url, input_2_img_url, input_expected, **kwargs):
"""Method called when input_x_img_url is changed which can happen when setting them but also when cleared on the client side"""
count = 0
Expand All @@ -348,9 +346,11 @@ def reset_image(self, input_1_img_url, input_2_img_url, input_expected, **kwargs
# Hide button if we have all the inputs we need
self.state.input_needed = count < input_expected

@change(*[f"xai_param__{k}" for k in config.ALL_SALIENCY_PARAMS.keys()])
def on_saliency_param_update(self, **kwargs):
self.update_active_xai_algorithm()

@change("xai_viz_color_min", "xai_viz_color_max")
def xai_viz_color_range_change(
self, xai_viz_color_min, xai_viz_color_max, **kwargs
):
Expand Down
29 changes: 15 additions & 14 deletions xaitk_saliency_demo/app/xaitk_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,6 @@ def __init__(self, input, **kwargs):
)


def create_card_container(**kwargs):
with vuetify.VCard(**kwargs) as _card:
_header = vuetify.VCardTitle()
vuetify.VDivider()
_content = vuetify.VCardText()

return _card, _header, _content


# -----------------------------------------------------------------------------
# Section builders
#
Expand Down Expand Up @@ -173,23 +164,25 @@ def __init__(self, run=None):
# classes UI
_chart = plotly.Figure(
style="width: 100%; height: 100%;",
v_show=("task_active === 'classification'",),
v_show=("task_active === 'classification' && !input_needed",),
display_mode_bar=False,
)
ctrl.classification_chart_update = _chart.update

# similarity UI
vuetify.VProgressCircular(
"{{ Math.round(model_viz_similarity) }} %",
v_show=("task_active === 'similarity'",),
v_show=("task_active === 'similarity' && !input_needed",),
size=192,
width=15,
color="teal",
value=("model_viz_similarity", 0),
)

# object detection UI
with vuetify.VRow(v_show=("task_active === 'detection'",), align="center"):
with vuetify.VRow(
v_show=("task_active === 'detection' && !input_needed",), align="center"
):
trame.XaiImage(
classes="ma-2",
src=("input_1_img_url",),
Expand Down Expand Up @@ -415,12 +408,14 @@ def __init__(self):
)
with self:
vuetify.VSelect(
v_show=("xai_ready", False),
v_model=("xai_viz_classification_selected",),
items=("xai_viz_classification_selected_available", []),
**config.STYLE_COMPACT,
classes="mb-2",
)
trame.XaiImage(
v_show=("xai_ready", False),
v_if=("input_1_img_url",),
src=("input_1_img_url",),
max_height=400,
Expand All @@ -438,9 +433,12 @@ def __init__(self):

class XaiSimilarityResults(html.Div):
def __init__(self):
super().__init__(v_if=("xai_viz_type == 'similarity'",))
super().__init__(
v_if="xai_viz_type == 'similarity'",
)
with self:
trame.XaiImage(
v_show=("xai_ready", False),
v_if=("input_2_img_url",),
src=("input_1_img_url",),
max_height=400,
Expand All @@ -459,17 +457,20 @@ def __init__(self):
class XaiDetectionResults(html.Div):
def __init__(self):
super().__init__(
v_if=("xai_viz_type == 'detection'",), classes="d-flex flex-column"
v_if="xai_viz_type == 'detection'",
classes="d-flex flex-column",
)
with self:
vuetify.VSelect(
v_show=("xai_ready", False),
v_model=("xai_viz_detection_selected",),
items=("model_viz_detection_areas", []),
change="model_viz_detection_area_actives = [1 + Number(xai_viz_detection_selected.split('_')[1])]",
**config.STYLE_COMPACT,
classes="mb-2",
)
trame.XaiImage(
v_show=("xai_ready", False),
v_if=("input_1_img_url",),
src=("input_1_img_url",),
max_height=400,
Expand Down

0 comments on commit 90d7ba0

Please sign in to comment.