diff --git a/.coveragerc b/.coveragerc index 7121c9090..86515ed88 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,7 +4,9 @@ omit = *tests* docs/* omegaconf/grammar/gen/* + omegaconf/vendor/* omegaconf/version.py + omegaconf/typing.py .stubs [report] diff --git a/.flake8 b/.flake8 index 7a36d6a4e..78ef79828 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -exclude = .git,.nox,.tox,omegaconf/grammar/gen,build +exclude = .git,.nox,.tox,omegaconf/grammar/gen,build,omegaconf/vendor max-line-length = 119 select = E,F,W,C ignore=W503,E203 diff --git a/.isort.cfg b/.isort.cfg index 8d4c54412..1fa164cdb 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -7,4 +7,4 @@ line_length=88 ensure_newline_before_comments=True known_third_party=attr,pytest known_first_party=omegaconf -skip=.eggs,.nox,omegaconf/grammar/gen,build +skip=.eggs,.nox,omegaconf/grammar/gen,build,omegaconf/vendor diff --git a/build_helpers/bin/antlr-4.11.1-complete.jar b/build_helpers/bin/antlr-4.11.1-complete.jar new file mode 100644 index 000000000..bb96df951 Binary files /dev/null and b/build_helpers/bin/antlr-4.11.1-complete.jar differ diff --git a/build_helpers/bin/antlr-4.9.3-complete.jar b/build_helpers/bin/antlr-4.9.3-complete.jar deleted file mode 100644 index 749296fe7..000000000 Binary files a/build_helpers/bin/antlr-4.9.3-complete.jar and /dev/null differ diff --git a/build_helpers/build_helpers.py b/build_helpers/build_helpers.py index 6419e2604..25b057055 100644 --- a/build_helpers/build_helpers.py +++ b/build_helpers/build_helpers.py @@ -6,6 +6,7 @@ import shutil import subprocess import sys +from functools import partial from pathlib import Path from typing import List, Optional @@ -30,7 +31,7 @@ def run(self) -> None: command = [ "java", "-jar", - str(build_dir / "bin" / "antlr-4.9.3-complete.jar"), + str(build_dir / "bin" / "antlr-4.11.1-complete.jar"), "-Dlanguage=Python3", "-o", str(project_root / "omegaconf" / "grammar" / "gen"), @@ -46,12 +47,44 @@ def run(self) -> None: subprocess.check_call(command) + self.announce( + "Fixing imports for generated parsers", + level=distutils.log.INFO, + ) + self._fix_imports() + def initialize_options(self) -> None: pass def finalize_options(self) -> None: pass + def _fix_imports(self) -> None: + """Fix imports from the generated parsers to use the vendored antlr4 instead""" + build_dir = Path(__file__).parent.absolute() + project_root = build_dir.parent + lib = "antlr4" + pkgname = 'omegaconf.vendor' + + replacements = [ + partial( # import antlr4 -> import omegaconf.vendor.antlr4 + re.compile(r'(^\s*)import {}\n'.format(lib), flags=re.M).sub, + r'\1from {} import {}\n'.format(pkgname, lib) + ), + partial( # from antlr4 -> from fomegaconf.vendor.antlr4 + re.compile(r'(^\s*)from {}(\.|\s+)'.format(lib), flags=re.M).sub, + r'\1from {}.{}\2'.format(pkgname, lib) + ), + ] + + path = project_root / "omegaconf" / "grammar" / "gen" + for item in path.iterdir(): + if item.is_file() and item.name.endswith(".py"): + text = item.read_text('utf8') + for replacement in replacements: + text = replacement(text) + item.write_text(text, 'utf8') + class BuildPyCommand(build_py.build_py): # pragma: no cover def run(self) -> None: diff --git a/build_helpers/get_vendored.py b/build_helpers/get_vendored.py new file mode 100644 index 000000000..2f5d65913 --- /dev/null +++ b/build_helpers/get_vendored.py @@ -0,0 +1,121 @@ +import re +import shutil +import subprocess +from functools import partial +from itertools import chain +from pathlib import Path +from typing import Callable, FrozenSet, Generator, List, Set, Tuple, Union + +WHITELIST = {'README.txt', '__init__.py', 'vendor.txt'} + + +def delete_all(*paths: Path, whitelist: Union[Set[str], FrozenSet[str]] = frozenset()) -> None: + """Clear all the items in each of the indicated paths, except for elements listed + in the whitelist""" + for item in paths: + if item.is_dir(): + shutil.rmtree(item, ignore_errors=True) + elif item.is_file() and item.name not in whitelist: + item.unlink() + + +def iter_subtree(path: Path, depth: int = 0) -> Generator[Tuple[Path, int], None, None]: + """Recursively yield all files in a subtree, depth-first""" + if not path.is_dir(): + if path.is_file(): + yield path, depth + return + for item in path.iterdir(): + if item.is_dir(): + yield from iter_subtree(item, depth + 1) + elif item.is_file(): + yield item, depth + 1 + + +def patch_vendor_imports(file: Path, replacements: List[Callable[[str], str]]) -> None: + """Apply a list of replacements/patches to a given file""" + text = file.read_text('utf8') + for replacement in replacements: + text = replacement(text) + file.write_text(text, 'utf8') + + +def find_vendored_libs(vendor_dir: Path, whitelist: Set[str]) -> Tuple[List[str], List[Path]]: + vendored_libs = [] + paths = [] + for item in vendor_dir.iterdir(): + if item.is_dir(): + vendored_libs.append(item.name) + elif item.is_file() and item.name not in whitelist: + vendored_libs.append(item.stem) # without extension + else: # not a dir or a file not in the whilelist + continue + paths.append(item) + return vendored_libs, paths + + +def vendor(vendor_dir: Path, relative_imports: bool = False) -> None: + # target package is .; foo/vendor -> foo.vendor + pkgname = f'{vendor_dir.parent.name}.{vendor_dir.name}' + + # remove everything + delete_all(*vendor_dir.iterdir(), whitelist=WHITELIST) + + # install with pip + subprocess.run([ + 'pip', 'install', '-t', str(vendor_dir), + '-r', str(vendor_dir / 'vendor.txt'), + '--no-compile', '--no-deps' + ]) + + # delete stuff that's not needed + delete_all( + *vendor_dir.glob('*.dist-info'), + *vendor_dir.glob('*.egg-info'), + vendor_dir / 'bin') + + vendored_libs, paths = find_vendored_libs(vendor_dir, WHITELIST) + + if not relative_imports: + replacements: List[Callable[[str], str]] = [] + for lib in vendored_libs: + replacements += ( + partial( # import bar -> import foo.vendor.bar + re.compile(r'(^\s*)import {}\n'.format(lib), flags=re.M).sub, + r'\1from {} import {}\n'.format(pkgname, lib) + ), + partial( # from bar -> from foo.vendor.bar + re.compile(r'(^\s*)from {}(\.|\s+)'.format(lib), flags=re.M).sub, + r'\1from {}.{}\2'.format(pkgname, lib) + ), + ) + + for file, depth in chain.from_iterable(map(iter_subtree, paths)): + if relative_imports: + pkgname = '.' * (depth - 1) + replacements = [] + for lib in vendored_libs: + replacements += ( + partial( + re.compile(r'(^\s*)import {}\n'.format(lib), flags=re.M).sub, + r'\1from {} import {}\n'.format(pkgname, "") + ), + partial( + re.compile(r'^from {}(\s+)'.format(lib), flags=re.M).sub, + r'from .{}\1'.format(pkgname) + ), + partial( + re.compile(r'(^\s*)from {}(\.+)'.format(lib), flags=re.M).sub, + r'\1from {}\2'.format(pkgname) + ), + ) + patch_vendor_imports(file, replacements) + + +if __name__ == '__main__': + # this assumes this is a script in `build_helpers` + here = Path('__file__').resolve().parent + vendor_dir = here / 'omegaconf' / 'vendor' + assert (vendor_dir / 'vendor.txt').exists(), 'omegaconf/vendor/vendor.txt file not found' + assert (vendor_dir / '__init__.py').exists(), 'omegaconf/vendor/__init__.py file not found' + vendor(vendor_dir, relative_imports=True) diff --git a/docs/notebook/Tutorial.ipynb b/docs/notebook/Tutorial.ipynb index d17712c57..1f3cba167 100644 --- a/docs/notebook/Tutorial.ipynb +++ b/docs/notebook/Tutorial.ipynb @@ -41,9 +41,14 @@ } ], "source": [ + "import sys\n", + "import os\n", + "sys.path.insert(0, os.path.abspath(\"../../\"))\n", + "\n", + "\n", "from omegaconf import OmegaConf\n", "conf = OmegaConf.create()\n", - "print(conf)" + "print(conf)\n" ] }, { @@ -78,7 +83,7 @@ ], "source": [ "conf = OmegaConf.create(dict(k='v',list=[1,dict(a='1',b='2')]))\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -112,7 +117,7 @@ ], "source": [ "conf = OmegaConf.create([1, dict(a=10, b=dict(a=10))])\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -150,7 +155,7 @@ ], "source": [ "conf = OmegaConf.load('../source/example.yaml')\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -192,7 +197,7 @@ "- item2\n", "\"\"\"\n", "conf = OmegaConf.create(yaml)\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -230,7 +235,7 @@ "source": [ "dot_list = [\"a.aa.aaa=1\", \"a.aa.bbb=2\", \"a.bb.aaa=3\", \"a.bb.bbb=4\"]\n", "conf = OmegaConf.from_dotlist(dot_list)\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -269,7 +274,7 @@ "import sys\n", "sys.argv = ['your-program.py', 'server.port=82', 'log.file=log2.txt']\n", "conf = OmegaConf.from_cli()\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -308,7 +313,7 @@ ], "source": [ "conf = OmegaConf.load('../source/example.yaml')\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -340,7 +345,7 @@ } ], "source": [ - "conf.server.port" + "conf.server.port\n" ] }, { @@ -372,7 +377,7 @@ } ], "source": [ - "conf['log']['rotation']" + "conf['log']['rotation']\n" ] }, { @@ -404,7 +409,7 @@ } ], "source": [ - "conf.users[0]" + "conf.users[0]\n" ] }, { @@ -425,7 +430,7 @@ }, "outputs": [], "source": [ - "conf.server.port = 81" + "conf.server.port = 81\n" ] }, { @@ -446,7 +451,7 @@ }, "outputs": [], "source": [ - "conf.server.hostname = \"localhost\"" + "conf.server.hostname = \"localhost\"\n" ] }, { @@ -467,7 +472,7 @@ }, "outputs": [], "source": [ - "conf.database = {'hostname': 'database01', 'port': 3306}" + "conf.database = {'hostname': 'database01', 'port': 3306}\n" ] }, { @@ -499,7 +504,7 @@ } ], "source": [ - "conf.get('missing_key', 'a default value')" + "conf.get('missing_key', 'a default value')\n" ] }, { @@ -536,7 +541,7 @@ "try:\n", " conf.log.file\n", "except MissingMandatoryValue as exc:\n", - " print(exc)" + " print(exc)\n" ] }, { @@ -588,7 +593,7 @@ ], "source": [ "conf = OmegaConf.load('../source/config_interpolation.yaml')\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -609,7 +614,7 @@ "# Primitive interpolation types are inherited from the referenced value\n", "print(\"conf.client.server_port: \", conf.client.server_port, type(conf.client.server_port).__name__)\n", "# Composite interpolation types are always string\n", - "print(\"conf.client.url: \", conf.client.url, type(conf.client.url).__name__)" + "print(\"conf.client.url: \", conf.client.url, type(conf.client.url).__name__)\n" ] }, { @@ -643,7 +648,7 @@ } ], "source": [ - "print(OmegaConf.to_yaml(conf, resolve=True))" + "print(OmegaConf.to_yaml(conf, resolve=True))\n" ] }, { @@ -678,7 +683,7 @@ ")\n", "print(f\"Default: cfg.plan = {cfg.plan}\")\n", "cfg.selected_plan = \"B\"\n", - "print(f\"After selecting plan B: cfg.plan = {cfg.plan}\")" + "print(f\"After selecting plan B: cfg.plan = {cfg.plan}\")\n" ] }, { @@ -712,7 +717,7 @@ " \"player\": \"${john}\",\n", " }\n", ")\n", - "(cfg.player.height, cfg.player.weight)" + "(cfg.player.height, cfg.player.weight)\n" ] }, { @@ -733,7 +738,7 @@ "source": [ "# Let's set up the environment first (only needed for this demonstration)\n", "import os\n", - "os.environ['USER'] = 'omry'" + "os.environ['USER'] = 'omry'\n" ] }, { @@ -766,7 +771,7 @@ ], "source": [ "conf = OmegaConf.load('../source/env_interpolation.yaml')\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -791,7 +796,7 @@ ], "source": [ "conf = OmegaConf.load('../source/env_interpolation.yaml')\n", - "print(OmegaConf.to_yaml(conf, resolve=True))" + "print(OmegaConf.to_yaml(conf, resolve=True))\n" ] }, { @@ -832,7 +837,7 @@ ")\n", "print(repr(cfg.database.password1))\n", "print(repr(cfg.database.password2))\n", - "print(repr(cfg.database.password3))" + "print(repr(cfg.database.password3))\n" ] }, { @@ -899,7 +904,7 @@ "print(\"timeout (missing variable):\", repr(cfg.database.timeout))\n", "\n", "os.environ[\"DB_TIMEOUT\"] = \"${.port}\"\n", - "print(\"timeout (interpolation):\", repr(cfg.database.timeout))" + "print(\"timeout (interpolation):\", repr(cfg.database.timeout))\n" ] }, { @@ -942,7 +947,7 @@ "source": [ "OmegaConf.register_new_resolver(\"plus_10\", lambda x: x + 10)\n", "conf = OmegaConf.create({'key': '${plus_10:990}'})\n", - "conf.key" + "conf.key\n" ] }, { @@ -972,7 +977,7 @@ "source": [ "OmegaConf.register_new_resolver(\"plus\", lambda x, y: x + y)\n", "conf = OmegaConf.create({\"a\": 1, \"b\": 2, \"a_plus_b\": \"${plus:${a},${b}}\"})\n", - "conf.a_plus_b" + "conf.a_plus_b\n" ] }, { @@ -1030,8 +1035,7 @@ "\n", "# same value even if `uncached` changes, because the cache is based\n", "# on the string literal \"${uncached}\" that remains the same\n", - "print(\"With cache (interpolation):\", cfg.cached_3, \"==\", cfg.cached_3)\n", - "\n" + "print(\"With cache (interpolation):\", cfg.cached_3, \"==\", cfg.cached_3)\n" ] }, { @@ -1080,7 +1084,7 @@ ], "source": [ "base_conf = OmegaConf.load('../source/example2.yaml')\n", - "print(OmegaConf.to_yaml(base_conf))" + "print(OmegaConf.to_yaml(base_conf))\n" ] }, { @@ -1104,7 +1108,7 @@ ], "source": [ "second_conf = OmegaConf.load('../source/example3.yaml')\n", - "print(OmegaConf.to_yaml(second_conf))" + "print(OmegaConf.to_yaml(second_conf))\n" ] }, { @@ -1142,7 +1146,7 @@ "sys.argv = ['program.py', 'server.port=82']\n", "# Merge with cli arguments\n", "conf.merge_with_cli()\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -1196,9 +1200,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 33, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "server:\n", + " port: 80\n", + "users:\n", + "- user1\n", + "- user2\n", + "- user3\n", + "\n" + ] + } + ], "source": [ "from omegaconf import OmegaConf, ListMergeMode\n", "\n", @@ -1206,7 +1224,7 @@ "cfg_example4 = OmegaConf.load('../source/example4.yaml')\n", "\n", "conf = OmegaConf.merge(cfg_example2, cfg_example4, list_merge_mode=ListMergeMode.EXTEND_UNIQUE)\n", - "print(OmegaConf.to_yaml(conf))" + "print(OmegaConf.to_yaml(conf))\n" ] }, { @@ -1223,6 +1241,11 @@ "- user3\n", "```" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { @@ -1241,7 +1264,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.11.4" }, "pycharm": { "stem_cell": { diff --git a/docs/source/conf.py b/docs/source/conf.py index 76530d451..d0286e74f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,10 +19,10 @@ from packaging.version import parse -from omegaconf import version as v - sys.path.insert(0, os.path.abspath("../../")) +from omegaconf import version as v # noqa E402 + year = datetime.datetime.now().year parsed_ver = parse(v.__version__) diff --git a/news/1114.misc b/news/1114.misc new file mode 100644 index 000000000..01ee2cc13 --- /dev/null +++ b/news/1114.misc @@ -0,0 +1 @@ +Upgrade dependency on `python3-antlr4-runtime` from `v4.9.*` to `v4.11.*`. `antlr4` runtime is now vendored to prevent conflicts with other dependencies. diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index e8b9e369e..a3b040ce6 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -29,6 +29,7 @@ open_dict, read_write, ) +from .typing import Antlr4ParserRuleContext from .version import __version__ __all__ = [ @@ -63,4 +64,5 @@ "MISSING", "SI", "II", + "Antlr4ParserRuleContext", ] diff --git a/omegaconf/base.py b/omegaconf/base.py index 77e951058..ed52ec42c 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -6,8 +6,6 @@ from enum import Enum from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union -from antlr4 import ParserRuleContext - from ._utils import ( _DEFAULT_MARKER_, NoneType, @@ -37,6 +35,7 @@ from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser from .grammar_parser import parse from .grammar_visitor import GrammarVisitor +from .typing import Antlr4ParserRuleContext DictKeyType = Union[str, bytes, int, Enum, float, bool] @@ -629,7 +628,7 @@ def _validate_and_convert_interpolation_result( # If the converted value is of the same type, it means that no conversion # was actually needed. As a result, we can keep the original `resolved` # (and otherwise, the converted value must be wrapped into a new node). - if type(conv_value) != type(res_value): + if type(conv_value) is not type(res_value): must_wrap = True resolved = conv_value @@ -726,7 +725,7 @@ def _maybe_resolve_interpolation( def resolve_parse_tree( self, - parse_tree: ParserRuleContext, + parse_tree: Antlr4ParserRuleContext, node: Node, memo: Optional[Set[int]] = None, key: Optional[Any] = None, diff --git a/omegaconf/grammar_parser.py b/omegaconf/grammar_parser.py index 3c883c2cf..6fe59c45d 100644 --- a/omegaconf/grammar_parser.py +++ b/omegaconf/grammar_parser.py @@ -2,9 +2,6 @@ import threading from typing import Any -from antlr4 import CommonTokenStream, InputStream, ParserRuleContext -from antlr4.error.ErrorListener import ErrorListener - from .errors import GrammarParseError # Import from visitor in order to check the presence of generated grammar files @@ -13,6 +10,9 @@ OmegaConfGrammarLexer, OmegaConfGrammarParser, ) +from .typing import Antlr4ParserRuleContext +from .vendor.antlr4 import CommonTokenStream, InputStream # type: ignore[attr-defined] +from .vendor.antlr4.error.ErrorListener import ErrorListener # Used to cache grammar objects to avoid re-creating them on each call to `parse()`. # We use a per-thread cache to make it thread-safe. @@ -39,8 +39,10 @@ # it must not accept anything that isn't a valid interpolation (per the # interpolation grammar defined in `omegaconf/grammar/*.g4`). +# ParserRuleContext: TypeAlias = ParserRuleContext + -class OmegaConfErrorListener(ErrorListener): # type: ignore +class OmegaConfErrorListener(ErrorListener): def syntaxError( self, recognizer: Any, @@ -95,7 +97,7 @@ def reportContextSensitivity( def parse( value: str, parser_rule: str = "configValue", lexer_mode: str = "DEFAULT_MODE" -) -> ParserRuleContext: +) -> Antlr4ParserRuleContext: """ Parse interpolated string `value` (and return the parse tree). """ @@ -116,7 +118,7 @@ def parse( # The two lines below could be enabled in the future if we decide to switch # to SLL prediction mode. Warning though, it has not been fully tested yet! - # from antlr4 import PredictionMode + # from omegaconf.vendor.antlr4 import PredictionMode # parser._interp.predictionMode = PredictionMode.SLL # Note that although the input stream `istream` is implicitly cached within @@ -133,7 +135,7 @@ def parse( parser.reset() try: - return getattr(parser, parser_rule)() + return getattr(parser, parser_rule)() # type: ignore except Exception as exc: if type(exc) is Exception and str(exc) == "Empty Stack": # This exception is raised by antlr when trying to pop a mode while diff --git a/omegaconf/grammar_visitor.py b/omegaconf/grammar_visitor.py index 55a6907ed..198bd6b8c 100644 --- a/omegaconf/grammar_visitor.py +++ b/omegaconf/grammar_visitor.py @@ -14,9 +14,8 @@ Union, ) -from antlr4 import TerminalNode - from .errors import InterpolationResolutionError +from .vendor.antlr4 import TerminalNode # type: ignore[attr-defined] if TYPE_CHECKING: from .base import Node # noqa F401 @@ -91,9 +90,9 @@ def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str: return res else: assert isinstance(child, TerminalNode) and isinstance( - child.symbol.text, str + child.symbol.text, str # type: ignore[attr-defined] ) - return child.symbol.text + return child.symbol.text # type: ignore[attr-defined] def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any: # text EOF @@ -138,7 +137,7 @@ def visitInterpolationNode( inter_key_tokens = [] # parsed elements of the dot path for child in ctx.getChildren(): if isinstance(child, TerminalNode): - s = child.symbol + s = child.symbol # type: ignore[attr-defined] if s.type in [ OmegaConfGrammarLexer.DOT, OmegaConfGrammarLexer.BRACKET_OPEN, @@ -168,7 +167,7 @@ def visitInterpolationResolver( args = [] args_str = [] if isinstance(maybe_seq, TerminalNode): # means there are no args - assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE + assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE # type: ignore[attr-defined] else: assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext) for val, txt in self.visitSequence(maybe_seq): @@ -191,7 +190,7 @@ def visitDictKeyValuePair( colon = ctx.getChild(1) assert ( isinstance(colon, TerminalNode) - and colon.symbol.type == OmegaConfGrammarLexer.COLON + and colon.symbol.type == OmegaConfGrammarLexer.COLON # type: ignore[attr-defined] ) value = _get_value(self.visitElement(ctx.getChild(2))) return key, value @@ -224,8 +223,8 @@ def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> items = [] for child in list(ctx.getChildren())[::2]: if isinstance(child, TerminalNode): - assert child.symbol.type == OmegaConfGrammarLexer.ID - items.append(child.symbol.text) + assert child.symbol.type == OmegaConfGrammarLexer.ID # type: ignore[attr-defined] + items.append(child.symbol.text) # type: ignore[attr-defined] else: assert isinstance(child, OmegaConfGrammarParser.InterpolationContext) item = _get_value(self.visitInterpolation(child)) @@ -268,7 +267,7 @@ def empty_str_warning() -> None: else: assert ( isinstance(child, TerminalNode) - and child.symbol.type == OmegaConfGrammarLexer.COMMA + and child.symbol.type == OmegaConfGrammarLexer.COMMA # type: ignore[attr-defined] ) if is_previous_comma: empty_str_warning() @@ -312,7 +311,7 @@ def _createPrimitive( if isinstance(child, OmegaConfGrammarParser.InterpolationContext): return self.visitInterpolation(child) assert isinstance(child, TerminalNode) - symbol = child.symbol + symbol = child.symbol # type: ignore[attr-defined] # Parse primitive types. if symbol.type in ( OmegaConfGrammarLexer.ID, @@ -351,7 +350,7 @@ def _unescape( chrs = [] for node, next_node in zip_longest(seq, seq[1:]): if isinstance(node, TerminalNode): - s = node.symbol + s = node.symbol # type: ignore if s.type == OmegaConfGrammarLexer.ESC_INTER: # `ESC_INTER` is of the form `\\...\${`: the formula below computes # the number of characters to keep at the end of the string to remove diff --git a/omegaconf/typing.py b/omegaconf/typing.py new file mode 100644 index 000000000..6e47b88d3 --- /dev/null +++ b/omegaconf/typing.py @@ -0,0 +1,18 @@ +import sys + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +from .vendor.antlr4.ParserRuleContext import ParserRuleContext + +# The antlr4 class `ParserRulerContext` is not a valid type for `mypy` for a number of +# reasons, including lack of strict typing in antlr4-python3-runtime and dynamically +# generated attributes that can only be checked at runtime. +# To be able to use the class for type hinting without raising `valid-type` mypy errors, +# we define a type alias that we can use throughout the code. +# Note that type aliases cannot shadow the name of the class they are aliasing, +# so we need to name the aliased class something different from `ParserRuleContext` + +Antlr4ParserRuleContext: TypeAlias = ParserRuleContext # type: ignore[valid-type] diff --git a/omegaconf/vendor/README.txt b/omegaconf/vendor/README.txt new file mode 100644 index 000000000..8123b7d1f --- /dev/null +++ b/omegaconf/vendor/README.txt @@ -0,0 +1,8 @@ +# Vendored dependencies + +This folder contains libraries that are vendored from third party packages. +To add or modify a vendored library, just add or edit the corresponding dependency +in `vendor.txt`, and then run the `build_helpers/get_vendored.py` script. + +**NOTE** all files in this folder apart from `__init__.py`, `vendor.txt` and `README.txt` are +dynamically generated; any manual modifications will be lost when running the `get_vendored` script diff --git a/omegaconf/vendor/__init__.py b/omegaconf/vendor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/omegaconf/vendor/antlr4/BufferedTokenStream.py b/omegaconf/vendor/antlr4/BufferedTokenStream.py new file mode 100644 index 000000000..4fe39e262 --- /dev/null +++ b/omegaconf/vendor/antlr4/BufferedTokenStream.py @@ -0,0 +1,302 @@ +# +# Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. +# Use of this file is governed by the BSD 3-clause license that +# can be found in the LICENSE.txt file in the project root. + +# This implementation of {@link TokenStream} loads tokens from a +# {@link TokenSource} on-demand, and places the tokens in a buffer to provide +# access to any previous token by index. +# +#

+# This token stream ignores the value of {@link Token#getChannel}. If your +# parser requires the token stream filter tokens to only those on a particular +# channel, such as {@link Token#DEFAULT_CHANNEL} or +# {@link Token#HIDDEN_CHANNEL}, use a filtering token stream such a +# {@link CommonTokenStream}.

+from io import StringIO +from .Token import Token +from .error.Errors import IllegalStateException + +# need forward declaration +Lexer = None + +# this is just to keep meaningful parameter types to Parser +class TokenStream(object): + + pass + + +class BufferedTokenStream(TokenStream): + __slots__ = ('tokenSource', 'tokens', 'index', 'fetchedEOF') + + def __init__(self, tokenSource:Lexer): + # The {@link TokenSource} from which tokens for this stream are fetched. + self.tokenSource = tokenSource + + # A collection of all tokens fetched from the token source. The list is + # considered a complete view of the input once {@link #fetchedEOF} is set + # to {@code true}. + self.tokens = [] + + # The index into {@link #tokens} of the current token (next token to + # {@link #consume}). {@link #tokens}{@code [}{@link #p}{@code ]} should be + # {@link #LT LT(1)}. + # + #

This field is set to -1 when the stream is first constructed or when + # {@link #setTokenSource} is called, indicating that the first token has + # not yet been fetched from the token source. For additional information, + # see the documentation of {@link IntStream} for a description of + # Initializing Methods.

+ self.index = -1 + + # Indicates whether the {@link Token#EOF} token has been fetched from + # {@link #tokenSource} and added to {@link #tokens}. This field improves + # performance for the following cases: + # + #