diff --git a/tests/parser-cases/foo.bar.thrift b/tests/parser-cases/foo.bar.thrift new file mode 100644 index 0000000..d9b3174 --- /dev/null +++ b/tests/parser-cases/foo.bar.thrift @@ -0,0 +1 @@ +include "foo/bar.thrift" \ No newline at end of file diff --git a/tests/parser-cases/foo/bar.thrift b/tests/parser-cases/foo/bar.thrift new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_parser.py b/tests/test_parser.py index cb51a71..a0388dc 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -40,6 +40,12 @@ def test_include(): assert thrift.datetime == 1422009523 +def test_include_conflict(): + with pytest.raises(ThriftParserError) as excinfo: + load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift') + assert 'Module name conflict between "parser-cases/foo.bar.thrift" and "parser-cases/foo/bar.thrift"' == str(excinfo.value) + + def test_cpp_include(): load('parser-cases/cpp_include.thrift') diff --git a/thriftpy2/parser/__init__.py b/thriftpy2/parser/__init__.py index 16bb8fa..856941a 100644 --- a/thriftpy2/parser/__init__.py +++ b/thriftpy2/parser/__init__.py @@ -41,16 +41,20 @@ def load(path, # add sub modules to sys.modules recursively if real_module: sys.modules[module_name] = thrift - include_thrifts = thrift.__thrift_meta__["includes"][:] + include_thrifts = list(zip(thrift.__thrift_meta__["includes"][:], + thrift.__thrift_meta__["sub_modules"][:])) while include_thrifts: include_thrift = include_thrifts.pop() - lost_sub_modules = [ - m for m in thrift.__thrift_meta__["sub_modules"] if m not in sys.modules - ] - for module in lost_sub_modules: - sys.modules[module.__name__] = include_thrift - if include_thrift.__name__ not in sys.modules: - include_thrifts.extend(include_thrift.__thrift_meta__["includes"]) + registered_thrift = sys.modules.get(include_thrift[1].__name__) + if registered_thrift is None: + sys.modules[include_thrift[1].__name__] = include_thrift[0] + include_thrifts.extend(include_thrift[0].__thrift_meta__["includes"]) + else: + if registered_thrift.__thrift_file__ != include_thrift[0].__thrift_file__: + raise ThriftParserError( + 'Module name conflict between "%s" and "%s"' % + (registered_thrift.__thrift_file__, include_thrift[0].__thrift_file__) + ) return thrift diff --git a/thriftpy2/parser/parser.py b/thriftpy2/parser/parser.py index 441f401..be44975 100644 --- a/thriftpy2/parser/parser.py +++ b/thriftpy2/parser/parser.py @@ -62,15 +62,8 @@ def p_include(p): for include_dir in replace_include_dirs: path = os.path.join(include_dir, p[2]) if os.path.exists(path): - child_path = os.path.normpath( - os.path.dirname(str(thrift.__name__).replace("_thrift", "").replace(".", os.sep)) + os.sep + p[2]) - - child_path = child_path.lstrip(os.sep) - - child_module_name = str( - child_path).replace(os.sep, - ".").replace( - ".thrift", "_thrift") + child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__)) + child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift") child = parse(path, module_name=child_module_name) setattr(thrift, str(child.__name__).replace("_thrift", ""), child)