Skip to content

Commit

Permalink
move util script to main package
Browse files Browse the repository at this point in the history
  • Loading branch information
wtbarnes committed Aug 23, 2024
1 parent fde4f6b commit 4ccfcac
Showing 1 changed file with 11 additions and 70 deletions.
81 changes: 11 additions & 70 deletions src/ebtelplusplus/tests/util.py → src/ebtelplusplus/util.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,12 @@
"""
Utility functions for configuring and running ebtel++ simulations
"""
import os
import subprocess
import warnings
from collections import OrderedDict
import tempfile
import xml.etree.ElementTree as ET
import xml.dom.minidom as xdm
import xml.etree.ElementTree as ET
from collections import OrderedDict

import astropy.units as u
import numpy as np

import ebtelplusplus

__all__ = ['run_ebtel', 'read_xml', 'write_xml']


class EbtelPlusPlusError(Exception):
"""
Raise this exception when there's an ebtel++ error
"""
pass


def run_ebtel(config):
"""
Run an ebtel++ simulation
Parameters
----------
config: `dict`
Dictionary of configuration options
ebtel_dir: `str`
Path to directory containing ebtel++ source code.
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_filename = os.path.join(tmpdir, 'ebtelplusplus.tmp.xml')
results_filename = os.path.join(tmpdir, 'ebtelplusplus.tmp')
config['output_filename'] = results_filename
write_xml(config, config_filename)
res = ebtelplusplus.run(config_filename)
data = np.loadtxt(results_filename)

results = {
'time': data[:, 0]*u.s,
'electron_temperature': data[:, 1]*u.K,
'ion_temperature': data[:, 2]*u.K,
'density': data[:, 3]*u.cm**(-3),
'electron_pressure': data[:, 4]*u.dyne/(u.cm**2),
'ion_pressure': data[:, 5]*u.dyne/(u.cm**2),
'velocity': data[:, 6]*u.cm/u.s,
'heat': data[:, 7]*u.erg/(u.cm**3*u.s),
}

results_dem = {}
if config['calculate_dem']:
results_dem['dem_tr'] = np.loadtxt(
config['output_filename'] + '.dem_tr')
results_dem['dem_corona'] = np.loadtxt(
config['output_filename'] + '.dem_corona')
# The first row of both is the temperature bins
results_dem['dem_temperature'] = results_dem['dem_tr'][0, :]*u.K
results_dem['dem_tr'] = results_dem['dem_tr'][1:, :]*u.Unit('cm-5 K-1')
results_dem['dem_corona'] = results_dem['dem_corona'][1:, :]*u.Unit('cm-5 K-1')

return {**results, **results_dem}
__all__ = ['read_xml', 'write_xml']


def read_xml(input_filename,):
Expand Down Expand Up @@ -109,15 +50,15 @@ def read_node(node):
for child in node:
tmp[child.tag] = read_node(child)
return tmp

if node.text:
return type_checker(node.text)
elif node.attrib:
return {key: type_checker(node.attrib[key]) for key in node.attrib}
else:
if node.text:
return type_checker(node.text)
elif node.attrib:
return {key: type_checker(node.attrib[key]) for key in node.attrib}
else:
warnings.warn(
f'Unrecognized node format for {node.tag}. Returning None.')
return None
warnings.warn(
f'Unrecognized node format for {node.tag}. Returning None.')
return None


def bool_filter(val):
Expand Down

0 comments on commit 4ccfcac

Please sign in to comment.