diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 874b7ff2..69921fa9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -64,3 +64,18 @@ repos: - id: bandit name: bandit args: ["-c", ".bandit"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.17.1 + hooks: + - id: mypy + name: mypy + args: [] + additional_dependencies: + [ + "types-jsonschema", + "types-tqdm", + "types-tabulate", + "scipy-stubs", + "matplotlib", # There are no official stubs for matplotlib + ] diff --git a/codebasin/__init__.py b/codebasin/__init__.py index 9889466b..1efbc4bf 100644 --- a/codebasin/__init__.py +++ b/codebasin/__init__.py @@ -1,8 +1,11 @@ # Copyright (C) 2019-2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import importlib.metadata import os import shlex +import typing import warnings from collections.abc import Iterable from pathlib import Path @@ -152,7 +155,7 @@ def __iter__(self): yield from self.commands @classmethod - def from_json(cls, instance: list): + def from_json(cls, instance: list) -> CompilationDatabase: """ Parameters ---------- @@ -174,7 +177,10 @@ def from_json(cls, instance: list): return cls(commands) @classmethod - def from_file(cls, filename: str | os.PathLike[str]): + def from_file( + cls, + filename: str | os.PathLike[str], + ) -> CompilationDatabase: """ Parameters --------- @@ -194,8 +200,8 @@ def from_file(cls, filename: str | os.PathLike[str]): A CompilationDatbase corresponding to the provided JSON file. """ with open(filename) as f: - db = codebasin.util._load_json(f, schema_name="compiledb") - return CompilationDatabase.from_json(db) + db: object = codebasin.util._load_json(f, schema_name="compiledb") + return CompilationDatabase.from_json(typing.cast(list, db)) class CodeBase: diff --git a/codebasin/__main__.py b/codebasin/__main__.py index 1d7849fe..26271828 100755 --- a/codebasin/__main__.py +++ b/codebasin/__main__.py @@ -44,10 +44,10 @@ def _help_string(*lines: str, is_long=False, is_last=False): # argparse.HelpFormatter indents by 24 characters. # We cannot override this directly, but can delete them with backspaces. - lines = ["\b" * 20 + x for x in lines] + modified_lines = ["\b" * 20 + x for x in lines] # The additional space is required for argparse to respect newlines. - result += "\n".join(lines) + result += "\n".join(modified_lines) if not is_last: result += "\n " diff --git a/codebasin/config.py b/codebasin/config.py index e3fd8f1a..d4cb307f 100644 --- a/codebasin/config.py +++ b/codebasin/config.py @@ -4,6 +4,7 @@ Contains functions to build up a configuration dictionary, defining a specific code base configuration. """ +from __future__ import annotations import argparse import logging @@ -12,16 +13,17 @@ import re import string import tomllib +from collections.abc import Sequence from dataclasses import asdict, dataclass, field from itertools import chain from pathlib import Path -from typing import Self +from typing import Any from codebasin import CompilationDatabase, util log = logging.getLogger(__name__) -_compilers = None +_compilers = {} class _StoreSplitAction(argparse.Action): @@ -45,9 +47,9 @@ def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - values: str, - option_string: str, - ): + values: str | Sequence[Any] | None, + option_string: str | None = None, + ) -> None: if not isinstance(values, str): raise TypeError("store_split expects string values") split_values = values.split(self.sep) @@ -84,9 +86,9 @@ def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - value: str, - option_string: str, - ): + value: str | Sequence[Any] | None, + option_string: str | None = None, + ) -> None: if not isinstance(value, str): raise TypeError("extend_match expects string value") matches = re.findall(self.pattern, value) @@ -118,7 +120,7 @@ class _CompilerMode: include_files: list[str] = field(default_factory=list) @classmethod - def from_toml(cls, toml: object) -> Self: + def from_toml(cls, toml: dict[str, Any]) -> _CompilerMode: return _CompilerMode(**toml) @@ -131,7 +133,7 @@ class _CompilerPass: modes: list[str] = field(default_factory=list) @classmethod - def from_toml(cls, toml: object) -> Self: + def from_toml(cls, toml: dict[str, Any]) -> _CompilerPass: return _CompilerPass(**toml) @@ -144,7 +146,7 @@ class _Compiler: passes: dict[str, _CompilerPass] = field(default_factory=dict) @classmethod - def from_toml(cls, toml: object) -> Self: + def from_toml(cls, toml: dict[str, Any]) -> _Compiler: kwargs = toml.copy() if "parser" in kwargs: for option in kwargs["parser"]: diff --git a/codebasin/coverage/__main__.py b/codebasin/coverage/__main__.py index e13a0046..1b16c09b 100755 --- a/codebasin/coverage/__main__.py +++ b/codebasin/coverage/__main__.py @@ -128,11 +128,13 @@ def _compute(args: argparse.Namespace): with open(filename, "rb") as f: digest = hashlib.file_digest(f, "sha512") - used_lines = [] - unused_lines = [] + used_lines: list[int] = [] + unused_lines: list[int] = [] tree = state.get_tree(filename) association = state.get_map(filename) for node in [n for n in tree.walk() if isinstance(n, CodeNode)]: + if not node.lines: + continue if association[node] == frozenset([]): unused_lines.extend(node.lines) else: diff --git a/codebasin/finder.py b/codebasin/finder.py index 3075035c..70a352f7 100644 --- a/codebasin/finder.py +++ b/codebasin/finder.py @@ -96,7 +96,7 @@ def get_setmap(self, codebase: CodeBase) -> dict[frozenset, int]: dict[frozenset, int] The number of lines associated with each platform set. """ - setmap = collections.defaultdict(int) + setmap: dict[frozenset, int] = collections.defaultdict(int) for fn in codebase: # Don't count symlinks if their target is in the code base. # The target will be counted separately. diff --git a/codebasin/preprocessor.py b/codebasin/preprocessor.py index c456b397..1364ef18 100644 --- a/codebasin/preprocessor.py +++ b/codebasin/preprocessor.py @@ -620,7 +620,7 @@ class CodeNode(Node): end_line: int = field(default=-1, init=False) num_lines: int = field(default=0, init=False) source: str | None = field(default=None, init=False, repr=False) - lines: list[str] | None = field( + lines: list[int] | None = field( default_factory=list, init=False, repr=False, diff --git a/codebasin/report.py b/codebasin/report.py index 1ac2dbcf..a3c77aad 100644 --- a/codebasin/report.py +++ b/codebasin/report.py @@ -15,7 +15,7 @@ import sys import warnings from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Sequence from pathlib import Path from typing import Self, TextIO @@ -129,11 +129,11 @@ def average_coverage( if len(platforms) == 0: return float("nan") - total = sum([coverage(setmap, [p]) for p in platforms]) + total = sum([coverage(setmap, {p}) for p in platforms]) return total / len(platforms) -def distance(setmap, p1, p2): +def distance(setmap, p1, p2) -> float: """ Compute distance between two platforms """ @@ -148,14 +148,14 @@ def distance(setmap, p1, p2): return d -def divergence(setmap): +def divergence(setmap) -> float: """ Compute code divergence as defined by Harrell and Kitson i.e. average of pair-wise distances between platform sets """ platforms = extract_platforms(setmap) - d = 0 + d = 0.0 npairs = 0 for p1, p2 in it.combinations(platforms, 2): d += distance(setmap, p1, p2) @@ -166,14 +166,14 @@ def divergence(setmap): return d / float(npairs) -def summary(setmap: defaultdict[str, int], stream: TextIO = sys.stdout): +def summary(setmap: dict[frozenset[str], int], stream: TextIO = sys.stdout): """ Produce a summary report for the platform set, including a breakdown of SLOC per platform subset, code divergence, etc. Parameters ---------- - setmap: defaultdict[str, int] + setmap: dict[frozenset[str], int] The setmap used to compute the summary report. stream: TextIO, default: sys.stdout @@ -214,7 +214,7 @@ def summary(setmap: defaultdict[str, int], stream: TextIO = sys.stdout): def clustering( output_name: str, - setmap: defaultdict[str, int], + setmap: dict[frozenset[str], int], stream: TextIO = sys.stdout, ): """ @@ -225,7 +225,7 @@ def clustering( output_name: str The filename for the dendrogram. - setmap: defaultdict[str, int] + setmap: dict[frozenset[str], int] The setmap used to compute the clustering statistics. stream: TextIO, default: sys.stdout @@ -313,7 +313,7 @@ def find_duplicates(codebase: CodeBase) -> list[set[Path]]: A list of all sets of Paths with identical contents. """ # Search for possible matches using a hash, ignoring symlinks. - possible_matches = {} + possible_matches: dict[str, set] = {} for path in codebase: path = Path(path) if path.is_symlink(): @@ -486,7 +486,7 @@ def is_symlink(self): def _platforms_str( self, all_platforms: set[str], - labels: Iterable[str] = string.ascii_uppercase, + labels: Sequence[str] = string.ascii_uppercase, ) -> str: """ Parameters @@ -494,7 +494,7 @@ def _platforms_str( all_platforms: set[str] The set of all platforms. - labels: Iterable[str], default: string.ascii_uppercase + labels: Sequence[str], default: string.ascii_uppercase The labels to use in place of real platform names. Returns @@ -605,7 +605,7 @@ def __init__(self, rootdir: str | os.PathLike[str]): def insert( self, filename: str | os.PathLike[str], - setmap: defaultdict[str, int], + setmap: dict[frozenset[str], int], ): """ Insert a new file into the tree, creating as many nodes as necessary. @@ -653,7 +653,7 @@ def _print( prefix: str = "", connector: str = "", fancy: bool = True, - levels: int = None, + levels: int | None = None, ): """ Recursive helper function to print all nodes in a FileTree. @@ -740,7 +740,7 @@ def _print( return lines - def write_to(self, stream: TextIO, levels: int = None): + def write_to(self, stream: TextIO, levels: int | None = None): """ Write the FileTree to the specified stream. @@ -766,7 +766,7 @@ def files( *, stream: TextIO = sys.stdout, prune: bool = False, - levels: int = None, + levels: int | None = None, ): """ Produce a file tree representing the code base. @@ -796,7 +796,7 @@ def files( # Build up a tree from the list of files. tree = FileTree(codebase.directories[0]) for f in codebase: - setmap = defaultdict(int) + setmap: dict[frozenset[str], int] = defaultdict(int) if state: association = state.get_map(f) for node in filter( @@ -828,10 +828,10 @@ def files( ] legend += ["[" + " | ".join(header) + "]"] legend += [""] - legend = "\n".join(legend) + legend_string = "\n".join(legend) if not stream.isatty(): - legend = _strip_colors(legend) - print(legend, file=stream) + legend_string = _strip_colors(legend_string) + print(legend_string, file=stream) # Print the tree. tree.write_to(stream, levels=levels) diff --git a/codebasin/util.py b/codebasin/util.py index e12220fa..2da36cbd 100644 --- a/codebasin/util.py +++ b/codebasin/util.py @@ -21,7 +21,10 @@ log = logging.getLogger(__name__) -def ensure_ext(path: os.PathLike[str], extensions: Iterable[str]) -> None: +def ensure_ext( + path: str | os.PathLike[str], + extensions: Iterable[str], +) -> None: """ Ensure that a path has one of the specified extensions. @@ -54,7 +57,7 @@ def ensure_ext(path: os.PathLike[str], extensions: Iterable[str]) -> None: raise ValueError(f"{path} does not have a valid extension: {exts}") -def safe_open_write_binary(fname: os.PathLike[str]) -> typing.BinaryIO: +def safe_open_write_binary(fname: str | os.PathLike[str]) -> typing.BinaryIO: """Open fname for (binary) writing. Truncate if not a symlink.""" fpid = os.open( fname, @@ -64,7 +67,7 @@ def safe_open_write_binary(fname: os.PathLike[str]) -> typing.BinaryIO: return os.fdopen(fpid, "wb") -def valid_path(path: os.PathLike[str]) -> bool: +def valid_path(path: str | os.PathLike[str]) -> bool: """ Check if a given file path is valid. @@ -74,7 +77,7 @@ def valid_path(path: os.PathLike[str]) -> bool: Parameters ---------- - path : os.PathLike[str] + path : str | os.PathLike[str] The file path to be validated. Returns @@ -95,12 +98,12 @@ def valid_path(path: os.PathLike[str]) -> bool: valid = True # Check for null byte character(s) - if "\x00" in path: + if "\x00" in str(path): log.critical("Null byte character in file request.") valid = False # Check for carriage returns or line feed character(s) - if ("\n" in path) or ("\r" in path): + if ("\n" in str(path)) or ("\r" in str(path)): log.critical("Carriage return or line feed character in file request.") valid = False @@ -223,7 +226,7 @@ def _load_json(file_object: typing.TextIO, schema_name: str) -> object: def _load_toml( - file_object: typing.TextIO, + file_object: typing.IO, schema_name: str, ) -> dict[str, typing.Any]: """ @@ -231,7 +234,7 @@ def _load_toml( Parameters ---------- - file_object : typing.TextIO + file_object : typing.IO The file object to load from. schema_name : {'cbiconfig', 'analysis'}