Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix security issue with XML parsing #5686

Merged
merged 5 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions model-optimizer/mo/back/ie_ir_ver_2/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import hashlib
from xml.etree.ElementTree import Element, SubElement, tostring, ElementTree

from defusedxml import defuse_stdlib
import defusedxml.ElementTree as ET
mvafin marked this conversation as resolved.
Show resolved Hide resolved
from defusedxml.minidom import parseString

from mo.graph.graph import *
Expand All @@ -12,6 +13,13 @@
from mo.utils.utils import refer_to_faq_msg
from mo.utils.version import get_version

# defuse_stdlib provide patched version of xml.etree.ElementTree which allows to use objects from xml.etree.ElementTree
# in a safe manner without including unsafe xml.etree.ElementTree
ET_defused = defuse_stdlib()[ET]
mvafin marked this conversation as resolved.
Show resolved Hide resolved
Element = ET_defused.Element
SubElement = ET_defused.SubElement
tostring = ET_defused.tostring


def serialize_constants(graph: Graph, bin_file_name: str, data_type=np.float32):
"""
Expand Down Expand Up @@ -444,8 +452,7 @@ def append_ir_info(file: str, meta_info: dict = dict(), mean_data: [list, None]
path_to_xml = file + ".xml"
path_to_bin = file + ".bin"

et = ElementTree()
et.parse(path_to_xml)
et = ET.parse(path_to_xml)
net = et.getroot()

if mean_data:
Expand All @@ -462,4 +469,4 @@ def append_ir_info(file: str, meta_info: dict = dict(), mean_data: [list, None]

pretty_xml_as_string = parseString(tostring(net)).toprettyxml()
with open(path_to_xml, 'wb') as file:
file.write(bytes(pretty_xml_as_string, "UTF-8"))
file.write(bytes(pretty_xml_as_string, "UTF-8"))
11 changes: 9 additions & 2 deletions model-optimizer/mo/middle/passes/tensor_names.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from xml.etree.ElementTree import Element, SubElement, tostring

from defusedxml import defuse_stdlib
from defusedxml.minidom import parseString
import defusedxml.ElementTree as ET
mvafin marked this conversation as resolved.
Show resolved Hide resolved

from mo.graph.graph import Node, Graph

# defuse_stdlib provide patched version of xml.etree.ElementTree which allows to use objects from xml.etree.ElementTree
# in a safe manner without including unsafe xml.etree.ElementTree
ET_defused = defuse_stdlib()[ET]
Element = ET_defused.Element
SubElement = ET_defused.SubElement
tostring = ET_defused.tostring


def propagate_op_name_to_tensor(graph: Graph):
for node in graph.nodes():
Expand Down
10 changes: 7 additions & 3 deletions model-optimizer/mo/utils/ir_engine/ir_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import logging as log
import os
import sys
import xml.etree.ElementTree as ET

from defusedxml import defuse_stdlib
import defusedxml.ElementTree as ET
from argparse import Namespace
from collections import namedtuple, defaultdict
from pathlib import Path
Expand All @@ -17,6 +19,9 @@

log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.DEBUG, stream=sys.stdout)

# defuse_stdlib provide patched version of xml.etree.ElementTree which allows to use objects from xml.etree.ElementTree
# in a safe manner without including unsafe xml.etree.ElementTree
ElementTree = defuse_stdlib()[ET].ElementTree

class IREngine(object):
def __init__(self, path_to_xml: str, path_to_bin=None, precision="FP32", xml_tree=None):
Expand Down Expand Up @@ -88,7 +93,6 @@ def __load_xml(self):
elif elem.tag in ['version', 'cli_params']:
self.meta_data['quantization_parameters'][elem.tag] = elem.attrib['value']


self.graph.graph['cmd_params'] = Namespace(**self.meta_data) # TODO check what we need all this attrs

if len(statistics):
Expand Down Expand Up @@ -237,7 +241,7 @@ def __load_layer(self, layer):

body_ir = IREngine(path_to_xml=None,
path_to_bin=self.path_to_bin,
xml_tree=ET.ElementTree(xml_body_child[0]))
xml_tree=ElementTree(xml_body_child[0]))
self.graph.graph['hashes'].update(body_ir.graph.graph['hashes'])

# Find port_map section and take an input_port_map & output_port_map
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import unittest
import tempfile

from mo.utils.ir_reader.restore_graph import restore_graph_from_ir
from defusedxml.common import EntitiesForbidden


class TestIRReader(unittest.TestCase):
def setUp(self):
self.xml_bomb = b'<?xml version="1.0"?>\n' \
b'<!DOCTYPE lolz [\n' \
b' <!ENTITY lol "lol">\n' \
b' <!ELEMENT lolz (#PCDATA)>\n' \
b' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">\n' \
b' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">\n' \
b' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">\n' \
b' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">\n' \
b' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">\n' \
b' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">\n' \
b' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">\n' \
b' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">\n' \
b' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">\n' \
b']>\n' \
b'<lolz>&lol9;</lolz>'

def test_read_xml_bomb(self):
bomb_file = tempfile.NamedTemporaryFile(delete=False)
bomb_file.write(self.xml_bomb)
bomb_file.close()
self.assertRaises(EntitiesForbidden, restore_graph_from_ir, bomb_file.name)
os.remove(bomb_file.name)