Skip to content

Commit 55d2a5c

Browse files
Adrian Collistercopybara-github
authored andcommitted
Add from_zip to parser.
PiperOrigin-RevId: 809703253 Change-Id: I12b77a7d7c539419c2e3ee4fce50ed04e604a3cc
1 parent 696bc70 commit 55d2a5c

File tree

5 files changed

+66
-2
lines changed

5 files changed

+66
-2
lines changed

dm_control/mjcf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from dm_control.mjcf.parser import from_file
3030
from dm_control.mjcf.parser import from_path
3131
from dm_control.mjcf.parser import from_xml_string
32+
from dm_control.mjcf.parser import from_zip
3233

3334
from dm_control.mjcf.physics import Physics
3435

dm_control/mjcf/parser.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
"""Functions for parsing XML into an MJCF object model."""
1717

18+
import io
1819
import os
1920
import sys
21+
import zipfile
2022

2123
from dm_control.mjcf import constants
2224
from dm_control.mjcf import debugging
@@ -107,6 +109,60 @@ def from_path(path, escape_separators=False, resolve_references=True,
107109
assets=assets)
108110

109111

112+
def from_zip(path, model_file='model.xml', escape_separators=False,
113+
resolve_references=True):
114+
"""Parses a zipped XML file into an MJCF object model.
115+
116+
Args:
117+
path: A path to a zip file containing an MJCF model and its assets.
118+
model_file: If the zip contains multiple XML files, specify the name of the
119+
main model file. Ignored if the zip only contains one XML file.
120+
escape_separators: (optional) A boolean, whether to replace '/' characters
121+
in element identifiers. If `False`, any '/' present in the XML causes a
122+
ValueError to be raised.
123+
resolve_references: (optional) A boolean indicating whether the parser
124+
should attempt to resolve reference attributes to a corresponding element.
125+
126+
Returns:
127+
An `mjcf.RootElement`.
128+
129+
Raises:
130+
ValueError: If:
131+
- the path does not point to a zip file
132+
- the zip file contains no XML files
133+
- the zip file contains more than one XML file and none of them have the
134+
name specified in `model_file`.
135+
"""
136+
contents = resources.GetResource(path)
137+
if not zipfile.is_zipfile(io.BytesIO(contents)):
138+
raise ValueError(f'File {path} is not a zip file.')
139+
with zipfile.ZipFile(io.BytesIO(contents), 'r') as zf:
140+
xml_files = [f for f in zf.namelist() if f.endswith('.xml')]
141+
if not xml_files:
142+
raise ValueError(f'No XML file found in {path}.')
143+
elif len(xml_files) > 1:
144+
model_files = [f for f in xml_files if f == model_file]
145+
if not model_files:
146+
raise ValueError(
147+
f'Multiple XML files found in {path}, but none named {model_file}.'
148+
)
149+
xml_path = model_files[0]
150+
else:
151+
xml_path = xml_files[0]
152+
xml_string = zf.read(xml_path)
153+
154+
model_dir = os.path.dirname(xml_path)
155+
assets = {
156+
os.path.relpath(name, model_dir): zf.read(name)
157+
for name in zf.namelist()
158+
if not (name.endswith(os.path.sep) or name == xml_path)
159+
}
160+
161+
xml_root = etree.fromstring(xml_string)
162+
return _parse(xml_root, escape_separators,
163+
resolve_references=resolve_references, assets=assets)
164+
165+
110166
def _parse(xml_root, escape_separators=False,
111167
model_dir='', resolve_references=True, assets=None):
112168
"""Parses a complete MJCF model from an XML.
5.04 KB
Binary file not shown.

dm_control/mjcf/xml_validation_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
_ARENA_XML = os.path.join(ASSETS_DIR, 'arena.xml')
2626
_LEGO_BRICK_XML = os.path.join(ASSETS_DIR, 'lego_brick.xml')
2727
_ROBOT_XML = os.path.join(ASSETS_DIR, 'robot_arm.xml')
28+
_ZIPPED_MODEL = os.path.join(ASSETS_DIR, 'model_with_assetdir.zip')
2829

2930

30-
def validate(xml_string):
31+
def validate(xml_string, assets=None):
3132
"""Validates that an XML string is a valid MJCF.
3233
3334
Validation is performed by constructing Mujoco model from the string.
@@ -36,9 +37,10 @@ def validate(xml_string):
3637
3738
Args:
3839
xml_string: XML string to validate
40+
assets: Optional dict of assets to use for the model.
3941
"""
4042

41-
mjmodel = wrapper.MjModel.from_xml_string(xml_string)
43+
mjmodel = wrapper.MjModel.from_xml_string(xml_string, assets)
4244
wrapper.MjData(mjmodel)
4345

4446

@@ -61,6 +63,10 @@ def testXmlAttach(self):
6163
# validate
6264
validate(arena.to_xml_string())
6365

66+
def testXmlFromZip(self):
67+
model = parser.from_zip(_ZIPPED_MODEL)
68+
validate(model.to_xml_string(), model.get_assets())
69+
6470

6571
if __name__ == '__main__':
6672
absltest.main()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def is_excluded(s):
222222
'*.xml',
223223
'*.textproto',
224224
'*.h5',
225+
'*.zip',
225226
],
226227
excludes=[
227228
'*/dog_assets/extras/*',

0 commit comments

Comments
 (0)