diff --git a/codebasin/util.py b/codebasin/util.py index 122cbbc..e12220f 100644 --- a/codebasin/util.py +++ b/codebasin/util.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019-2024 Intel Corporation +# Copyright (C) 2019-2025 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause """ Contains utility functions for common operations, including: @@ -21,7 +21,7 @@ log = logging.getLogger(__name__) -def ensure_ext(path: os.PathLike[str], extensions: Iterable[str]): +def ensure_ext(path: os.PathLike[str], extensions: Iterable[str]) -> None: """ Ensure that a path has one of the specified extensions. @@ -33,11 +33,6 @@ def ensure_ext(path: os.PathLike[str], extensions: Iterable[str]): extensions: Iterable[str] The valid extensions to test against. - Returns - ------- - bool - True if `path` is a file with one of the specified extensions. - Raises ------ TypeError @@ -56,10 +51,10 @@ def ensure_ext(path: os.PathLike[str], extensions: Iterable[str]): extension = "".join(path.suffixes) if extension not in extensions: exts = ", ".join([f"'{ext}'" for ext in extensions]) - raise ValueError(f"{path} does not have a valid extension: f{exts}") + raise ValueError(f"{path} does not have a valid extension: {exts}") -def safe_open_write_binary(fname): +def safe_open_write_binary(fname: os.PathLike[str]) -> typing.BinaryIO: """Open fname for (binary) writing. Truncate if not a symlink.""" fpid = os.open( fname, @@ -69,8 +64,34 @@ def safe_open_write_binary(fname): return os.fdopen(fpid, "wb") -def valid_path(path): - """Return true if the path passed in is valid""" +def valid_path(path: os.PathLike[str]) -> bool: + """ + Check if a given file path is valid. + + This function ensures that the file path does not contain + potentially dangerous characters such as null bytes (`\x00`) + or carriage returns/line feeds (`\n`, `\r`). + + Parameters + ---------- + path : os.PathLike[str] + The file path to be validated. + + Returns + ------- + bool + A boolean value indicating whether the path is valid + (`True`) or invalid (`False`). + + Examples + -------- + >>> valid_path("/home/user/file.txt") + True + >>> valid_path("/home/user/\x00file.txt") + False + >>> valid_path("/home/user/file\n.txt") + False + """ valid = True # Check for null byte character(s) @@ -201,7 +222,10 @@ def _load_json(file_object: typing.TextIO, schema_name: str) -> object: return json_object -def _load_toml(file_object: typing.TextIO, schema_name: str) -> object: +def _load_toml( + file_object: typing.TextIO, + schema_name: str, +) -> dict[str, typing.Any]: """ Load TOML from file and validate it against a schema. @@ -215,8 +239,9 @@ def _load_toml(file_object: typing.TextIO, schema_name: str) -> object: Returns ------- - Object - The loaded TOML. + dict[str, Any] + The loaded TOML object, represented as a Python + dict with str key/value mappings. Raises ------