|
15 | 15 |
|
16 | 16 | """Functions for parsing XML into an MJCF object model.""" |
17 | 17 |
|
| 18 | +import io |
18 | 19 | import os |
19 | 20 | import sys |
| 21 | +import zipfile |
20 | 22 |
|
21 | 23 | from dm_control.mjcf import constants |
22 | 24 | from dm_control.mjcf import debugging |
@@ -107,6 +109,60 @@ def from_path(path, escape_separators=False, resolve_references=True, |
107 | 109 | assets=assets) |
108 | 110 |
|
109 | 111 |
|
| 112 | +def from_zip(path, model_file='model.xml', escape_separators=False, |
| 113 | + resolve_references=True): |
| 114 | + """Parses a zipped XML file into an MJCF object model. |
| 115 | +
|
| 116 | + Args: |
| 117 | + path: A path to a zip file containing an MJCF model and its assets. |
| 118 | + model_file: If the zip contains multiple XML files, specify the name of the |
| 119 | + main model file. Ignored if the zip only contains one XML file. |
| 120 | + escape_separators: (optional) A boolean, whether to replace '/' characters |
| 121 | + in element identifiers. If `False`, any '/' present in the XML causes a |
| 122 | + ValueError to be raised. |
| 123 | + resolve_references: (optional) A boolean indicating whether the parser |
| 124 | + should attempt to resolve reference attributes to a corresponding element. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + An `mjcf.RootElement`. |
| 128 | +
|
| 129 | + Raises: |
| 130 | + ValueError: If: |
| 131 | + - the path does not point to a zip file |
| 132 | + - the zip file contains no XML files |
| 133 | + - the zip file contains more than one XML file and none of them have the |
| 134 | + name specified in `model_file`. |
| 135 | + """ |
| 136 | + contents = resources.GetResource(path) |
| 137 | + if not zipfile.is_zipfile(io.BytesIO(contents)): |
| 138 | + raise ValueError(f'File {path} is not a zip file.') |
| 139 | + with zipfile.ZipFile(io.BytesIO(contents), 'r') as zf: |
| 140 | + xml_files = [f for f in zf.namelist() if f.endswith('.xml')] |
| 141 | + if not xml_files: |
| 142 | + raise ValueError(f'No XML file found in {path}.') |
| 143 | + elif len(xml_files) > 1: |
| 144 | + model_files = [f for f in xml_files if f == model_file] |
| 145 | + if not model_files: |
| 146 | + raise ValueError( |
| 147 | + f'Multiple XML files found in {path}, but none named {model_file}.' |
| 148 | + ) |
| 149 | + xml_path = model_files[0] |
| 150 | + else: |
| 151 | + xml_path = xml_files[0] |
| 152 | + xml_string = zf.read(xml_path) |
| 153 | + |
| 154 | + model_dir = os.path.dirname(xml_path) |
| 155 | + assets = { |
| 156 | + os.path.relpath(name, model_dir): zf.read(name) |
| 157 | + for name in zf.namelist() |
| 158 | + if not (name.endswith(os.path.sep) or name == xml_path) |
| 159 | + } |
| 160 | + |
| 161 | + xml_root = etree.fromstring(xml_string) |
| 162 | + return _parse(xml_root, escape_separators, |
| 163 | + resolve_references=resolve_references, assets=assets) |
| 164 | + |
| 165 | + |
110 | 166 | def _parse(xml_root, escape_separators=False, |
111 | 167 | model_dir='', resolve_references=True, assets=None): |
112 | 168 | """Parses a complete MJCF model from an XML. |
|
0 commit comments