diff --git a/src/ebtelplusplus/tests/util.py b/src/ebtelplusplus/util.py similarity index 58% rename from src/ebtelplusplus/tests/util.py rename to src/ebtelplusplus/util.py index 77741de..2d73440 100644 --- a/src/ebtelplusplus/tests/util.py +++ b/src/ebtelplusplus/util.py @@ -1,71 +1,12 @@ """ Utility functions for configuring and running ebtel++ simulations """ -import os -import subprocess import warnings -from collections import OrderedDict -import tempfile -import xml.etree.ElementTree as ET import xml.dom.minidom as xdm +import xml.etree.ElementTree as ET +from collections import OrderedDict -import astropy.units as u -import numpy as np - -import ebtelplusplus - -__all__ = ['run_ebtel', 'read_xml', 'write_xml'] - - -class EbtelPlusPlusError(Exception): - """ - Raise this exception when there's an ebtel++ error - """ - pass - - -def run_ebtel(config): - """ - Run an ebtel++ simulation - - Parameters - ---------- - config: `dict` - Dictionary of configuration options - ebtel_dir: `str` - Path to directory containing ebtel++ source code. - """ - with tempfile.TemporaryDirectory() as tmpdir: - config_filename = os.path.join(tmpdir, 'ebtelplusplus.tmp.xml') - results_filename = os.path.join(tmpdir, 'ebtelplusplus.tmp') - config['output_filename'] = results_filename - write_xml(config, config_filename) - res = ebtelplusplus.run(config_filename) - data = np.loadtxt(results_filename) - - results = { - 'time': data[:, 0]*u.s, - 'electron_temperature': data[:, 1]*u.K, - 'ion_temperature': data[:, 2]*u.K, - 'density': data[:, 3]*u.cm**(-3), - 'electron_pressure': data[:, 4]*u.dyne/(u.cm**2), - 'ion_pressure': data[:, 5]*u.dyne/(u.cm**2), - 'velocity': data[:, 6]*u.cm/u.s, - 'heat': data[:, 7]*u.erg/(u.cm**3*u.s), - } - - results_dem = {} - if config['calculate_dem']: - results_dem['dem_tr'] = np.loadtxt( - config['output_filename'] + '.dem_tr') - results_dem['dem_corona'] = np.loadtxt( - config['output_filename'] + '.dem_corona') - # The first row of both is the temperature bins - results_dem['dem_temperature'] = results_dem['dem_tr'][0, :]*u.K - results_dem['dem_tr'] = results_dem['dem_tr'][1:, :]*u.Unit('cm-5 K-1') - results_dem['dem_corona'] = results_dem['dem_corona'][1:, :]*u.Unit('cm-5 K-1') - - return {**results, **results_dem} +__all__ = ['read_xml', 'write_xml'] def read_xml(input_filename,): @@ -109,15 +50,15 @@ def read_node(node): for child in node: tmp[child.tag] = read_node(child) return tmp + + if node.text: + return type_checker(node.text) + elif node.attrib: + return {key: type_checker(node.attrib[key]) for key in node.attrib} else: - if node.text: - return type_checker(node.text) - elif node.attrib: - return {key: type_checker(node.attrib[key]) for key in node.attrib} - else: - warnings.warn( - f'Unrecognized node format for {node.tag}. Returning None.') - return None + warnings.warn( + f'Unrecognized node format for {node.tag}. Returning None.') + return None def bool_filter(val):