Skip to content

Commit

Permalink
Adding openmc.read_source_file (#2858)
Browse files Browse the repository at this point in the history
Co-authored-by: Paul Romano <[email protected]>
  • Loading branch information
pshriwise and paulromano authored Feb 16, 2024
1 parent 5005c3c commit 3b575a4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/pythonapi/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Simulation Settings
:nosignatures:
:template: myfunction.rst

openmc.read_source_file
openmc.write_source_file
openmc.wwinp_to_wws

Expand Down
34 changes: 34 additions & 0 deletions openmc/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,37 @@ def write_source_file(
with h5py.File(filename, **kwargs) as fh:
fh.attrs['filetype'] = np.string_("source")
fh.create_dataset('source_bank', data=arr, dtype=source_dtype)


def read_source_file(filename: PathLike) -> typing.List[SourceParticle]:
"""Read a source file and return a list of source particles.
.. versionadded:: 0.14.1
Parameters
----------
filename : str or path-like
Path to source file to read
Returns
-------
list of SourceParticle
Source particles read from file
See Also
--------
openmc.SourceParticle
"""
with h5py.File(filename, 'r') as fh:
filetype = fh.attrs['filetype']
arr = fh['source_bank'][...]

if filetype != b'source':
raise ValueError(f'File {filename} is not a source file')

source_particles = []
for *params, particle in arr:
source_particles.append(SourceParticle(*params, ParticleType(particle)))

return source_particles
24 changes: 24 additions & 0 deletions tests/unit_tests/test_source_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ def test_source_file(run_in_tmpdir):
assert np.all(arr['particle'] == 0)


# Ensure sites read in are consistent
sites = openmc.read_source_file('test_source.h5')

assert filetype == b'source'
xs = np.array([site.r[0] for site in sites])
ys = np.array([site.r[1] for site in sites])
zs = np.array([site.r[2] for site in sites])
assert np.all((xs > 0.0) & (xs < 1.0))
assert np.all(ys == np.arange(1000))
assert np.all(zs == 0.0)
u = np.array([s.u for s in sites])
assert np.all(u[..., 0] == 0.0)
assert np.all(u[..., 1] == 0.0)
assert np.all(u[..., 2] == 1.0)
E = np.array([s.E for s in sites])
assert np.all(E == n - np.arange(n))
wgt = np.array([s.wgt for s in sites])
assert np.all(wgt == 1.0)
dgs = np.array([s.delayed_group for s in sites])
assert np.all(dgs == 0)
p_types = np.array([s.particle for s in sites])
assert np.all(p_types == 0)


def test_wrong_source_attributes(run_in_tmpdir):
# Create a source file with animal attributes
source_dtype = np.dtype([
Expand Down

0 comments on commit 3b575a4

Please sign in to comment.