diff --git a/docs/source/pythonapi/base.rst b/docs/source/pythonapi/base.rst index 705aecee9f1..7b924e37499 100644 --- a/docs/source/pythonapi/base.rst +++ b/docs/source/pythonapi/base.rst @@ -35,6 +35,7 @@ Simulation Settings :nosignatures: :template: myfunction.rst + openmc.read_source_file openmc.write_source_file openmc.wwinp_to_wws diff --git a/openmc/source.py b/openmc/source.py index 1ba90c92130..441bf4e76b3 100644 --- a/openmc/source.py +++ b/openmc/source.py @@ -885,3 +885,37 @@ def write_source_file( with h5py.File(filename, **kwargs) as fh: fh.attrs['filetype'] = np.string_("source") fh.create_dataset('source_bank', data=arr, dtype=source_dtype) + + +def read_source_file(filename: PathLike) -> typing.List[SourceParticle]: + """Read a source file and return a list of source particles. + + .. versionadded:: 0.14.1 + + Parameters + ---------- + filename : str or path-like + Path to source file to read + + Returns + ------- + list of SourceParticle + Source particles read from file + + See Also + -------- + openmc.SourceParticle + + """ + with h5py.File(filename, 'r') as fh: + filetype = fh.attrs['filetype'] + arr = fh['source_bank'][...] + + if filetype != b'source': + raise ValueError(f'File {filename} is not a source file') + + source_particles = [] + for *params, particle in arr: + source_particles.append(SourceParticle(*params, ParticleType(particle))) + + return source_particles diff --git a/tests/unit_tests/test_source_file.py b/tests/unit_tests/test_source_file.py index ec10260bafe..d9fb3c1f907 100644 --- a/tests/unit_tests/test_source_file.py +++ b/tests/unit_tests/test_source_file.py @@ -45,6 +45,30 @@ def test_source_file(run_in_tmpdir): assert np.all(arr['particle'] == 0) + # Ensure sites read in are consistent + sites = openmc.read_source_file('test_source.h5') + + assert filetype == b'source' + xs = np.array([site.r[0] for site in sites]) + ys = np.array([site.r[1] for site in sites]) + zs = np.array([site.r[2] for site in sites]) + assert np.all((xs > 0.0) & (xs < 1.0)) + assert np.all(ys == np.arange(1000)) + assert np.all(zs == 0.0) + u = np.array([s.u for s in sites]) + assert np.all(u[..., 0] == 0.0) + assert np.all(u[..., 1] == 0.0) + assert np.all(u[..., 2] == 1.0) + E = np.array([s.E for s in sites]) + assert np.all(E == n - np.arange(n)) + wgt = np.array([s.wgt for s in sites]) + assert np.all(wgt == 1.0) + dgs = np.array([s.delayed_group for s in sites]) + assert np.all(dgs == 0) + p_types = np.array([s.particle for s in sites]) + assert np.all(p_types == 0) + + def test_wrong_source_attributes(run_in_tmpdir): # Create a source file with animal attributes source_dtype = np.dtype([