From a2165c00dd73634828cae27b781d11b5828959d1 Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Tue, 7 May 2024 17:03:25 -0700 Subject: [PATCH 1/7] Add viewerplot element --- nerfstudio/viewer/viewer_elements.py | 100 ++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index 654503a3c4..ea73a28abe 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -32,6 +32,7 @@ GuiButtonHandle, GuiDropdownHandle, GuiInputHandle, + GuiPlotlyHandle, ScenePointerEvent, ViserServer, ) @@ -43,6 +44,8 @@ if TYPE_CHECKING: from nerfstudio.viewer.viewer import Viewer +import plotly.graph_objects as go + TValue = TypeVar("TValue") TString = TypeVar("TString", default=str, bound=str) @@ -284,7 +287,9 @@ def __init__( cb_hook: Callable = lambda element: None, ) -> None: self.name = name - self.gui_handle: Optional[Union[GuiInputHandle[TValue], GuiButtonHandle, GuiButtonGroupHandle]] = None + self.gui_handle: Optional[ + Union[GuiInputHandle[TValue], GuiButtonHandle, GuiButtonGroupHandle, GuiPlotlyHandle] + ] = None self.disabled = disabled self.visible = visible self.cb_hook = cb_hook @@ -710,3 +715,96 @@ def _create_gui_handle(self, viser_server: ViserServer) -> None: self.gui_handle = viser_server.add_gui_vector3( self.name, self.default_value, step=self.step, disabled=self.disabled, visible=self.visible, hint=self.hint ) + + +class ViewerPlot(ViewerElement[go.Figure]): + """Base class for viewer figures, using plotly backend. + Includes misc wrapper methods for setting plotly figure properties. + """ + gui_handle: GuiPlotlyHandle + + _figure: go.Figure + """Figure to be displayed. Do not access this directly, exists only for initial statekeeping.""" + _aspect: float + """Aspect ratio of the plot (h/w). Default is 1.0.""" + + def __init__( + self, + figure: Optional[go.Figure] = None, + aspect: float = 1.0, + visible: bool = True, + ): + """ + Args: + - figure: The plotly figure to display -- if None, an empty figure is created. + - aspect: Aspect ratio of the plot (h/w). Default is 1.0. + - visible: If the plot is visible. + """ + self._figure = go.Figure() if figure is None else figure + self._aspect = aspect + super().__init__(name="", visible=visible) # plots have no name. + + def _create_gui_handle(self, viser_server: ViserServer) -> None: + self.gui_handle = viser_server.add_gui_plotly( + figure=self._figure, visible=self.visible, aspect=self._aspect + ) + + def install(self, viser_server: ViserServer) -> None: + self._create_gui_handle(viser_server) + assert self.gui_handle is not None + + @property + def figure(self): + assert self.gui_handle is not None + return self.gui_handle.figure + + @figure.setter + def figure(self, figure: go.Figure): + assert self.gui_handle is not None + self._figure = figure + self.gui_handle.figure = figure + + @property + def aspect(self): + return self._aspect + + @aspect.setter + def aspect(self, aspect: float): + self._aspect = aspect + if self.gui_handle is not None: + self.gui_handle.aspect = aspect + + @staticmethod + def set_margin(figure: go.Figure, margin: int = 0) -> None: + """Wrapper for setting the margin of a plotly figure.""" + # Set margins. + figure.update_layout( + margin=dict(l=margin, r=margin, t=margin, b=margin), + ) + + # Set automargin for title, so that title doesn't get cut off. + if margin == 0 and getattr(figure.layout, "title", None) is not None: + figure.layout.title.automargin = True # type: ignore + + @staticmethod + def set_dark(figure: go.Figure, dark: bool) -> None: + """Wrapper for setting the dark mode of a plotly figure.""" + if dark: + figure.update_layout(template="plotly_dark") + else: + figure.update_layout(template="plotly") + + @staticmethod + def plot_line(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blue") -> go.Scatter: + """Wrapper for plotting a line in a plotly figure.""" + return go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=color)) + + @staticmethod + def plot_scatter(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blue") -> go.Scatter: + """Wrapper for plotting a scatter in a plotly figure.""" + return go.Scatter(x=x, y=y, mode="markers", name=name, marker=dict(color=color)) + + @staticmethod + def plot_image(image: np.ndarray, name: str = "") -> go.Image: + """Wrapper for plotting an image in a plotly figure.""" + return go.Image(z=image, name=name) From ca04cbc0d9e44a2dba41fe18f23d5d428071cb43 Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 9 May 2024 10:52:25 -0700 Subject: [PATCH 2/7] Set darkmode + margin as ViewerPlot properties --- nerfstudio/viewer/viewer_elements.py | 65 ++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 19 deletions(-) diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index ea73a28abe..87d82771cb 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -727,11 +727,17 @@ class ViewerPlot(ViewerElement[go.Figure]): """Figure to be displayed. Do not access this directly, exists only for initial statekeeping.""" _aspect: float """Aspect ratio of the plot (h/w). Default is 1.0.""" + _dark_mode: bool + """If the plot is in dark mode (i.e., `plotly_dark` template). Default is True. Uses `plotly` template for light mode.""" + _margin: int + """Margin of the plot. Default is 0.""" def __init__( self, figure: Optional[go.Figure] = None, aspect: float = 1.0, + margin: int = 0, + dark_mode: bool = True, visible: bool = True, ): """ @@ -739,9 +745,12 @@ def __init__( - figure: The plotly figure to display -- if None, an empty figure is created. - aspect: Aspect ratio of the plot (h/w). Default is 1.0. - visible: If the plot is visible. + - margin: Margin of the plot. Default is 0. """ self._figure = go.Figure() if figure is None else figure self._aspect = aspect + self._margin = margin + self._dark_mode = dark_mode super().__init__(name="", visible=visible) # plots have no name. def _create_gui_handle(self, viser_server: ViserServer) -> None: @@ -762,7 +771,7 @@ def figure(self): def figure(self, figure: go.Figure): assert self.gui_handle is not None self._figure = figure - self.gui_handle.figure = figure + self._update_plot() @property def aspect(self): @@ -771,28 +780,46 @@ def aspect(self): @aspect.setter def aspect(self, aspect: float): self._aspect = aspect - if self.gui_handle is not None: - self.gui_handle.aspect = aspect + self._update_plot() - @staticmethod - def set_margin(figure: go.Figure, margin: int = 0) -> None: - """Wrapper for setting the margin of a plotly figure.""" - # Set margins. - figure.update_layout( - margin=dict(l=margin, r=margin, t=margin, b=margin), + @property + def dark(self): + return self._dark_mode + + @dark.setter + def dark(self, dark_mode: bool): + self._dark_mode = dark_mode + self._update_plot() + + @property + def margin(self): + return self._margin + + @margin.setter + def margin(self, margin: int): + self._margin = margin + self._update_plot() + + def _update_plot(self) -> None: + """Refresh the plot with: + - the current figure + - aspect ratio + - dark mode + """ + template = "plotly_dark" if self._dark_mode else "plotly" + self._figure.update_layout(template=template) + + # Set margins. Also, set automargin for title, so that title doesn't get cut off. + self._figure.update_layout( + margin=dict(l=self._margin, r=self._margin, t=self._margin, b=self._margin), ) + if self._margin == 0 and getattr(self._figure.layout, "title", None) is not None: + self._figure.layout.title.automargin = True # type: ignore - # Set automargin for title, so that title doesn't get cut off. - if margin == 0 and getattr(figure.layout, "title", None) is not None: - figure.layout.title.automargin = True # type: ignore + if self.gui_handle is not None: + self.gui_handle.aspect = self._aspect + self.gui_handle.figure = self._figure - @staticmethod - def set_dark(figure: go.Figure, dark: bool) -> None: - """Wrapper for setting the dark mode of a plotly figure.""" - if dark: - figure.update_layout(template="plotly_dark") - else: - figure.update_layout(template="plotly") @staticmethod def plot_line(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blue") -> go.Scatter: From 5efb685e54a278b0c5a06a917cb4c23b22c13b8b Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 9 May 2024 10:59:36 -0700 Subject: [PATCH 3/7] Lint --- nerfstudio/viewer/viewer_elements.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index 87d82771cb..894ab3f87f 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -721,6 +721,7 @@ class ViewerPlot(ViewerElement[go.Figure]): """Base class for viewer figures, using plotly backend. Includes misc wrapper methods for setting plotly figure properties. """ + gui_handle: GuiPlotlyHandle _figure: go.Figure @@ -754,9 +755,7 @@ def __init__( super().__init__(name="", visible=visible) # plots have no name. def _create_gui_handle(self, viser_server: ViserServer) -> None: - self.gui_handle = viser_server.add_gui_plotly( - figure=self._figure, visible=self.visible, aspect=self._aspect - ) + self.gui_handle = viser_server.add_gui_plotly(figure=self._figure, visible=self.visible, aspect=self._aspect) def install(self, viser_server: ViserServer) -> None: self._create_gui_handle(viser_server) @@ -820,7 +819,6 @@ def _update_plot(self) -> None: self.gui_handle.aspect = self._aspect self.gui_handle.figure = self._figure - @staticmethod def plot_line(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blue") -> go.Scatter: """Wrapper for plotting a line in a plotly figure.""" From 3607761db394083be45ae5b50049b7908d609ccd Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 9 May 2024 11:30:01 -0700 Subject: [PATCH 4/7] bugfix --- nerfstudio/viewer/viewer_elements.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index 894ab3f87f..6140f4f4d4 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -44,6 +44,8 @@ if TYPE_CHECKING: from nerfstudio.viewer.viewer import Viewer +import plotly +import plotly.basedatatypes import plotly.graph_objects as go TValue = TypeVar("TValue") @@ -767,7 +769,9 @@ def figure(self): return self.gui_handle.figure @figure.setter - def figure(self, figure: go.Figure): + def figure(self, figure: Union[go.Figure, plotly.basedatatypes.BaseTraceType]): + if isinstance(figure, plotly.basedatatypes.BaseTraceType): + figure = go.Figure(data=[figure]) assert self.gui_handle is not None self._figure = figure self._update_plot() @@ -812,7 +816,7 @@ def _update_plot(self) -> None: self._figure.update_layout( margin=dict(l=self._margin, r=self._margin, t=self._margin, b=self._margin), ) - if self._margin == 0 and getattr(self._figure.layout, "title", None) is not None: + if self._margin == 0 and self._figure.layout.title.text is not None: # type: ignore self._figure.layout.title.automargin = True # type: ignore if self.gui_handle is not None: @@ -831,5 +835,9 @@ def plot_scatter(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blu @staticmethod def plot_image(image: np.ndarray, name: str = "") -> go.Image: - """Wrapper for plotting an image in a plotly figure.""" + """Wrapper for plotting an image in a plotly figure. + `plotly.graph_object.Image` expects [0...255], so images [0...1] is automatically scaled here. + """ + if image.dtype != np.uint8: + image = (image * 255).astype(np.uint8) return go.Image(z=image, name=name) From 4d6cf66a63991d64facc0d9e742e98f6b9bb8eba Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 9 May 2024 13:31:27 -0700 Subject: [PATCH 5/7] Bump viser version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 95b980863d..cd8ebd5345 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "torchvision>=0.14.1", "torchmetrics[image]>=1.0.1", "typing_extensions>=4.4.0", - "viser==0.1.27", + "viser==0.1.29", "nuscenes-devkit>=1.1.1", "wandb>=0.13.3", "xatlas", From 770d843d9b332dd5c53c77883f692cdd0f31366d Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 24 Jun 2024 16:33:16 -0700 Subject: [PATCH 6/7] wip --- nerfstudio/viewer/metrics_panel.py | 31 ++++++++++++++++++++++++++++ nerfstudio/viewer/viewer_elements.py | 8 +++---- 2 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 nerfstudio/viewer/metrics_panel.py diff --git a/nerfstudio/viewer/metrics_panel.py b/nerfstudio/viewer/metrics_panel.py new file mode 100644 index 0000000000..40a90b12ea --- /dev/null +++ b/nerfstudio/viewer/metrics_panel.py @@ -0,0 +1,31 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path + +from typing import Literal + +import viser + +from nerfstudio.viewer.control_panel import ControlPanel + + +def populate_metrics_tab( + server: viser.ViserServer, + control_panel: ControlPanel, + config_path: Path, +) -> None: + pass \ No newline at end of file diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index 6140f4f4d4..d71f6649ee 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -745,10 +745,10 @@ def __init__( ): """ Args: - - figure: The plotly figure to display -- if None, an empty figure is created. - - aspect: Aspect ratio of the plot (h/w). Default is 1.0. - - visible: If the plot is visible. - - margin: Margin of the plot. Default is 0. + figure: The plotly figure to display -- if None, an empty figure is created. + aspect: Aspect ratio of the plot (h/w). Default is 1.0. + visible: If the plot is visible. + margin: Margin of the plot. Default is 0. """ self._figure = go.Figure() if figure is None else figure self._aspect = aspect From 9352d88f17384e344a5b398cad32abcfb771a79d Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 27 Jun 2024 16:52:57 -0700 Subject: [PATCH 7/7] wip --- nerfstudio/viewer/metrics_panel.py | 35 ++++++++++++++++++++++++++++-- nerfstudio/viewer/viewer.py | 5 +++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/nerfstudio/viewer/metrics_panel.py b/nerfstudio/viewer/metrics_panel.py index 40a90b12ea..387e1c0873 100644 --- a/nerfstudio/viewer/metrics_panel.py +++ b/nerfstudio/viewer/metrics_panel.py @@ -14,18 +14,49 @@ from __future__ import annotations +import dataclasses from pathlib import Path - from typing import Literal import viser +from nerfstudio.models.base_model import Model +from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.viewer.control_panel import ControlPanel +from nerfstudio.viewer.viewer_elements import ViewerPlot +@dataclasses.dataclass +class MetricsPanel: + server: viser.ViserServer + control_panel: ControlPanel + config_path: Path + viewer_model: Model def populate_metrics_tab( server: viser.ViserServer, control_panel: ControlPanel, config_path: Path, + viewer_model: Model, +) -> None: + viewing_gsplat = isinstance(viewer_model, SplatfactoModel) + + with server.add_gui_folder("Training Metrics"): + populate_train_metrics_tab(server, control_panel, config_path, viewing_gsplat) + # with server.add_gui_folder("Training Loss"): + # populate_train_loss_tab(server, control_panel, config_path, viewing_gsplat) + #training rays? + # with server.add_gui_folder("Eval Metrics"): + # populate_eval_metrics_tab(server, control_panel, config_path, viewing_gsplat) + # with server.add_gui_folder("Eval Metrics (All Images)"): + # populate_eval_metrics_all_images_tab(server, control_panel, config_path, viewing_gsplat) + # with server.add_gui_folder("Eval Loss"): + # populate_eval_loss_tab(server, control_panel, config_path, viewing_gsplat) + + +def populate_train_metrics_tab( + server: viser.ViserServer, + control_panel: ControlPanel, + config_path: Path, + viewing_gsplat: bool, ) -> None: - pass \ No newline at end of file + ViewerPlot() \ No newline at end of file diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index a5093f0dad..101d72fcef 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -40,6 +40,7 @@ from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName from nerfstudio.viewer.control_panel import ControlPanel from nerfstudio.viewer.export_panel import populate_export_tab +from nerfstudio.viewer.metrics_panel import populate_metrics_tab from nerfstudio.viewer.render_panel import populate_render_tab from nerfstudio.viewer.render_state_machine import RenderAction, RenderStateMachine from nerfstudio.viewer.utils import CameraState, parse_object @@ -199,6 +200,7 @@ def __init__( self._output_split_type_change, default_composite_depth=self.config.default_composite_depth, ) + config_path = self.log_filename.parents[0] / "config.yml" with tabs.add_tab("Render", viser.Icon.CAMERA): self.render_tab_state = populate_render_tab( @@ -208,6 +210,9 @@ def __init__( with tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT): populate_export_tab(self.viser_server, self.control_panel, config_path, self.pipeline.model) + with tabs.add_tab("Metrics", viser.Icon.GRAPH): + populate_metrics_tab(self.viser_server, self.control_panel, config_path, self.pipeline.model) + # Keep track of the pointers to generated GUI folders, because each generated folder holds a unique ID. viewer_gui_folders = dict()