Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ Added
- Support ``Deque`` and ``FrozenSet`` in type hints (`#905
<https://github.com/omni-us/jsonargparse/pull/905>`__).

Fixed
^^^^^
- Detect loops in config files that recursively load subconfig files and raise
an error showing the config chain instead of recursing indefinitely (`#910
<https://github.com/omni-us/jsonargparse/pull/910>`__).

Changed
^^^^^^^
- Docs now reference methods via the public ``ArgumentParser`` class instead of
Expand Down
3 changes: 2 additions & 1 deletion jsonargparse/_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import_object,
indent_text,
iter_to_set_str,
load_config_path_context,
parse_value_or_config,
)

Expand Down Expand Up @@ -251,7 +252,7 @@ def _load_config(self, value, parser):
cfg, cfg_path = parse_value_or_config(value)
if not isinstance(cfg, dict):
raise TypeError(f'Parser key "{self.dest}": Unable to load config "{value}"')
with change_to_path_dir(cfg_path):
with load_config_path_context(cfg_path), change_to_path_dir(cfg_path):
cfg = parser._apply_actions(cfg, parent_key=self.dest)
return cfg
except (TypeError,) + get_loader_exceptions() as ex:
Expand Down
15 changes: 10 additions & 5 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
get_argument_group_class,
get_private_kwargs,
identity,
load_config_path_context,
return_parser_if_captured,
)

Expand Down Expand Up @@ -621,7 +622,7 @@ def parse_path(
ArgumentError: If the parsing fails and ``exit_on_error=False``.
"""
fpath = Path(cfg_path, mode=_get_config_read_mode())
with change_to_path_dir(fpath):
with load_config_path_context(fpath), change_to_path_dir(fpath):
cfg_str = fpath.read_text()
parsed_cfg = self.parse_string(
cfg_str=cfg_str,
Expand Down Expand Up @@ -1029,10 +1030,14 @@ def get_defaults(self, skip_validation: bool = False, **kwargs) -> Namespace:

default_config_files = self._get_default_config_files()
for default_config_file in default_config_files:
default_config_file_content = default_config_file.read_text()
if not default_config_file_content.strip():
continue
with change_to_path_dir(default_config_file), parser_context(parent_parser=self, parsing_defaults=True):
with (
load_config_path_context(default_config_file),
change_to_path_dir(default_config_file),
parser_context(parent_parser=self, parsing_defaults=True),
):
default_config_file_content = default_config_file.read_text()
if not default_config_file_content.strip():
continue
cfg_file = self._load_config_parser_mode(default_config_file_content, prev_cfg=cfg)
cfg = self.merge_config(cfg_file, cfg)
try:
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import_object,
indent_text,
iter_to_set_str,
load_config_path_context,
object_path_serializer,
parse_value_or_config,
warning,
Expand Down Expand Up @@ -635,14 +636,14 @@ def _check_type(self, value, append=False, cfg=None, mode=None):
"logger": self.logger,
}
try:
with change_to_path_dir(config_path):
with load_config_path_context(config_path), change_to_path_dir(config_path):
val = adapt_typehints(val, self._typehint, **kwargs)
except ValueError as ex:
if orig_val == "-" and isinstance(getattr(ex, "parent", None), PathError):
raise ex
try:
if isinstance(orig_val, str):
with change_to_path_dir(config_path):
with load_config_path_context(config_path), change_to_path_dir(config_path):
val = adapt_typehints(orig_val, self._typehint, default=self.default, **kwargs)
ex = None
except ValueError:
Expand Down
39 changes: 38 additions & 1 deletion jsonargparse/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import warnings
from argparse import ArgumentError
from collections import namedtuple
from contextlib import contextmanager
from contextvars import ContextVar
from importlib import import_module
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import (
Any,
Callable,
Iterator,
Optional,
Type,
Union,
Expand All @@ -36,6 +39,7 @@


default_config_option_help = "Path to a configuration file."
config_load_stack: ContextVar[tuple[tuple[str, str], ...]] = ContextVar("config_load_stack", default=())


def argument_error(message: str, default_config_file: Optional[str] = None) -> ArgumentError:
Expand All @@ -45,6 +49,39 @@ def argument_error(message: str, default_config_file: Optional[str] = None) -> A
return ex


def _config_path_id(cfg_path: Path) -> tuple[str, str]:
path_id = cfg_path.absolute
if not (cfg_path.is_url or cfg_path.is_fsspec):
path_id = os.path.realpath(path_id)
return path_id, str(cfg_path)


def _format_config_load_chain(stack: tuple[tuple[str, str], ...], path_id: tuple[str, str]) -> str:
chain = list(stack) + [path_id]
for num, (stack_path, _) in enumerate(chain):
if stack_path == path_id[0]:
chain = chain[num:]
break
return " -> ".join(display for _, display in chain)


@contextmanager
def load_config_path_context(cfg_path: Optional[Path]) -> Iterator[None]:
if cfg_path is None:
yield
return
path_id = _config_path_id(cfg_path)
stack = config_load_stack.get()
if path_id[0] in {path for path, _ in stack}:
chain = _format_config_load_chain(stack, path_id)
raise TypeError(f"Config file loop detected: {chain}")
token = config_load_stack.set(stack + (path_id,))
try:
yield
finally:
config_load_stack.reset(token)


class JsonargparseWarning(UserWarning):
pass

Expand Down Expand Up @@ -115,7 +152,7 @@ def parse_value_or_config(
except TypeError:
pass
else:
with cfg_path.relative_path_context():
with load_config_path_context(cfg_path), cfg_path.relative_path_context():
value = load_value(cfg_path.read_text(), simple_types=simple_types)
if type(value) is str and value.strip() != "":
parsed_val = load_value(value, simple_types=simple_types)
Expand Down
20 changes: 20 additions & 0 deletions jsonargparse_tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,26 @@ def test_action_parser_parse_path(composed_parsers):
assert expected == cfg2.as_dict()


def test_action_parser_config_loop(composed_parsers):
parser, yaml_main, yaml_inner2 = composed_parsers[:3]
yaml_main.write_text(json_or_yaml_dump({"inner2": "main.yaml"}))
with pytest.raises(ArgumentError) as ctx:
parser.parse_path(yaml_main)
ctx.match(r"Config file loop detected: .+main\.yaml -> main\.yaml")

yaml_main.write_text(json_or_yaml_dump({"inner2": "inner2.yaml"}))
yaml_inner2.write_text(json_or_yaml_dump({"inner3": "main.yaml"}))
with pytest.raises(ArgumentError) as ctx:
parser.parse_path(yaml_main)
ctx.match(r"Config file loop detected: .+main\.yaml -> inner2\.yaml -> main\.yaml")

yaml_main.write_text(json_or_yaml_dump({"inner2": "main.yaml"}))
parser.default_config_files = [str(yaml_main)]
with pytest.raises(ArgumentError) as ctx:
parser.parse_args([])
ctx.match(r"Config file loop detected: .+main\.yaml -> main\.yaml")


def test_action_parser_parse_env_inner(composed_parsers):
parser, _, yaml_inner2, yaml_inner3 = composed_parsers
assert "opt2_env" == parser.parse_env({"LV1_INNER2__OPT2": "opt2_env"}).inner2.opt2
Expand Down
Loading