diff --git a/src/cryo_challenge/_commands/run_map2map_pipeline.py b/src/cryo_challenge/_commands/run_map2map_pipeline.py index 2e684d9..90db1aa 100644 --- a/src/cryo_challenge/_commands/run_map2map_pipeline.py +++ b/src/cryo_challenge/_commands/run_map2map_pipeline.py @@ -8,7 +8,6 @@ from .._map_to_map.map_to_map_pipeline import run from ..data._validation.config_validators import validate_input_config_mtm -from ..data._validation.config_validators import GWConfig def add_args(parser): @@ -37,11 +36,6 @@ def main(args): config = yaml.safe_load(file) validate_input_config_mtm(config) - if "gromov_wasserstein_extra_params" in config["analysis"]: - _ = GWConfig( - **config["analysis"]["gromov_wasserstein_extra_params"] - ).model_dump() - warnexists(config["output"]) mkbasedir(os.path.dirname(config["output"])) diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 760b63e..63da1c9 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -146,6 +146,44 @@ def validate_config_mtm_analysis_normalize(config_analysis_normalize: dict) -> N return +class GWConfig(BaseModel): + """ + Gromov-Wasserstein extra parameters config + + Parameters + ---------- + top_k : int + Number of voxels to use (ranked according to highest mass) + n_downsample_pix : int + Number of pixels to downsample to (in each dimension) + exponent : float + exponential weighting of GW cost. Base cost is Euclindean distance + cost_scale_factor : float + multiplicative scaling factor for the cost (before exponentiation) + element_wise : bool + dask parralelization: whether to call dask.compute on each map-to-map computation, or naively loop through each (80) submitted maps row-wise + slurm : bool + parallelization configuration: whether to use dask_hpc_runner.SlurmRunner as a runner for dask.Client(runner) + scheduler : Optional[str] + string argument to dask.compute + local_directory : Optional[str] + directory for dask.distributed.Client + """ + + top_k: int + n_downsample_pix: int + exponent: float + cost_scale_factor: float + element_wise: bool + slurm: bool + scheduler: Optional[str] = None + local_directory: Optional[str] = None + + +def validate_config_mtm_analysis_gromov_wasserstein_extra_arams(config): + _ = GWConfig(**config).model_dump() + + def validate_config_mtm_analysis(config_analysis: dict) -> None: """ Validate the analysis part of the config dictionary for the MapToMap config. @@ -167,6 +205,14 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: validate_generic_config(config_analysis, keys_and_types) validate_config_mtm_analysis_normalize(config_analysis["normalize"]) + if "gromov_wasserstein_extra_params" in config_analysis: + validate_config_mtm_analysis_gromov_wasserstein_extra_arams( + config_analysis["gromov_wasserstein_extra_params"] + ) + else: + Warning( + "No Gromov-Wasserstein extra parameters found in config, so not validated." + ) return @@ -375,37 +421,3 @@ def check_voxel_size(cls, value): if value <= 0: raise ValueError("Voxel size must be positive.") return value - - -class GWConfig(BaseModel): - """ - Config for the GW distance computation - - Parameters - ---------- - top_k : int - Number of voxels to use (ranked according to highest mass) - n_downsample_pix : int - Number of pixels to downsample to (in each dimension) - exponent : float - exponential weighting of GW cost. Base cost is Euclindean distance - cost_scale_factor : float - multiplicative scaling factor for the cost (before exponentiation) - element_wise : bool - dask parralelization: whether to call dask.compute on each map-to-map computation, or naively loop through each (80) submitted maps row-wise - slurm : bool - parallelization configuration: whether to use dask_hpc_runner.SlurmRunner as a runner for dask.Client(runner) - scheduler : Optional[str] - string argument to dask.compute - local_directory : Optional[str] - directory for dask.distributed.Client - """ - - top_k: int - n_downsample_pix: int - exponent: float - cost_scale_factor: float - element_wise: bool - slurm: bool - scheduler: Optional[str] = None - local_directory: Optional[str] = None