Skip to content

Commit

Permalink
add verbose option to async and adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Feb 23, 2023
1 parent bc1321e commit 9e228df
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
1 change: 1 addition & 0 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_adjoint_setup_fwd(use_emulated_run):
folder_name="default",
path="simulation_data.hdf5",
callback_url=None,
verbose=False,
)
sim_orig = sim_data_orig.simulation
sim_fwd = sim_data_fwd.simulation
Expand Down
22 changes: 20 additions & 2 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def tidy3d_run_async_fn(simulations: Dict[str, Simulation], **kwargs) -> BatchDa
""" Running a single simulation using web.run. """


@partial(custom_vjp, nondiff_argnums=tuple(range(1, 5)))
# pylint:disable=too-many-arguments
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
def run(
simulation: JaxSimulation,
task_name: str,
folder_name: str = "default",
path: str = "simulation_data.hdf5",
callback_url: str = None,
verbose: bool = True,
) -> JaxSimulationData:
"""Submits a :class:`.JaxSimulation` to server, starts running, monitors progress, downloads,
and loads results as a :class:`.JaxSimulationData` object.
Expand All @@ -62,6 +64,8 @@ def run(
callback_url : str = None
Http PUT url to receive simulation finish event. The body content is a json file with
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
Returns
-------
Expand All @@ -79,18 +83,21 @@ def run(
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)

# convert back to jax type and return
return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)


# pylint:disable=too-many-arguments
def run_fwd(
simulation: JaxSimulation,
task_name: str,
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
) -> Tuple[JaxSimulationData, tuple]:
"""Run forward pass and stash extra objects for the backwards pass."""

Expand All @@ -103,6 +110,7 @@ def run_fwd(
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)

# remove the gradient data from the returned version (not needed)
Expand All @@ -116,6 +124,7 @@ def run_bwd(
folder_name: str,
path: str,
callback_url: str,
verbose: bool,
res: tuple,
sim_data_vjp: JaxSimulationData,
) -> Tuple[JaxSimulation]:
Expand All @@ -135,6 +144,7 @@ def run_bwd(
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
grad_data_adj = sim_data_adj.grad_data

Expand All @@ -159,12 +169,13 @@ def _task_name_orig(index: int, task_name_suffix: str = None):


# pylint:disable=too-many-locals
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 7)))
def run_async(
simulations: Tuple[JaxSimulation, ...],
folder_name: str = "default",
path_dir: str = DEFAULT_DATA_DIR,
callback_url: str = None,
verbose: bool = True,
num_workers: int = None,
task_name_suffix: str = None,
) -> Tuple[JaxSimulationData, ...]:
Expand All @@ -184,6 +195,8 @@ def run_async(
callback_url : str = None
Http PUT url to receive simulation finish event. The body content is a json file with
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
num_workers: int = None
Number of tasks to submit at once in a batch, if None, will run all at the same time.
Expand Down Expand Up @@ -212,6 +225,7 @@ def run_async(
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
)

Expand All @@ -234,6 +248,7 @@ def run_async_fwd(
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
num_workers: int,
task_name_suffix: str,
) -> Tuple[Dict[str, JaxSimulationData], tuple]:
Expand All @@ -252,6 +267,7 @@ def run_async_fwd(
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
task_name_suffix=task_name_suffix_fwd,
)
Expand All @@ -271,6 +287,7 @@ def run_async_bwd(
folder_name: str,
path_dir: str,
callback_url: str,
verbose: bool,
num_workers: int,
task_name_suffix: str,
res: tuple,
Expand Down Expand Up @@ -303,6 +320,7 @@ def run_async_bwd(
folder_name=folder_name,
path_dir=path_dir,
callback_url=callback_url,
verbose=verbose,
num_workers=num_workers,
task_name_suffix=task_name_suffix_adj,
)
Expand Down
15 changes: 11 additions & 4 deletions tidy3d/web/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
nest_asyncio.apply()


# pylint:disable=too-many-arguments, too-many-locals
async def _run_async(
simulations: Dict[str, Simulation],
folder_name: str = "default",
path_dir: str = DEFAULT_DATA_DIR,
callback_url: str = None,
num_workers: int = None,
# verbose: bool = True,
verbose: bool = True,
) -> BatchData:
"""Submits a set of :class:`.Simulation` objects to server, starts running,
monitors progress, downloads, and loads results as a :class:`.BatchData` object.
Expand All @@ -36,6 +37,8 @@ async def _run_async(
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
num_workers: int = None
Number of tasks to submit at once in a batch, if None, will run all at the same time.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
Note
----
Expand Down Expand Up @@ -88,7 +91,7 @@ async def worker(queue):
task_name=task_name,
callback_url=callback_url,
folder_name=folder_name,
# verbose=verbose,
verbose=verbose,
)
job.start()

Expand Down Expand Up @@ -116,16 +119,17 @@ async def worker(queue):
await asyncio.gather(*tasks, return_exceptions=True)

# return the batch data containing all of the job run details for loading later
return BatchData(task_ids=task_ids, task_paths=task_paths)
return BatchData(task_ids=task_ids, task_paths=task_paths, verbose=verbose)


# pylint:disable=too-many-arguments
def run_async(
simulations: Dict[str, Simulation],
folder_name: str = "default",
path_dir: str = DEFAULT_DATA_DIR,
callback_url: str = None,
num_workers: int = None,
# verbose: bool = True,
verbose: bool = True,
) -> BatchData:
"""Submits a set of :class:`.Simulation` objects to server, starts running,
monitors progress, downloads, and loads results as a :class:`.BatchData` object.
Expand All @@ -144,6 +148,8 @@ def run_async(
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
num_workers: int = None
Number of tasks to submit at once in a batch, if None, will run all at the same time.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
Note
----
Expand All @@ -163,5 +169,6 @@ def run_async(
path_dir=path_dir,
callback_url=callback_url,
num_workers=num_workers,
verbose=verbose,
)
)

0 comments on commit 9e228df

Please sign in to comment.