Skip to content

Commit

Permalink
validation of gw happens inside validate_input_config_mtm
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jan 17, 2025
1 parent c55750f commit 4baba98
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 40 deletions.
6 changes: 0 additions & 6 deletions src/cryo_challenge/_commands/run_map2map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from .._map_to_map.map_to_map_pipeline import run
from ..data._validation.config_validators import validate_input_config_mtm
from ..data._validation.config_validators import GWConfig


def add_args(parser):
Expand Down Expand Up @@ -37,11 +36,6 @@ def main(args):
config = yaml.safe_load(file)

validate_input_config_mtm(config)
if "gromov_wasserstein_extra_params" in config["analysis"]:
_ = GWConfig(
**config["analysis"]["gromov_wasserstein_extra_params"]
).model_dump()

warnexists(config["output"])
mkbasedir(os.path.dirname(config["output"]))

Expand Down
80 changes: 46 additions & 34 deletions src/cryo_challenge/data/_validation/config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,44 @@ def validate_config_mtm_analysis_normalize(config_analysis_normalize: dict) -> N
return


class GWConfig(BaseModel):
"""
Gromov-Wasserstein extra parameters config
Parameters
----------
top_k : int
Number of voxels to use (ranked according to highest mass)
n_downsample_pix : int
Number of pixels to downsample to (in each dimension)
exponent : float
exponential weighting of GW cost. Base cost is Euclindean distance
cost_scale_factor : float
multiplicative scaling factor for the cost (before exponentiation)
element_wise : bool
dask parralelization: whether to call dask.compute on each map-to-map computation, or naively loop through each (80) submitted maps row-wise
slurm : bool
parallelization configuration: whether to use dask_hpc_runner.SlurmRunner as a runner for dask.Client(runner)
scheduler : Optional[str]
string argument to dask.compute
local_directory : Optional[str]
directory for dask.distributed.Client
"""

top_k: int
n_downsample_pix: int
exponent: float
cost_scale_factor: float
element_wise: bool
slurm: bool
scheduler: Optional[str] = None
local_directory: Optional[str] = None


def validate_config_mtm_analysis_gromov_wasserstein_extra_arams(config):
_ = GWConfig(**config).model_dump()


def validate_config_mtm_analysis(config_analysis: dict) -> None:
"""
Validate the analysis part of the config dictionary for the MapToMap config.
Expand All @@ -167,6 +205,14 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None:

validate_generic_config(config_analysis, keys_and_types)
validate_config_mtm_analysis_normalize(config_analysis["normalize"])
if "gromov_wasserstein_extra_params" in config_analysis:
validate_config_mtm_analysis_gromov_wasserstein_extra_arams(
config_analysis["gromov_wasserstein_extra_params"]
)
else:
Warning(
"No Gromov-Wasserstein extra parameters found in config, so not validated."
)
return


Expand Down Expand Up @@ -375,37 +421,3 @@ def check_voxel_size(cls, value):
if value <= 0:
raise ValueError("Voxel size must be positive.")
return value


class GWConfig(BaseModel):
"""
Config for the GW distance computation
Parameters
----------
top_k : int
Number of voxels to use (ranked according to highest mass)
n_downsample_pix : int
Number of pixels to downsample to (in each dimension)
exponent : float
exponential weighting of GW cost. Base cost is Euclindean distance
cost_scale_factor : float
multiplicative scaling factor for the cost (before exponentiation)
element_wise : bool
dask parralelization: whether to call dask.compute on each map-to-map computation, or naively loop through each (80) submitted maps row-wise
slurm : bool
parallelization configuration: whether to use dask_hpc_runner.SlurmRunner as a runner for dask.Client(runner)
scheduler : Optional[str]
string argument to dask.compute
local_directory : Optional[str]
directory for dask.distributed.Client
"""

top_k: int
n_downsample_pix: int
exponent: float
cost_scale_factor: float
element_wise: bool
slurm: bool
scheduler: Optional[str] = None
local_directory: Optional[str] = None

0 comments on commit 4baba98

Please sign in to comment.