diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19cad4e16..7c6dcde6d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,25 @@ +default_language_version: + python: python3.8 + repos: - - repo: https://github.com/timothycrosley/isort - rev: 5.0.9 + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 23.7.0 hooks: - id: black - language_version: python3.8 - - repo: https://gitlab.com/pycqa/flake8 - rev: 3.7.9 + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 hooks: - id: flake8 - additional_dependencies: [-e, 'git+https://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v1.4.1 hooks: - id: mypy args: [--strict] - additional_dependencies: ['pytest'] + additional_dependencies: ['attrs', 'pytest'] diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 49be30329..1b14b2825 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -10,10 +10,13 @@ _get_value, is_primitive_container, is_structured_config, + maybe_escape, ) -def _resolve_container_value(cfg: Container, key: Any) -> None: +def _resolve_container_value( + cfg: Container, key: Any, escape_interpolation_strings: bool +) -> None: node = cfg._get_child(key) assert isinstance(node, Node) if node._is_interpolation(): @@ -23,7 +26,7 @@ def _resolve_container_value(cfg: Container, key: Any) -> None: node._set_value(MISSING) else: if isinstance(resolved, Container): - _resolve(resolved) + _resolve(resolved, escape_interpolation_strings) if isinstance(resolved, InterpolationResultNode): resolved_value = _get_value(resolved) if is_primitive_container(resolved_value) or is_structured_config( @@ -33,12 +36,15 @@ def _resolve_container_value(cfg: Container, key: Any) -> None: if isinstance(resolved, Container) and isinstance(node, ValueNode): cfg[key] = resolved else: - node._set_value(_get_value(resolved)) + val = _get_value(resolved) + if escape_interpolation_strings: + val = maybe_escape(val) + node._set_value(val) else: - _resolve(node) + _resolve(node, escape_interpolation_strings) -def _resolve(cfg: Node) -> Node: +def _resolve(cfg: Node, escape_interpolation_strings: bool) -> Node: assert isinstance(cfg, Node) if cfg._is_interpolation(): try: @@ -46,15 +52,18 @@ def _resolve(cfg: Node) -> Node: except InterpolationToMissingValueError: cfg._set_value(MISSING) else: - cfg._set_value(resolved._value()) + val = resolved._value() + if escape_interpolation_strings: + val = maybe_escape(val) + cfg._set_value(val) if isinstance(cfg, DictConfig): for k in cfg.keys(): - _resolve_container_value(cfg, k) + _resolve_container_value(cfg, k, escape_interpolation_strings) elif isinstance(cfg, ListConfig): for i in range(len(cfg)): - _resolve_container_value(cfg, i) + _resolve_container_value(cfg, i, escape_interpolation_strings) return cfg diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 3452f48ca..388f8cfaa 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -683,6 +683,43 @@ def is_primitive_container(obj: Any) -> bool: return is_primitive_list(obj) or is_primitive_dict(obj) +def maybe_escape(value: Any) -> Any: + """Escape interpolation strings and return other values unchanged. + + When the input value is an interpolation string, the returned value is such that + it yields the original input string when resolved. + """ + if not isinstance(value, str) or not _is_interpolation_string( + value, strict_interpolation_validation=False + ): + return value + start = 0 + tokens = [] + while True: + # Find next ${ that needs escaping. + first_inter = value.find("${", start) + if first_inter < 0: + tokens.append(value[start:]) # ensure we keep the end of the string + break + # Any backslash that comes before ${ will need to be escaped as well. + count_esc = 0 + while ( + first_inter - count_esc - 1 >= 0 + and value[first_inter - count_esc - 1] == "\\" + ): + count_esc += 1 + tokens += [ + # Characters that need not be changed. + value[start : first_inter - count_esc], + # Escaped backslashes before the interpolation. + "\\" * (count_esc * 2), + # Escaped interpolation. + "\\${", + ] + start = first_inter + 2 + return "".join(tokens) + + def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any: args = getattr(ref_type, "__args__", None) if ref_type is not List and args is not None and args[0]: diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 041602879..918a60604 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -785,7 +785,7 @@ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str: ) @staticmethod - def resolve(cfg: Container) -> None: + def resolve(cfg: Container, escape_interpolation_strings: bool = False) -> None: """ Resolves all interpolations in the given config object in-place. @@ -800,7 +800,10 @@ def resolve(cfg: Container) -> None: raise ValueError( f"Invalid config type ({type(cfg).__name__}), expected an OmegaConf Container" ) - omegaconf._impl._resolve(cfg) + omegaconf._impl._resolve(cfg, escape_interpolation_strings=True) + if not escape_interpolation_strings: + # Do a second pass without escaping. + omegaconf._impl._resolve(cfg, escape_interpolation_strings=False) @staticmethod def missing_keys(cfg: Any) -> Set[str]: diff --git a/requirements/dev.txt b/requirements/dev.txt index eedeeb1aa..63421ed03 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,12 +1,12 @@ -r base.txt -r docs.txt attrs -black +black>=23.7.0 build coveralls -flake8>=4 -isort~=5.0 -mypy +flake8>=6.0.0 +isort>=5.12.0 +mypy>=1.4.1 nox pre-commit pyflakes @@ -15,6 +15,6 @@ pytest-benchmark pytest-lazy-fixture pytest-mock towncrier +types-setuptools # makes mypy happy twine pydevd - diff --git a/setup.cfg b/setup.cfg index f1a1fff8b..e124d046e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ test=pytest [mypy] -python_version = 3.7 +python_version = 3.8 mypy_path=.stubs exclude = build/