Skip to content

Commit

Permalink
Refactoring for package (#72)
Browse files Browse the repository at this point in the history
* more refactoring for packaging

* refactor for package

* fixing vfp test

* vfp test

* update badge
  • Loading branch information
joglekara authored Sep 15, 2024
1 parent 603dccd commit 4dad107
Show file tree
Hide file tree
Showing 27 changed files with 457 additions and 544 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/cpu-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
push:
branches:
- main
- feature/package

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
Expand All @@ -34,4 +35,4 @@ jobs:
- name: Test with pytest
run: |
CPU_ONLY=True pytest tests/test_lpse2d tests/test_vlasov1d
CPU_ONLY=True pytest tests/test_lpse2d tests/test_vlasov1d tests/test_vfp1d
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ADEPT

![Docs](https://readthedocs.org/projects/adept/badge/?version=latest)
![Tests](https://github.com/ergodicio/adept/actions/workflows/test.yaml/badge.svg)
![Tests](https://github.com/ergodicio/adept/actions/workflows/cpu-tests.yaml/badge.svg)

![ADEPT](./docs/source/adept-logo.png)

Expand Down
410 changes: 2 additions & 408 deletions adept/__init__.py

Large diffs are not rendered by default.

409 changes: 409 additions & 0 deletions adept/_base_.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions adept/_lpse2d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modules import BaseLPSE2D as BaseLPSE2D, save_driver as save_driver
from .helpers import calc_threshold_intensity as calc_threshold_intensity
4 changes: 1 addition & 3 deletions adept/_lpse2d/core/driver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Dict
import jax
import equinox as eqx
from jax import numpy as jnp
from adept import get_envelope
from adept._base_ import get_envelope


class Driver:
Expand Down
2 changes: 1 addition & 1 deletion adept/_lpse2d/core/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp
import numpy as np

from adept import get_envelope
from adept._base_ import get_envelope
from adept._lpse2d.core import epw, laser


Expand Down
2 changes: 1 addition & 1 deletion adept/_lpse2d/core/vector_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp, Array
import numpy as np

from adept import get_envelope
from adept._base_ import get_envelope
from adept._lpse2d.core import epw, laser


Expand Down
4 changes: 1 addition & 3 deletions adept/_lpse2d/helpers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import os
from typing import Dict, Tuple
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
from jax import Array, numpy as jnp
import numpy as np
import equinox as eqx
import xarray as xr
import interpax

from astropy.units import Quantity as _Q

from adept import get_envelope
from adept._base_ import get_envelope


def write_units(cfg: Dict) -> Dict:
Expand Down
2 changes: 2 additions & 0 deletions adept/_lpse2d/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import BaseLPSE2D as BaseLPSE2D
from .driver import save as save_driver
3 changes: 2 additions & 1 deletion adept/_lpse2d/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from diffrax import diffeqsolve, SaveAt, ODETerm, SubSaveAt
from equinox import filter_jit

from adept import ADEPTModule, Stepper
from adept import ADEPTModule
from adept._base_ import Stepper
from adept._lpse2d.helpers import (
write_units,
post_process,
Expand Down
15 changes: 8 additions & 7 deletions adept/_vlasov1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@
import numpy as np
import xarray, mlflow, pint
from jax import numpy as jnp
from scipy.special import gamma
from diffrax import Solution
from matplotlib import pyplot as plt

from adept import get_envelope
from adept._base_ import get_envelope
from adept._vlasov1d.storage import store_f, store_fields

gamma_da = xarray.open_dataarray(os.path.join(os.path.dirname(__file__), "gamma_func_for_sg.nc"))
m_ax = gamma_da.coords["m"].data
g_3_m = np.squeeze(gamma_da.loc[{"gamma": "3/m"}].data)
g_5_m = np.squeeze(gamma_da.loc[{"gamma": "5/m"}].data)
# gamma_da = xarray.open_dataarray(os.path.join(os.path.dirname(__file__), "gamma_func_for_sg.nc"))
# m_ax = gamma_da.coords["m"].data
# g_3_m = np.squeeze(gamma_da.loc[{"gamma": "3/m"}].data)
# g_5_m = np.squeeze(gamma_da.loc[{"gamma": "5/m"}].data)


def gamma_3_over_m(m):
return np.interp(m, m_ax, g_3_m)
return gamma(3.0 / m) # np.interp(m, m_ax, g_3_m)


def gamma_5_over_m(m):
return np.interp(m, m_ax, g_5_m)
return gamma(5.0 / m) # np.interp(m, m_ax, g_5_m)


def _initialize_distribution_(
Expand Down
3 changes: 2 additions & 1 deletion adept/_vlasov1d/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from jax import numpy as jnp
from diffrax import ODETerm, SubSaveAt, diffeqsolve, SaveAt

from adept import Stepper, ADEPTModule
from adept import ADEPTModule
from adept._base_ import Stepper
from adept._vlasov1d.storage import get_save_quantities
from adept._vlasov1d.helpers import _initialize_total_distribution_, post_process
from adept._vlasov1d.solvers.vector_field import VlasovMaxwell
Expand Down
2 changes: 1 addition & 1 deletion adept/_vlasov1d/solvers/pushers/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict
from jax import numpy as jnp

from adept import get_envelope
from adept._base_ import get_envelope


class Driver:
Expand Down
2 changes: 1 addition & 1 deletion adept/_vlasov1d/solvers/vector_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from jax import numpy as jnp, Array

from adept import get_envelope
from adept._base_ import get_envelope
from adept._vlasov1d.solvers.pushers import field, fokker_planck, vlasov


Expand Down
4 changes: 1 addition & 3 deletions adept/lpse2d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from adept._lpse2d.modules.base import BaseLPSE2D
from adept._lpse2d.helpers import calc_threshold_intensity
from adept._lpse2d.modules.driver import save as save_driver
from ._lpse2d import BaseLPSE2D, save_driver, calc_threshold_intensity
2 changes: 1 addition & 1 deletion adept/tf1d/solvers/pushers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import equinox as eqx

from adept.electrostatic import get_complex_frequency_table
from adept import get_envelope
from adept._base_ import get_envelope


class WaveSolver(eqx.Module):
Expand Down
88 changes: 0 additions & 88 deletions adept/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,91 +175,3 @@ def export_run(run_id, prefix="individual", step=0):
t0 = time.time()
upload_dir_to_s3(td2, "remote-mlflow-staging", f"artifacts/{run_id}", run_id, prefix, step)
# print(f"Uploading took {round(time.time() - t0, 2)} s")


def setup_parsl(parsl_provider="local", num_gpus=4, nodes=1):
import parsl
from parsl.config import Config
from parsl.providers import SlurmProvider, LocalProvider
from parsl.launchers import SrunLauncher
from parsl.executors import HighThroughputExecutor

if parsl_provider == "local":

if nodes == 1:
this_provider = LocalProvider
provider_args = dict(
worker_init="source /pscratch/sd/a/archis/venvs/adept-gpu/bin/activate; \
module load cudnn/8.9.3_cuda12.lua; \
export PYTHONPATH='$PYTHONPATH:/global/homes/a/archis/adept/'; \
export BASE_TEMPDIR='/pscratch/sd/a/archis/tmp/'; \
export MLFLOW_TRACKING_URI='/pscratch/sd/a/archis/mlflow'; \
export MLFLOW_EXPORT=True",
init_blocks=1,
max_blocks=1,
nodes_per_block=1,
)
htex = HighThroughputExecutor(
available_accelerators=num_gpus,
label="tpd",
provider=this_provider(**provider_args),
cpu_affinity="block",
)
print(f"{htex.workers_per_node=}")
else:
this_provider = LocalProvider
provider_args = dict(
worker_init="source /pscratch/sd/a/archis/venvs/adept-gpu/bin/activate; \
module load cudnn/8.9.3_cuda12.lua; \
export PYTHONPATH='$PYTHONPATH:/global/homes/a/archis/adept/'; \
export BASE_TEMPDIR='/pscratch/sd/a/archis/tmp/'; \
export MLFLOW_TRACKING_URI='/pscratch/sd/a/archis/mlflow'; \
export MLFLOW_EXPORT=True",
nodes_per_block=nodes,
launcher=SrunLauncher(overrides="-c 32 --gpus-per-node 4"),
cmd_timeout=120,
init_blocks=1,
max_blocks=1,
)

htex = HighThroughputExecutor(
available_accelerators=num_gpus * nodes,
label="tpd",
provider=this_provider(**provider_args),
max_workers_per_node=4,
cpu_affinity="block",
)
print(f"{htex.workers_per_node=}")

elif parsl_provider == "gpu":

this_provider = SlurmProvider
sched_args = ["#SBATCH -C gpu", "#SBATCH --qos=regular"]
provider_args = dict(
partition=None,
account="m4490_g",
scheduler_options="\n".join(sched_args),
worker_init="export SLURM_CPU_BIND='cores';\
source /pscratch/sd/a/archis/venvs/adept-gpu/bin/activate; \
module load cudnn/8.9.3_cuda12.lua; \
export PYTHONPATH='$PYTHONPATH:/global/homes/a/archis/adept/'; \
export BASE_TEMPDIR='/pscratch/sd/a/archis/tmp/'; \
export MLFLOW_TRACKING_URI='/pscratch/sd/a/archis/mlflow';\
export MLFLOW_EXPORT=True",
launcher=SrunLauncher(overrides="--gpus-per-node 4 -c 128"),
walltime="12:00:00",
cmd_timeout=120,
nodes_per_block=1,
# init_blocks=1,
max_blocks=nodes,
)

htex = HighThroughputExecutor(
available_accelerators=4, label="tpd-learn", provider=this_provider(**provider_args), cpu_affinity="block"
)
print(f"{htex.workers_per_node=}")

config = Config(executors=[htex], retries=4)

# load the Parsl config
parsl.load(config)
2 changes: 1 addition & 1 deletion adept/vfp1d/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diffrax import diffeqsolve, SaveAt, ODETerm, SubSaveAt
from jax import numpy as jnp, tree_util as jtu

from adept import ADEPTModule, Stepper
from adept._base_ import ADEPTModule, Stepper
from adept.vfp1d.vector_field import OSHUN1D
from adept.vfp1d.helpers import _initialize_total_distribution_, calc_logLambda
from adept.vfp1d.storage import get_save_quantities, post_process
Expand Down
15 changes: 4 additions & 11 deletions adept/vfp1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,13 @@
# [email protected]

from typing import Dict, Tuple
import os

import numpy as np
from jax import Array
import xarray, yaml
from astropy import units as u, constants as csts
from scipy.special import gamma
from astropy.units import Quantity as _Q
from jax import numpy as jnp
from adept import get_envelope

gamma_da = xarray.open_dataarray(os.path.join(os.path.dirname(__file__), "..", "vlasov1d", "gamma_func_for_sg.nc"))
m_ax = gamma_da.coords["m"].data
g_3_m = np.squeeze(gamma_da.loc[{"gamma": "3/m"}].data)
g_5_m = np.squeeze(gamma_da.loc[{"gamma": "5/m"}].data)
from adept._base_ import get_envelope


def gamma_3_over_m(m: float) -> Array:
Expand All @@ -26,7 +19,7 @@ def gamma_3_over_m(m: float) -> Array:
:return: Array
"""
return np.interp(m, m_ax, g_3_m)
return gamma(3.0 / m) # np.interp(m, m_ax, g_3_m)


def gamma_5_over_m(m: float) -> Array:
Expand All @@ -36,7 +29,7 @@ def gamma_5_over_m(m: float) -> Array:
:param m: float between 2 and 5
:return: Array
"""
return np.interp(m, m_ax, g_5_m)
return gamma(5.0 / m) # np.interp(m, m_ax, g_5_m)


def calc_logLambda(cfg: Dict, ne: float, Te: float, Z: int, ion_species: str) -> Tuple[float, float]:
Expand Down
5 changes: 0 additions & 5 deletions tests/test_lpse2d/test_tpd_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
from adept.lpse2d import calc_threshold_intensity

import numpy as np
from adept.utils import setup_parsl

# from parsl.app.app import python_app


# @python_app
def run_once(L, Te, I0):
import yaml
from adept import ergoExo
Expand Down Expand Up @@ -36,7 +32,6 @@ def test_threshold():
if "CPU_ONLY" in os.environ:
pass
else:
setup_parsl()
ess = []
c = 3e8
lam0 = 351e-9
Expand Down
4 changes: 3 additions & 1 deletion tests/test_tf1d/test_against_vlasov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy as np
from jax import config

from adept import ergoExo

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)

import xarray as xr
from adept import ergoExo, electrostatic
from adept import electrostatic


def _modify_defaults_(defaults):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_tf1d/test_landau_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import numpy as np
from jax import config

from adept import ergoExo

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)

from jax import numpy as jnp
import mlflow

from adept import ergoExo, electrostatic
from adept import electrostatic


def _modify_defaults_(defaults, rng):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_tf1d/test_resonance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy as np
from jax import config

from adept import ergoExo

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)

from jax import numpy as jnp
from adept import ergoExo, electrostatic
from adept import electrostatic


def _modify_defaults_(defaults, rng, gamma):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vfp1d/epp-short.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
units:
laser_wavelength: 351nm
reference electron temperature: 2000eV
reference electron temperature: 300eV
reference ion temperature: 300eV
reference electron density: 1.5e21/cm^3
Z: 6
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vlasov1d/test_absorbing_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from diffrax import ODETerm, diffeqsolve

from adept._vlasov1d.solvers.pushers.field import Driver, WaveSolver
from adept import Stepper
from adept._base_ import Stepper


class VectorField(eqx.Module):
Expand Down
Loading

0 comments on commit 4dad107

Please sign in to comment.