From 111b1d25e2ab6e09facd491a26957ed06e719c3e Mon Sep 17 00:00:00 2001 From: Guy Azran Date: Thu, 3 Aug 2023 12:22:06 +0300 Subject: [PATCH] PyMJCF recursive include tags relative to base model --- dm_control/mjcf/parser.py | 112 ++++++++++++++++++++++++++++++-------- 1 file changed, 90 insertions(+), 22 deletions(-) diff --git a/dm_control/mjcf/parser.py b/dm_control/mjcf/parser.py index c710f179..4ec96b2e 100644 --- a/dm_control/mjcf/parser.py +++ b/dm_control/mjcf/parser.py @@ -26,7 +26,8 @@ def from_xml_string(xml_string, escape_separators=False, - model_dir='', resolve_references=True, assets=None): + model_dir='', resolve_references=True, assets=None, + base_model_dir=None): """Parses an XML string into an MJCF object model. Args: @@ -41,6 +42,9 @@ def from_xml_string(xml_string, escape_separators=False, assets: (optional) A dictionary of pre-loaded assets, of the form `{filename: bytestring}`. If present, PyMJCF will search for assets in this dictionary before attempting to load them from the filesystem. + base_model_dir: (optional) Path to the directory containing the base model. + This is used to prefix the paths of elements' file attributes + to support nested includes as in the MuJoCo compiler. Returns: An `mjcf.RootElement`. @@ -49,11 +53,12 @@ def from_xml_string(xml_string, escape_separators=False, return _parse(xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, - assets=assets) + assets=assets, base_model_dir=base_model_dir) def from_file(file_handle, escape_separators=False, - model_dir='', resolve_references=True, assets=None): + model_dir='', resolve_references=True, assets=None, + base_model_dir=None): """Parses an XML file into an MJCF object model. Args: @@ -68,6 +73,9 @@ def from_file(file_handle, escape_separators=False, assets: (optional) A dictionary of pre-loaded assets, of the form `{filename: bytestring}`. If present, PyMJCF will search for assets in this dictionary before attempting to load them from the filesystem. + base_model_dir: (optional) Path to the directory containing the base model. + This is used to prefix the paths of elements' file attributes + to support nested includes as in the MuJoCo compiler. Returns: An `mjcf.RootElement`. @@ -76,11 +84,11 @@ def from_file(file_handle, escape_separators=False, return _parse(xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, - assets=assets) + assets=assets, base_model_dir=base_model_dir) def from_path(path, escape_separators=False, resolve_references=True, - assets=None): + assets=None, base_model_dir=None): """Parses an XML file into an MJCF object model. Args: @@ -94,6 +102,9 @@ def from_path(path, escape_separators=False, resolve_references=True, assets: (optional) A dictionary of pre-loaded assets, of the form `{filename: bytestring}`. If present, PyMJCF will search for assets in this dictionary before attempting to load them from the filesystem. + base_model_dir: (optional) Path to the directory containing the base model. + This is used to prefix the paths of elements' file attributes + to support nested includes as in the MuJoCo compiler. Returns: An `mjcf.RootElement`. @@ -103,11 +114,12 @@ def from_path(path, escape_separators=False, resolve_references=True, xml_root = etree.fromstring(contents) return _parse(xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, - assets=assets) + assets=assets, base_model_dir=base_model_dir) def _parse(xml_root, escape_separators=False, - model_dir='', resolve_references=True, assets=None): + model_dir='', resolve_references=True, assets=None, + base_model_dir=None): """Parses a complete MJCF model from an XML. Args: @@ -122,6 +134,9 @@ def _parse(xml_root, escape_separators=False, assets: (optional) A dictionary of pre-loaded assets, of the form `{filename: bytestring}`. If present, PyMJCF will search for assets in this dictionary before attempting to load them from the filesystem. + base_model_dir: (optional) Path to the directory containing the base model. + This is used to prefix the paths of elements' file attributes + to support nested includes as in the MuJoCo compiler. Returns: An `mjcf.RootElement`. @@ -140,20 +155,9 @@ def _parse(xml_root, escape_separators=False, # Recursively parse any included XML files. to_include = [] for include_tag in xml_root.findall('include'): - try: - # First look for the path to the included XML file in the assets dict. - path_or_xml_string = assets[include_tag.attrib['file']] - parsing_func = from_xml_string - except KeyError: - # If it's not present in the assets dict then attempt to load the XML - # from the filesystem. - path_or_xml_string = os.path.join(model_dir, include_tag.attrib['file']) - parsing_func = from_path - included_mjcf = parsing_func( - path_or_xml_string, - escape_separators=escape_separators, - resolve_references=resolve_references, - assets=assets) + included_mjcf = _parse_include(include_tag, escape_separators, model_dir, + resolve_references, assets, base_model_dir) + to_include.append(included_mjcf) # We must remove tags before parsing the main XML file, since # these are a schema violation. @@ -165,7 +169,7 @@ def _parse(xml_root, escape_separators=False, except KeyError: model = None mjcf_root = element.RootElement( - model=model, model_dir=model_dir, assets=assets) + model=model, model_dir=base_model_dir or model_dir, assets=assets) _parse_children(xml_root, mjcf_root, escape_separators) # Merge in the included XML files. @@ -180,6 +184,70 @@ def _parse(xml_root, escape_separators=False, return mjcf_root +def _parse_include(include_tag, escape_separators, model_dir, resolve_references, assets, base_model_dir): + """ + Parses an included XML file. + + Args: + include_tag: An `etree.Element` object with tag 'include'. + escape_separators: (optional) A boolean, whether to replace '/' characters + in element identifiers. If `False`, any '/' present in the XML causes + a ValueError to be raised. + model_dir: (optional) Path to the directory containing the model XML file. + This is used to prefix the paths of all asset files. + resolve_references: (optional) A boolean indicating whether the parser + should attempt to resolve reference attributes to a corresponding element. + assets: (optional) A dictionary of pre-loaded assets, of the form + `{filename: bytestring}`. If present, PyMJCF will search for assets in + this dictionary before attempting to load them from the filesystem. + base_model_dir: (optional) Path to the directory containing the base model. + This is used to prefix the paths of elements' file attributes + to support nested includes as in the MuJoCo compiler. + + Returns: + An `mjcf.RootElement`. + + Raises: + FileNotFoundError: If the included the inner paths of the included XML could + not be resolved. + """ + + base_dirs = [model_dir] # always look in the current model dir first + if base_model_dir is not None: + base_dirs.append(base_model_dir) # then look in the base model dir if provided + + not_found_exception = None # a container for the final exception if some file references are not resolved + + # try to parse the included XML file from each of the base dirs + for working_dir in base_dirs: + + # setup new parsing kwargs dict with current base model dir + parsing_func_kwargs = dict( + escape_separators=escape_separators, + resolve_references=resolve_references, + assets=assets, + base_model_dir=working_dir + ) + + try: + # First look for the path to the included XML file in the assets dict. + path_or_xml_string = assets[include_tag.attrib['file']] + parsing_func = from_xml_string + parsing_func_kwargs.update(dict(model_dir=working_dir)) # requires explicit model dir + except KeyError: + # If it's not present in the assets dict then attempt to load the XML + # from the filesystem. + path_or_xml_string = os.path.join(working_dir, include_tag.attrib['file']) + parsing_func = from_path + try: + # if successfully parsed the included XML file, stop searching + return parsing_func(path_or_xml_string, **parsing_func_kwargs) + except FileNotFoundError as e: + # base model dir did not resolve the inner include paths + not_found_exception = e + + raise FileNotFoundError('Could not find an appropriate base path for include tag') from not_found_exception + def _parse_children(xml_element, mjcf_element, escape_separators=False): """Parses all children of a given XML element into an MJCF element.