Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions startle/_inspect/make_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _reserve_short_names(params: Sequence[Param]):
return used_short_names, short_name_assignments


def make_arg_from_param(param: Param, name: Name, kw_only: bool = False) -> Arg:
def _make_arg_from_param(param: Param, name: Name, kw_only: bool = False) -> Arg:
return Arg(
name=name,
type_=param.normalized_hint, # type: ignore
Expand All @@ -95,7 +95,7 @@ def make_arg_from_param(param: Param, name: Name, kw_only: bool = False) -> Arg:
)


def make_args_from_params_flat(
def _make_args_from_params_flat(
params: Sequence[Param], brief: str = "", program_name: str = ""
) -> Args:
args = Args(brief=brief, program_name=program_name)
Expand All @@ -114,7 +114,7 @@ def make_args_from_params_flat(
used_short_names.add(first_char)
short_name_assignments[param.name] = first_char
short = first_char
arg = make_arg_from_param(
arg = _make_arg_from_param(
param=param,
name=Name(long=param.name.replace("_", "-"), short=short),
)
Expand All @@ -130,7 +130,7 @@ def make_args_from_params_flat(
return args


def make_args_from_params_recursive(
def _make_args_from_params_recursive(
params: Sequence[Param],
brief: str = "",
program_name: str = "",
Expand Down Expand Up @@ -180,7 +180,7 @@ def traverse(node: TreeNode[Param], args: Args, parent_name: str = ""):
short = first_char
name = Name(long=param.name.replace("_", "-"), short=short)

arg = make_arg_from_param(
arg = _make_arg_from_param(
param=param,
name=name,
kw_only=kw_only,
Expand Down Expand Up @@ -279,21 +279,21 @@ def make_args_from_func(
]

if not recurse:
return make_args_from_params_flat(
return _make_args_from_params_flat(
params=params,
brief=brief,
program_name=program_name,
)
else:
return make_args_from_params_recursive(
return _make_args_from_params_recursive(
params=params,
brief=brief,
program_name=program_name,
naming=naming,
)


def make_params_from_class(cls: type) -> list[Param]:
def _make_params_from_class(cls: type) -> list[Param]:
params = get_initializer_parameters(cls)
hints = get_type_hints(cls.__init__, include_extras=True)
_, arg_helps = parse_docstring(cls)
Expand All @@ -311,7 +311,7 @@ def make_params_from_class(cls: type) -> list[Param]:
]


def make_params_from_td(cls: type) -> list[Param]:
def _make_params_from_td(cls: type) -> list[Param]:
params = get_type_hints(cls, include_extras=True).items()
optional_keys = cast(frozenset[str], cls.__optional_keys__) # type: ignore
required_keys = cast(frozenset[str], cls.__required_keys__) # type: ignore
Expand Down Expand Up @@ -351,18 +351,18 @@ def make_args_from_class(
# TODO: check if cls is a class?

if is_typeddict(cls):
params = make_params_from_td(cls)
params = _make_params_from_td(cls)
else:
params = make_params_from_class(cls)
params = _make_params_from_class(cls)

if not recurse:
return make_args_from_params_flat(
return _make_args_from_params_flat(
params=params,
brief=brief,
program_name=program_name,
)
else:
return make_args_from_params_recursive(
return _make_args_from_params_recursive(
params=params,
brief=brief,
program_name=program_name,
Expand Down
4 changes: 4 additions & 0 deletions startle/_inspect/tree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Utilities for recursive inspection when parsing nested structures.
"""

from collections.abc import Iterable
from dataclasses import dataclass, is_dataclass
from typing import Generic, TypeVar, cast, get_type_hints
Expand Down
8 changes: 6 additions & 2 deletions startle/_parse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Literal, TypeVar

from rich.console import Console
from rich.text import Text
Expand All @@ -16,6 +16,8 @@ def parse(
args: list[str] | None = None,
brief: str = "",
catch: bool = True,
recurse: bool = False,
naming: Literal["flat", "nested"] = "flat",
) -> T:
"""
Given a class `cls`, parse arguments from the command-line according to the
Expand All @@ -36,7 +38,9 @@ class definition and construct an instance.
An instance of the class `cls`.
"""
# first, make Args object from the class
args_ = make_args_from_class(cls, brief=brief, program_name=name or "")
args_ = make_args_from_class(
cls, brief=brief, program_name=name or "", recurse=recurse, naming=naming
)

try:
# then, parse the arguments from the CLI
Expand Down
10 changes: 6 additions & 4 deletions tests/test_help/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def check_help_from_func(
assert remove_trailing_spaces(result) == remove_trailing_spaces(expected)


def check_help_from_class(cls: type, brief: str, program_name: str, expected: str):
def check_help_from_class(
cls: type, brief: str, program_name: str, expected: str, recurse: bool = False
):
console = Console(width=120, highlight=False, force_terminal=True)
with console.capture() as capture:
make_args_from_class(cls, program_name=program_name, brief=brief).print_help(
console
)
make_args_from_class(
cls, program_name=program_name, brief=brief, recurse=recurse
).print_help(console)
result = capture.get()

console = Console(width=120, highlight=False, force_terminal=True)
Expand Down
Loading