Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/visualiser #993

Merged
merged 6 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def __setstate__(self, state):
self.__dict__.update(state)

def __getattr__(self, item):
return getattr(self.prior_model, item)
try:
return super().__getattr__(item)
except AttributeError:
return getattr(self.prior_model, item)

def name_for_variable(self, variable):
path = ".".join(self.prior_model.path_for_prior(variable))
Expand Down
38 changes: 21 additions & 17 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from autofit.non_linear.samples.pdf import SamplesPDF
from autofit.non_linear.result import Result
from autofit.non_linear.samples.samples import Samples
from .visualize import Visualizer

logger = logging.getLogger(__name__)

Expand All @@ -26,6 +27,25 @@ class Analysis(ABC):
"""

Result = Result
Visualizer = Visualizer

def __getattr__(self, item: str):
"""
If a method starts with 'visualize_' then we assume it is associated with
the Visualizer and forward the call to the visualizer.

It may be desirable to remove this behaviour as the visualizer component of
the system becomes more sophisticated.
"""
if item.startswith("visualize"):
_method = getattr(Visualizer, item)
else:
raise AttributeError(f"Analysis has no attribute {item}")

def method(*args, **kwargs):
return _method(self, *args, **kwargs)

return method

def compute_all_latent_variables(
self, samples: Samples
Expand Down Expand Up @@ -136,22 +156,6 @@ def should_visualize(
def log_likelihood_function(self, instance):
raise NotImplementedError()

def visualize_before_fit(self, paths: AbstractPaths, model: AbstractPriorModel):
pass

def visualize(self, paths: AbstractPaths, instance, during_analysis):
pass

def visualize_before_fit_combined(
self, analyses, paths: AbstractPaths, model: AbstractPriorModel
):
pass

def visualize_combined(
self, analyses, paths: AbstractPaths, instance, during_analysis
):
pass

def save_attributes(self, paths: AbstractPaths):
pass

Expand Down Expand Up @@ -238,7 +242,7 @@ def make_result(
paths=paths,
samples=samples,
search_internal=search_internal,
analysis=None
analysis=None,
)

def profile_log_likelihood_function(self, paths: AbstractPaths, instance):
Expand Down
45 changes: 45 additions & 0 deletions autofit/non_linear/analysis/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.mapper.prior_model.abstract import AbstractPriorModel


class Visualizer:
"""
Methods associated with visualising analysis, model and data before, during
or after an optimisation.
"""

@staticmethod
def visualize_before_fit(
analysis,
paths: AbstractPaths,
model: AbstractPriorModel,
):
pass

@staticmethod
def visualize(
analysis,
paths: AbstractPaths,
instance,
during_analysis,
):
pass

@staticmethod
def visualize_before_fit_combined(
analysis,
analyses,
paths: AbstractPaths,
model: AbstractPriorModel,
):
pass

@staticmethod
def visualize_combined(
analysis,
analyses,
paths: AbstractPaths,
instance,
during_analysis,
):
pass
30 changes: 13 additions & 17 deletions autofit/non_linear/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def start_resume_fit(self, analysis: Analysis, model: AbstractPriorModel) -> Res
samples_summary=samples.summary(),
paths=self.paths,
samples=samples,
search_internal=search_internal
search_internal=search_internal,
)

if self.is_master:
Expand Down Expand Up @@ -730,10 +730,7 @@ def result_via_completed_fit(
model.freeze()
samples_summary = self.paths.load_samples_summary()

result = analysis.make_result(
samples_summary=samples_summary,
paths=self.paths
)
result = analysis.make_result(samples_summary=samples_summary, paths=self.paths)

if self.is_master:
self.logger.info(f"Fit Already Completed: skipping non-linear search.")
Expand Down Expand Up @@ -912,19 +909,18 @@ def perform_update(
instance = samples_summary.instance
except exc.FitException:
return samples

if self.is_master:

if self.is_master:
self.paths.save_samples_summary(samples_summary=samples_summary)

samples = samples.samples_above_weight_threshold_from(log_message=not during_analysis)
samples = samples.samples_above_weight_threshold_from(
log_message=not during_analysis
)
self.paths.save_samples(samples=samples)

if (
(during_analysis and conf.instance["output"]["latent_during_fit"]) or
(not during_analysis and conf.instance["output"]["latent_after_fit"])
if (during_analysis and conf.instance["output"]["latent_during_fit"]) or (
not during_analysis and conf.instance["output"]["latent_after_fit"]
):

latent_variables = analysis.compute_all_latent_variables(samples)

if latent_variables:
Expand Down Expand Up @@ -979,7 +975,7 @@ def perform_visualization(
self,
model: AbstractPriorModel,
analysis: AbstractPriorModel,
samples_summary : SamplesSummary,
samples_summary: SamplesSummary,
during_analysis: bool,
search_internal=None,
):
Expand Down Expand Up @@ -1014,7 +1010,7 @@ def perform_visualization(
analysis.visualize(
paths=self.paths,
instance=samples_summary.instance,
during_analysis=during_analysis
during_analysis=during_analysis,
)
analysis.visualize_combined(
analyses=None,
Expand All @@ -1024,10 +1020,10 @@ def perform_visualization(
)

if analysis.should_visualize(paths=self.paths, during_analysis=during_analysis):

if not isinstance(self.paths, NullPaths):

samples = self.samples_from(model=model, search_internal=search_internal)
samples = self.samples_from(
model=model, search_internal=search_internal
)

self.plot_results(samples=samples)

Expand Down
Loading