Skip to content

Commit

Permalink
standard slurmrunner from dask
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jan 17, 2025
1 parent 526c549 commit f7257e2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ dependencies = [
"ecos",
"dask",
"dask[distributed]",
"dask_hpc_runner @ git+https://github.com/jacobtomlinson/dask-hpc-runner.git@main",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask import delayed, compute
from dask.distributed import Client
from dask.diagnostics import ProgressBar
from dask_hpc_runner import SlurmRunner
from dask_jobqueue.slurm import SLURMRunner

from cryo_challenge._preprocessing.fourier_utils import downsample_volume

Expand Down Expand Up @@ -259,6 +259,9 @@ def parse_args():
parser.add_argument(
"--n_i", type=int, default=80, help="Number of volumes in set i"
)
parser.add_argument(
"--n_j", type=int, default=80, help="Number of volumes in set j"
)
parser.add_argument(
"--n_downsample_pix", type=int, default=20, help="Number of downsample pixels"
)
Expand Down Expand Up @@ -378,7 +381,7 @@ def main(args):
submission = torch.load(fname, weights_only=False)
volumes = submission["volumes"].to(torch_dtype)
volumes_i = volumes[: args.n_i]
volumes_j = volumes
volumes_j = volumes[: args.n_j]
n_downsample_pix = args.n_downsample_pix
top_k = args.top_k
exponent = args.exponent
Expand Down Expand Up @@ -411,7 +414,7 @@ def main(args):
args = parse_args()
if args.slurm:
job_id = os.environ["SLURM_JOB_ID"]
with SlurmRunner(
with SLURMRunner(
scheduler_file=args.scheduler_file,
) as runner:
# The runner object contains the scheduler address and can be passed directly to a client
Expand Down
22 changes: 22 additions & 0 deletions src/cryo_challenge/_map_to_map/gromov_wasserstein/submission.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash
#SBATCH --job-name=exponent
#SBATCH --output=slurm/logs/%j.out
#SBATCH --error=slurm/logs/%j.err
#SBATCH --partition=ccb
#SBATCH -n 40
#SBATCH --time=99:00:00


for COST_SCALE_FACTOR in 0.03 0.1 0.3 3.0 10.0 30.0 100.0
do
for N_DOWNSAMPLE_PIX in 20
do
for TOP_K in 500
do
for EXPONENT in 0.1 0.2 1.0 1.5 2.0
do
srun python /mnt/home/gwoollard/ceph/repos/Cryo-EM-Heterogeneity-Challenge-1/src/cryo_challenge/_map_to_map/gromov_wasserstein/gw_weighted_voxels.py --n_downsample_pix ${N_DOWNSAMPLE_PIX} --top_k ${TOP_K} --exponent $EXPONENT --slurm --cost_scale_factor ${COST_SCALE_FACTOR}
done
done
done
done
4 changes: 2 additions & 2 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mrcfile
import numpy as np
from dask.distributed import Client
from dask_hpc_runner import SlurmRunner
from dask_jobqueue.slurm import SLURMRunner

from .gromov_wasserstein.gw_weighted_voxels import get_distance_matrix_dask_gw

Expand Down Expand Up @@ -503,7 +503,7 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
scheduler_file = os.path.join(
extra_params["scheduler_file_dir"], f"scheduler-{job_id}.json"
)
with SlurmRunner(
with SLURMRunner(
scheduler_file=scheduler_file,
) as runner:
# The runner object contains the scheduler address and can be passed directly to a client
Expand Down

0 comments on commit f7257e2

Please sign in to comment.