diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..bfd02822 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,40 @@ +name: Tests + +on: + pull_request: + types: [opened, synchronize, reopened, edited] + branches: [main] + push: + branches: [main] + +concurrency: + group: "${{ github.event.pull_request.number }}-${{ github.ref_name }}-${{ github.workflow }}" + cancel-in-progress: true + +jobs: + running-tests: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.11"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: Install dependencies + run: | + python -m pip install --upgrade pip ruff + python -m pip install .[all] + - name: Run linter + run: | + ruff check + + - name: Run tests + run: | + python -m pip install pytest pytest-cov + pytest + diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..f2cc5c8b --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,28 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + branches: [main] + types: [opened, synchronize, reopened, edited] + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Lint with Ruff + run: | + ruff check diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 874d63a2..9b11a532 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,5 +1,4 @@ -name: Publish to PyPI (Trusted Publishing) - +name: Publish to PyPI on: release: types: [published] diff --git a/.github/workflows/test-publish.yml b/.github/workflows/test-publish.yml index 6fb6ec94..a0ae9c57 100644 --- a/.github/workflows/test-publish.yml +++ b/.github/workflows/test-publish.yml @@ -1,4 +1,4 @@ -name: Publish to PyPI (Trusted Publishing) +name: Publish to Test-PyPI on: release: diff --git a/.gitignore b/.gitignore index ba03882b..c816fe54 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,12 @@ +# IDE/Editor settings +.vscode/* + +# Allow VSCode recommendations +!.vscode/extensions.json + +# Ruff +.ruff_cache + # Models checkpoints @@ -29,9 +38,6 @@ example # Logs error_log.txt -# IDE/Editor settings -.vscode/ - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -162,11 +168,13 @@ dmypy.json # Pyre type checker .pyre/ - +# GUI generated train_data/ train_cache.npz autotune/ - gui-settings.json state.json +BirdNET_analysis_params.csv + +# Build files entitlements.plist \ No newline at end of file diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..1e21add5 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,6 @@ +// See https://go.microsoft.com/fwlink/?LinkId=827846 +{ + "recommendations": [ + "charliermarsh.ruff" + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 94caf39d..0b1e7d31 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ [![Docker](https://github.com/birdnet-team/BirdNET-Analyzer/actions/workflows/docker-build.yml/badge.svg)](https://github.com/birdnet-team/BirdNET-Analyzer/actions/workflows/docker-build.yml) [![Reddit](https://img.shields.io/badge/Reddit-FF4500?style=flat&logo=reddit&logoColor=white)](https://www.reddit.com/r/BirdNET_Analyzer/) ![GitHub stars)](https://img.shields.io/github/stars/birdnet-team/BirdNET-Analyzer) + [![GitHub release](https://img.shields.io/github/v/release/birdnet-team/BirdNET-Analyzer)](https://github.com/birdnet-team/BirdNET-Analyzer/releases/latest) +[![PyPI - Version](https://img.shields.io/pypi/v/birdnet_analyzer?logo=pypi)](https://pypi.org/project/birdnet-analyzer/) diff --git a/birdnet_analyzer/__init__.py b/birdnet_analyzer/__init__.py index 2575f448..2e23b5f7 100644 --- a/birdnet_analyzer/__init__.py +++ b/birdnet_analyzer/__init__.py @@ -1,8 +1,9 @@ from birdnet_analyzer.analyze import analyze from birdnet_analyzer.embeddings import embeddings -from birdnet_analyzer.train import train from birdnet_analyzer.search import search from birdnet_analyzer.segments import segments from birdnet_analyzer.species import species +from birdnet_analyzer.train import train -__all__ = ["analyze", "train", "embeddings", "search", "segments", "species"] +__version__ = "2.0.0" +__all__ = ["analyze", "embeddings", "search", "segments", "species", "train"] diff --git a/birdnet_analyzer/analyze/__main__.py b/birdnet_analyzer/analyze/__main__.py index 062bfaa9..11b1f956 100644 --- a/birdnet_analyzer/analyze/__main__.py +++ b/birdnet_analyzer/analyze/__main__.py @@ -1,4 +1,3 @@ from birdnet_analyzer.analyze.cli import main - main() diff --git a/birdnet_analyzer/analyze/cli.py b/birdnet_analyzer/analyze/cli.py index 526e568b..7b6e2ed5 100644 --- a/birdnet_analyzer/analyze/cli.py +++ b/birdnet_analyzer/analyze/cli.py @@ -1,5 +1,5 @@ -from birdnet_analyzer.utils import runtime_error_handler from birdnet_analyzer import analyze +from birdnet_analyzer.utils import runtime_error_handler @runtime_error_handler @@ -7,7 +7,7 @@ def main(): import os from multiprocessing import freeze_support - import birdnet_analyzer.cli as cli + from birdnet_analyzer import cli # Freeze support for executable freeze_support() diff --git a/birdnet_analyzer/analyze/core.py b/birdnet_analyzer/analyze/core.py index 6090908b..555a6bb3 100644 --- a/birdnet_analyzer/analyze/core.py +++ b/birdnet_analyzer/analyze/core.py @@ -1,9 +1,9 @@ import os -from typing import List, Literal +from typing import Literal def analyze( - input: str, + audio_input: str, output: str | None = None, *, min_conf: float = 0.25, @@ -19,8 +19,7 @@ def analyze( audio_speed: float = 1.0, batch_size: int = 1, combine_results: bool = False, - rtype: Literal["table", "audacity", "kaleidoscope", "csv"] - | List[Literal["table", "audacity", "kaleidoscope", "csv"]] = "table", + rtype: Literal["table", "audacity", "kaleidoscope", "csv"] | list[Literal["table", "audacity", "kaleidoscope", "csv"]] = "table", skip_existing_results: bool = False, sf_thresh: float = 0.03, top_n: int | None = None, @@ -31,7 +30,7 @@ def analyze( """ Analyzes audio files for bird species detection using the BirdNET-Analyzer. Args: - input (str): Path to the input directory or file containing audio data. + audio_input (str): Path to the input directory or file containing audio data. output (str | None, optional): Path to the output directory for results. Defaults to None. min_conf (float, optional): Minimum confidence threshold for detections. Defaults to 0.25. classifier (str | None, optional): Path to a custom classifier file. Defaults to None. @@ -73,7 +72,7 @@ def analyze( ensure_model_exists() flist = _set_params( - input=input, + audio_input=audio_input, output=output, min_conf=min_conf, custom_classifier=classifier, @@ -109,8 +108,7 @@ def analyze( # Analyze files if cfg.CPU_THREADS < 2 or len(flist) < 2: - for entry in flist: - result_files.append(analyze_file(entry)) + result_files.extend(analyze_file(f) for f in flist) else: with Pool(cfg.CPU_THREADS) as p: # Map analyzeFile function to each entry in flist @@ -129,7 +127,7 @@ def analyze( def _set_params( - input, + audio_input, output, min_conf, custom_classifier, @@ -154,7 +152,7 @@ def _set_params( labels_file=None, ): import birdnet_analyzer.config as cfg - from birdnet_analyzer.analyze.utils import load_codes # noqa: E402 + from birdnet_analyzer.analyze.utils import load_codes from birdnet_analyzer.species.utils import get_species_list from birdnet_analyzer.utils import collect_audio_files, read_lines @@ -164,7 +162,7 @@ def _set_params( cfg.LOCATION_FILTER_THRESHOLD = sf_thresh cfg.TOP_N = top_n cfg.MERGE_CONSECUTIVE = merge_consecutive - cfg.INPUT_PATH = input + cfg.INPUT_PATH = audio_input cfg.MIN_CONFIDENCE = min_conf cfg.SIGMOID_SENSITIVITY = sensitivity cfg.SIG_OVERLAP = overlap @@ -233,9 +231,7 @@ def _set_params( cfg.SPECIES_LIST_FILE = None cfg.SPECIES_LIST = get_species_list(cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD) - lfile = os.path.join( - cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", "_{}.txt".format(locale)) - ) + lfile = os.path.join(cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", f"_{locale}.txt")) if locale not in ["en"] and os.path.isfile(lfile): cfg.TRANSLATED_LABELS = read_lines(lfile) diff --git a/birdnet_analyzer/analyze/utils.py b/birdnet_analyzer/analyze/utils.py index 99fa8650..56546701 100644 --- a/birdnet_analyzer/analyze/utils.py +++ b/birdnet_analyzer/analyze/utils.py @@ -7,13 +7,10 @@ import numpy as np -import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg -import birdnet_analyzer.model as model -import birdnet_analyzer.utils as utils +from birdnet_analyzer import audio, model, utils -# 0 1 2 3 4 5 6 7 8 9 10 11 -RAVEN_TABLE_HEADER = "Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tCommon Name\tSpecies Code\tConfidence\tBegin Path\tFile Offset (s)\n" +RAVEN_TABLE_HEADER = "Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tCommon Name\tSpecies Code\tConfidence\tBegin Path\tFile Offset (s)\n" # noqa: E501 RTABLE_HEADER = "filepath,start,end,scientific_name,common_name,confidence,lat,lon,week,overlap,sensitivity,min_conf,species_list,model\n" KALEIDOSCOPE_HEADER = ( "INDIR,FOLDER,IN FILE,OFFSET,DURATION,scientific_name,common_name,confidence,lat,lon,week,overlap,sensitivity\n" @@ -58,10 +55,8 @@ def load_codes(): Returns: A dictionary containing the eBird codes. """ - with open(os.path.join(SCRIPT_DIR, cfg.CODES_FILE), "r") as cfile: - codes = json.load(cfile) - - return codes + with open(os.path.join(SCRIPT_DIR, cfg.CODES_FILE)) as cfile: + return json.load(cfile) def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str): @@ -83,8 +78,7 @@ def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_p # Read native sample rate high_freq = audio.get_sample_rate(afile_path) / 2 - if high_freq > int(cfg.SIG_FMAX / cfg.AUDIO_SPEED): - high_freq = int(cfg.SIG_FMAX / cfg.AUDIO_SPEED) + high_freq = min(high_freq, int(cfg.SIG_FMAX / cfg.AUDIO_SPEED)) high_freq = min(high_freq, int(cfg.BANDPASS_FMAX / cfg.AUDIO_SPEED)) low_freq = max(cfg.SIG_FMIN, int(cfg.BANDPASS_FMIN / cfg.AUDIO_SPEED)) @@ -98,13 +92,15 @@ def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_p selection_id += 1 label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] code = cfg.CODES[c[0]] if c[0] in cfg.CODES else c[0] - rstring += f"{selection_id}\tSpectrogram 1\t1\t{start}\t{end}\t{low_freq}\t{high_freq}\t{label.split('_', 1)[-1]}\t{code}\t{c[1]:.4f}\t{afile_path}\t{start}\n" + rstring += f"{selection_id}\tSpectrogram 1\t1\t{start}\t{end}\t{low_freq}\t{high_freq}\t{label.split('_', 1)[-1]}\t{code}\t{c[1]:.4f}\t{afile_path}\t{start}\n" # noqa: E501 # Write result string to file out_string += rstring - # If we don't have any valid predictions, we still need to add a line to the selection table in case we want to combine results - # TODO: That's a weird way to do it, but it works for now. It would be better to keep track of file durations during the analysis. + # If we don't have any valid predictions, we still need to add a line to the selection table + # in case we want to combine results + # TODO: That's a weird way to do it, but it works for now. It would be better to keep track + # of file durations during the analysis. if len(out_string) == len(RAVEN_TABLE_HEADER) and cfg.OUTPUT_PATH is not None: selection_id += 1 out_string += ( @@ -280,7 +276,7 @@ def combine_raven_tables(saved_results: list[str]): for rfile in saved_results: if not rfile: continue - with open(rfile, "r", encoding="utf-8") as rf: + with open(rfile, encoding="utf-8") as rf: try: lines = rf.readlines() @@ -305,16 +301,16 @@ def combine_raven_tables(saved_results: list[str]): continue # adjust selection id - line = line.split("\t") - line[0] = str(s_id) + line_elements = line.split("\t") + line_elements[0] = str(s_id) s_id += 1 # adjust time - line[3] = str(float(line[3]) + time_offset) - line[4] = str(float(line[4]) + time_offset) + line_elements[3] = str(float(line_elements[3]) + time_offset) + line_elements[4] = str(float(line_elements[4]) + time_offset) # write line - f.write("\t".join(line)) + f.write("\t".join(line_elements)) # adjust time offset time_offset += f_duration @@ -326,7 +322,7 @@ def combine_raven_tables(saved_results: list[str]): listfilesname = cfg.OUTPUT_RAVEN_FILENAME.rsplit(".", 1)[0] + ".list.txt" with open(os.path.join(cfg.OUTPUT_PATH, listfilesname), "w", encoding="utf-8") as f: - f.writelines((f + "\n" for f in audiofiles)) + f.writelines(f + "\n" for f in audiofiles) def combine_kaleidoscope_files(saved_results: list[str]): @@ -344,7 +340,7 @@ def combine_kaleidoscope_files(saved_results: list[str]): f.write(KALEIDOSCOPE_HEADER) for rfile in saved_results: - with open(rfile, "r", encoding="utf-8") as rf: + with open(rfile, encoding="utf-8") as rf: try: lines = rf.readlines() @@ -373,7 +369,7 @@ def combine_csv_files(saved_results: list[str]): f.write(CSV_HEADER) for rfile in saved_results: - with open(rfile, "r", encoding="utf-8") as rf: + with open(rfile, encoding="utf-8") as rf: try: lines = rf.readlines() @@ -417,14 +413,15 @@ def combine_results(saved_results: list[dict[str, str]]): combine_csv_files([f["csv"] for f in saved_results if f]) -def merge_consecutive_detections(results: dict[str, list], max_consecutive: int = None): +def merge_consecutive_detections(results: dict[str, list], max_consecutive: int | None = None): """Merges consecutive detections of the same species. Uses the mean of the top-3 highest scoring predictions as confidence score for the merged detection. Args: results: The dictionary with {segment: scores}. - max_consecutive: The maximum number of consecutive detections to merge. If None, merge all consecutive detections. + max_consecutive: The maximum number of consecutive detections to merge. + If None, merge all consecutive detections. Returns: The dictionary with merged detections. @@ -517,9 +514,7 @@ def get_raw_audio_from_file(fpath: str, offset, duration): ) # Split into raw audio chunks - chunks = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN) - - return chunks + return audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN) def predict(samples): @@ -557,10 +552,7 @@ def get_result_file_names(fpath: str): rpath = fpath.replace(cfg.INPUT_PATH, "") - if rpath: - rpath = rpath[1:] if rpath[0] in ["/", "\\"] else rpath - else: - rpath = os.path.basename(fpath) + rpath = (rpath[1:] if rpath[0] in ["/", "\\"] else rpath) if rpath else os.path.basename(fpath) file_shorthand = rpath.rsplit(".", 1)[0] @@ -599,10 +591,9 @@ def analyze_file(item): result_file_names = get_result_file_names(fpath) - if cfg.SKIP_EXISTING_RESULTS: - if all(os.path.exists(f) for f in result_file_names.values()): - print(f"Skipping {fpath} as it has already been analyzed", flush=True) - return None # or return path to combine later? TODO + if cfg.SKIP_EXISTING_RESULTS and all(os.path.exists(f) for f in result_file_names.values()): + print(f"Skipping {fpath} as it has already been analyzed", flush=True) + return None # or return path to combine later? TODO # Start time start_time = datetime.datetime.now() @@ -668,7 +659,7 @@ def analyze_file(item): if cfg.TOP_N: p_sorted = p_sorted[: cfg.TOP_N] - # TODO hier schon top n oder min conf raussortieren + # TODO: hier schon top n oder min conf raussortieren # Store top 5 results and advance indices results[str(s_start) + "-" + str(s_end)] = p_sorted diff --git a/birdnet_analyzer/audio.py b/birdnet_analyzer/audio.py index 905f84f6..e7b356af 100644 --- a/birdnet_analyzer/audio.py +++ b/birdnet_analyzer/audio.py @@ -3,7 +3,7 @@ import librosa import numpy as np import soundfile as sf -from scipy.signal import firwin, kaiserord, lfilter, find_peaks +from scipy.signal import find_peaks, firwin, kaiserord, lfilter import birdnet_analyzer.config as cfg @@ -184,8 +184,7 @@ def split_signal(sig, rate, seconds, overlap, minlen, amount=None): # Split signal with overlap sig_splits = [] - for i in range(0, 1 + lastchunkpos, stepsize): - sig_splits.append(data[i : i + chunksize]) + sig_splits.extend(data[i : i + chunksize] for i in range(0, lastchunkpos, stepsize)) return sig_splits @@ -212,38 +211,35 @@ def crop_center(sig, rate, seconds): return sig + def smart_crop_signal(sig, rate, sig_length, sig_overlap, sig_minlen): """Smart crop audio signal based on peak detection. - + This function analyzes the audio signal to find peaks in energy/amplitude, which are more likely to contain relevant target signals (e.g., bird calls). Only the audio segments with the highest energy peaks are returned. - + Args: sig: The audio signal. rate: The sample rate of the audio signal. sig_length: The desired length of each snippet in seconds. sig_overlap: The overlap between snippets in seconds. sig_minlen: The minimum length of a snippet in seconds. - + Returns: A list of audio snippets with the highest energy/peaks. """ - + # If signal is too short, just return it if len(sig) / rate <= sig_length: return [sig] - - # Calculate the window size in samples - window_size = int(sig_length * rate) - hop_size = int((sig_length - sig_overlap) * rate) - + # Split the signal into overlapping windows splits = split_signal(sig, rate, sig_length, sig_overlap, sig_minlen) - + if len(splits) <= 1: return splits - + # Calculate energy for each window energies = [] for split in splits: @@ -253,28 +249,28 @@ def smart_crop_signal(sig, rate, sig_length, sig_overlap, sig_minlen): peak = np.max(np.abs(split)) # Combine both metrics energies.append(energy * 0.7 + peak * 0.3) # Weighted combination - + # Find peaks in the energy curve # Smooth energies first to avoid small fluctuations - smoothed_energies = np.convolve(energies, np.ones(3)/3, mode='same') + smoothed_energies = np.convolve(energies, np.ones(3) / 3, mode="same") peaks, _ = find_peaks(smoothed_energies, height=np.mean(smoothed_energies), distance=2) - + # If no clear peaks found, fall back to selecting top energy segments if len(peaks) < 2: # Sort segments by energy and take top segments (up to 3 or 1/3 of total, whichever is more) num_segments = max(3, len(splits) // 3) indices = np.argsort(energies)[-num_segments:] return [splits[i] for i in sorted(indices)] - + # Return the audio segments corresponding to the peaks peak_splits = [splits[i] for i in peaks] - + # If we have too many peaks, select the strongest ones if len(peak_splits) > 5: peak_energies = [energies[i] for i in peaks] sorted_indices = np.argsort(peak_energies)[::-1] # Sort in descending order peak_splits = [peak_splits[i] for i in sorted_indices[:5]] # Take top 5 - + return peak_splits @@ -293,7 +289,7 @@ def bandpass(sig, rate, fmin, fmax, order=5): numpy.ndarray: The filtered signal as a float32 array. """ # Check if we have to bandpass at all - if fmin == cfg.SIG_FMIN and fmax == cfg.SIG_FMAX or fmin > fmax: + if (fmin == cfg.SIG_FMIN and fmax == cfg.SIG_FMAX) or fmin > fmax: return sig from scipy.signal import butter, lfilter @@ -342,7 +338,7 @@ def bandpass_kaiser_fir(sig, rate, fmin, fmax, width=0.02, stopband_attenuation_ numpy.ndarray: The filtered signal as a float32 numpy array. """ # Check if we have to bandpass at all - if fmin == cfg.SIG_FMIN and fmax == cfg.SIG_FMAX or fmin > fmax: + if (fmin == cfg.SIG_FMIN and fmax == cfg.SIG_FMAX) or fmin > fmax: return sig nyquist = 0.5 * rate diff --git a/birdnet_analyzer/cli.py b/birdnet_analyzer/cli.py index b043b3ac..e801ec7f 100644 --- a/birdnet_analyzer/cli.py +++ b/birdnet_analyzer/cli.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 import argparse import os @@ -30,7 +31,7 @@ **===== ***+== ****+ -""" +""" # noqa: W291 def io_args(): @@ -46,6 +47,7 @@ def io_args(): p.add_argument( "input", metavar="INPUT", + dest="audio_input", help="Path to input file or folder.", ) p.add_argument("-o", "--output", help="Path to output folder. Defaults to the input path.") @@ -399,6 +401,7 @@ def embeddings_parser(): parser.add_argument( "-i", "--input", + dest="audio_input", help="Path to input file or folder.", ) @@ -494,7 +497,7 @@ def segments_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=[audio_speed_args(), threads_args(), min_conf_args()], ) - parser.add_argument("input", metavar="INPUT", help="Path to folder containing audio files.") + parser.add_argument("input", dest="audio_input", metavar="INPUT", help="Path to folder containing audio files.") parser.add_argument("-r", "--results", help="Path to folder containing result files. Defaults to the `input` path.") parser.add_argument( "-o", "--output", help="Output folder path for extracted segments. Defaults to the `input` path." @@ -599,11 +602,12 @@ def train_parser(): parser.add_argument( "input", metavar="INPUT", + dest="audio_input", help="Path to training data folder. Subfolder names are used as labels.", - ) + ) parser.add_argument( - "--test_data", - help="Path to test data folder. If not specified, a random validation split will be used.") + "--test_data", help="Path to test data folder. If not specified, a random validation split will be used." + ) parser.add_argument( "--crop_mode", default=cfg.SAMPLE_CROP_MODE, @@ -628,7 +632,7 @@ def train_parser(): type=float, default=cfg.TRAIN_LEARNING_RATE, help="Learning rate.", - ) + ) parser.add_argument( "--focal-loss", action="store_true", @@ -645,7 +649,7 @@ def train_parser(): default=cfg.FOCAL_LOSS_ALPHA, type=float, help="Focal loss alpha parameter (balancing parameter). Controls weight between positive and negative examples.", - ) + ) parser.add_argument( "--hidden_units", type=int, diff --git a/birdnet_analyzer/embeddings/__init__.py b/birdnet_analyzer/embeddings/__init__.py index 529958bc..c7caad13 100644 --- a/birdnet_analyzer/embeddings/__init__.py +++ b/birdnet_analyzer/embeddings/__init__.py @@ -1,4 +1,3 @@ from birdnet_analyzer.embeddings.core import embeddings - __all__ = ["embeddings"] diff --git a/birdnet_analyzer/embeddings/__main__.py b/birdnet_analyzer/embeddings/__main__.py index 4ae12f15..b8631552 100644 --- a/birdnet_analyzer/embeddings/__main__.py +++ b/birdnet_analyzer/embeddings/__main__.py @@ -1,3 +1,3 @@ from birdnet_analyzer.embeddings.cli import main -main() \ No newline at end of file +main() diff --git a/birdnet_analyzer/embeddings/cli.py b/birdnet_analyzer/embeddings/cli.py index dbe8cf8d..636bbf8f 100644 --- a/birdnet_analyzer/embeddings/cli.py +++ b/birdnet_analyzer/embeddings/cli.py @@ -1,11 +1,10 @@ -from birdnet_analyzer.utils import runtime_error_handler - from birdnet_analyzer import embeddings +from birdnet_analyzer.utils import runtime_error_handler @runtime_error_handler def main(): - import birdnet_analyzer.cli as cli + from birdnet_analyzer import cli parser = cli.embeddings_parser() args = parser.parse_args() diff --git a/birdnet_analyzer/embeddings/core.py b/birdnet_analyzer/embeddings/core.py index e50d0482..76f1564f 100644 --- a/birdnet_analyzer/embeddings/core.py +++ b/birdnet_analyzer/embeddings/core.py @@ -1,5 +1,5 @@ def embeddings( - input: str, + audio_input: str, database: str, *, overlap: float = 0.0, @@ -15,7 +15,7 @@ def embeddings( representations of audio features. The embeddings can be used for further analysis or comparison. Args: - input (str): Path to the input audio file or directory containing audio files. + audio_input (str): Path to the input audio file or directory containing audio files. database (str): Path to the database where embeddings will be stored. overlap (float, optional): Overlap between consecutive audio segments in seconds. Defaults to 0.0. audio_speed (float, optional): Speed factor for audio processing. Defaults to 1.0. @@ -32,8 +32,8 @@ def embeddings( verify this. Example: embeddings( - input="path/to/audio", - database="path/to/database", + "path/to/audio", + "path/to/database", overlap=0.5, audio_speed=1.0, fmin=500, @@ -46,7 +46,7 @@ def embeddings( from birdnet_analyzer.utils import ensure_model_exists ensure_model_exists() - run(input, database, overlap, audio_speed, fmin, fmax, threads, batch_size) + run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batch_size) def get_database(db_path: str): @@ -62,9 +62,8 @@ def get_database(db_path: str): if not os.path.exists(db_path): os.makedirs(os.path.dirname(db_path), exist_ok=True) - db = sqlite_usearch_impl.SQLiteUsearchDB.create( + return sqlite_usearch_impl.SQLiteUsearchDB.create( db_path=db_path, - usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024), # TODO dont hardcode this + usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024), # TODO: dont hardcode this ) - return db return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path) diff --git a/birdnet_analyzer/embeddings/utils.py b/birdnet_analyzer/embeddings/utils.py index ecf5d31c..f8df6837 100644 --- a/birdnet_analyzer/embeddings/utils.py +++ b/birdnet_analyzer/embeddings/utils.py @@ -2,25 +2,20 @@ import datetime import os +from functools import partial +from multiprocessing import Pool import numpy as np +from ml_collections import ConfigDict +from perch_hoplite.db import interface as hoplite +from perch_hoplite.db import sqlite_usearch_impl +from tqdm import tqdm -import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg -import birdnet_analyzer.model as model -import birdnet_analyzer.utils as utils +from birdnet_analyzer import audio, model, utils from birdnet_analyzer.analyze.utils import get_raw_audio_from_file from birdnet_analyzer.embeddings.core import get_database - -from perch_hoplite.db import sqlite_usearch_impl -from perch_hoplite.db import interface as hoplite -from ml_collections import ConfigDict -from functools import partial -from tqdm import tqdm -from multiprocessing import Pool - - DATASET_NAME: str = "birdnet_analyzer_dataset" @@ -44,7 +39,7 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB): print(f"Error: Cannot analyze audio file {fpath}. File corrupt?\n", flush=True) utils.write_error_log(ex) - return None + return # Start time start_time = datetime.datetime.now() @@ -85,9 +80,7 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB): s_start, s_end = timestamps[i] # Check if embedding already exists - existing_embedding = db.get_embeddings_by_source( - DATASET_NAME, source_id, np.array([s_start, s_end]) - ) + existing_embedding = db.get_embeddings_by_source(DATASET_NAME, source_id, np.array([s_start, s_end])) if existing_embedding.size == 0: # Get prediction @@ -114,35 +107,28 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB): return delta_time = (datetime.datetime.now() - start_time).total_seconds() - print("Finished {} in {:.2f} seconds".format(fpath, delta_time), flush=True) + print(f"Finished {fpath} in {delta_time:.2f} seconds", flush=True) def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB): try: settings = db.get_metadata("birdnet_analyzer_settings") - if ( - settings["BANDPASS_FMIN"] != cfg.BANDPASS_FMIN - or settings["BANDPASS_FMAX"] != cfg.BANDPASS_FMAX - or settings["AUDIO_SPEED"] != cfg.AUDIO_SPEED - ): + if settings["BANDPASS_FMIN"] != cfg.BANDPASS_FMIN or settings["BANDPASS_FMAX"] != cfg.BANDPASS_FMAX or settings["AUDIO_SPEED"] != cfg.AUDIO_SPEED: raise ValueError( - "Database settings do not match current configuration. DB Settings are: fmin: {}, fmax: {}, audio_speed: {}".format( - settings["BANDPASS_FMIN"], settings["BANDPASS_FMAX"], settings["AUDIO_SPEED"] - ) + "Database settings do not match current configuration. DB Settings are: fmin:" + + f"{settings['BANDPASS_FMIN']}, fmax: {settings['BANDPASS_FMAX']}, audio_speed: {settings['AUDIO_SPEED']}" ) except KeyError: - settings = ConfigDict( - {"BANDPASS_FMIN": cfg.BANDPASS_FMIN, "BANDPASS_FMAX": cfg.BANDPASS_FMAX, "AUDIO_SPEED": cfg.AUDIO_SPEED} - ) + settings = ConfigDict({"BANDPASS_FMIN": cfg.BANDPASS_FMIN, "BANDPASS_FMAX": cfg.BANDPASS_FMAX, "AUDIO_SPEED": cfg.AUDIO_SPEED}) db.insert_metadata("birdnet_analyzer_settings", settings) db.commit() -def run(input, database, overlap, audio_speed, fmin, fmax, threads, batchsize): +def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchsize): ### Make sure to comment out appropriately if you are not using args. ### # Set input and output path - cfg.INPUT_PATH = input + cfg.INPUT_PATH = audio_input # Parse input files if os.path.isdir(cfg.INPUT_PATH): diff --git a/birdnet_analyzer/evaluation/__init__.py b/birdnet_analyzer/evaluation/__init__.py index 8a34deeb..cfe948cb 100644 --- a/birdnet_analyzer/evaluation/__init__.py +++ b/birdnet_analyzer/evaluation/__init__.py @@ -9,24 +9,25 @@ import argparse import json import os -from typing import Optional, Dict, List, Tuple +from birdnet_analyzer.evaluation.assessment.performance_assessor import ( + PerformanceAssessor, +) from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor -from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor def process_data( annotation_path: str, prediction_path: str, - mapping_path: Optional[str] = None, + mapping_path: str | None = None, sample_duration: float = 3.0, min_overlap: float = 0.5, - recording_duration: Optional[float] = None, - columns_annotations: Optional[Dict[str, str]] = None, - columns_predictions: Optional[Dict[str, str]] = None, - selected_classes: Optional[List[str]] = None, - selected_recordings: Optional[List[str]] = None, - metrics_list: Tuple[str, ...] = ("accuracy", "precision", "recall"), + recording_duration: float | None = None, + columns_annotations: dict[str, str] | None = None, + columns_predictions: dict[str, str] | None = None, + selected_classes: list[str] | None = None, + selected_recordings: list[str] | None = None, + metrics_list: tuple[str, ...] = ("accuracy", "precision", "recall"), threshold: float = 0.1, class_wise: bool = False, ): @@ -53,7 +54,7 @@ def process_data( """ # Load class mapping if provided if mapping_path: - with open(mapping_path, "r") as f: + with open(mapping_path) as f: class_mapping = json.load(f) else: class_mapping = None diff --git a/birdnet_analyzer/evaluation/assessment/metrics.py b/birdnet_analyzer/evaluation/assessment/metrics.py index c770f067..7743f05f 100644 --- a/birdnet_analyzer/evaluation/assessment/metrics.py +++ b/birdnet_analyzer/evaluation/assessment/metrics.py @@ -14,7 +14,7 @@ - calculate_auroc: Computes the Area Under the Receiver Operating Characteristic curve (AUROC). """ -from typing import Literal, Optional +from typing import Literal import numpy as np from sklearn.metrics import ( @@ -33,7 +33,7 @@ def calculate_accuracy( task: Literal["binary", "multilabel"], num_classes: int, threshold: float, - averaging_method: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + averaging_method: Literal["micro", "macro", "weighted", "none"] | None = "macro", ) -> np.ndarray: """ Calculate accuracy for the given predictions and labels. @@ -115,7 +115,7 @@ def calculate_recall( labels: np.ndarray, task: Literal["binary", "multilabel"], threshold: float, - averaging_method: Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]] = None, + averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None, ) -> np.ndarray: """ Calculate recall for the given predictions and labels. @@ -172,7 +172,7 @@ def calculate_precision( labels: np.ndarray, task: Literal["binary", "multilabel"], threshold: float, - averaging_method: Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]] = None, + averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None, ) -> np.ndarray: """ Calculate precision for the given predictions and labels. @@ -229,7 +229,7 @@ def calculate_f1_score( labels: np.ndarray, task: Literal["binary", "multilabel"], threshold: float, - averaging_method: Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]] = None, + averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None, ) -> np.ndarray: """ Calculate the F1 score for the given predictions and labels. @@ -285,7 +285,7 @@ def calculate_average_precision( predictions: np.ndarray, labels: np.ndarray, task: Literal["binary", "multilabel"], - averaging_method: Optional[Literal["micro", "macro", "weighted", "samples", "none"]] = None, + averaging_method: Literal["micro", "macro", "weighted", "samples", "none"] | None = None, ) -> np.ndarray: """ Calculate the average precision (AP) for the given predictions and labels. @@ -313,12 +313,7 @@ def calculate_average_precision( averaging = None if averaging_method == "none" else averaging_method # Compute average precision based on task type - if task == "binary": - y_true = labels.astype(int) - y_scores = predictions - ap = average_precision_score(y_true, y_scores, average=averaging) - - elif task == "multilabel": + if task in ("binary", "multilabel"): y_true = labels.astype(int) y_scores = predictions ap = average_precision_score(y_true, y_scores, average=averaging) @@ -337,7 +332,7 @@ def calculate_auroc( predictions: np.ndarray, labels: np.ndarray, task: Literal["binary", "multilabel"], - averaging_method: Optional[Literal["macro", "weighted", "samples", "none"]] = "macro", + averaging_method: Literal["macro", "weighted", "samples", "none"] | None = "macro", ) -> np.ndarray: """ Calculate the Area Under the Receiver Operating Characteristic curve (AUROC). @@ -382,9 +377,7 @@ def calculate_auroc( except ValueError as e: # Handle edge cases where AUROC cannot be computed - if "Only one class present in y_true" in str(e): - auroc = np.nan - elif "Number of classes in y_true" in str(e): + if "Only one class present in y_true" in str(e) or "Number of classes in y_true" in str(e): auroc = np.nan else: raise diff --git a/birdnet_analyzer/evaluation/assessment/performance_assessor.py b/birdnet_analyzer/evaluation/assessment/performance_assessor.py index c7b7a60c..20708d4c 100644 --- a/birdnet_analyzer/evaluation/assessment/performance_assessor.py +++ b/birdnet_analyzer/evaluation/assessment/performance_assessor.py @@ -6,15 +6,14 @@ as well as utilities for generating related plots. """ -from typing import Literal, Optional, Tuple +from typing import Literal import matplotlib.pyplot as plt import numpy as np import pandas as pd -from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix -from birdnet_analyzer.evaluation.assessment import metrics -from birdnet_analyzer.evaluation.assessment import plotting +from birdnet_analyzer.evaluation.assessment import metrics, plotting class PerformanceAssessor: @@ -27,9 +26,9 @@ def __init__( self, num_classes: int, threshold: float = 0.5, - classes: Optional[Tuple[str, ...]] = None, + classes: tuple[str, ...] | None = None, task: Literal["binary", "multilabel"] = "multilabel", - metrics_list: Tuple[str, ...] = ( + metrics_list: tuple[str, ...] = ( "recall", "precision", "f1", @@ -73,7 +72,7 @@ def __init__( raise ValueError("task must be 'binary' or 'multilabel'.") # Validate the metrics list - valid_metrics = {"accuracy", "recall", "precision", "f1", "ap", "auroc"} + valid_metrics = ["accuracy", "recall", "precision", "f1", "ap", "auroc"] if not metrics_list: raise ValueError("metrics_list cannot be empty.") if not all(metric in valid_metrics for metric in metrics_list): @@ -123,7 +122,8 @@ def calculate_metrics( raise ValueError("predictions and labels must be 2-dimensional arrays.") if predictions.shape[1] != self.num_classes: raise ValueError( - f"The number of columns in predictions ({predictions.shape[1]}) must match num_classes ({self.num_classes})." + f"The number of columns in predictions ({predictions.shape[1]}) " + + f"must match num_classes ({self.num_classes})." ) # Determine the averaging method for metrics @@ -192,10 +192,11 @@ def calculate_metrics( metrics_results["Accuracy"] = np.atleast_1d(result) # Define column names for the DataFrame - if per_class_metrics: - columns = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)] - else: - columns = ["Overall"] + columns = ( + (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]) + if per_class_metrics + else ["Overall"] + ) # Create a DataFrame to organize metric results metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()} @@ -225,14 +226,11 @@ def plot_metrics( metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics) # Choose the plotting method based on whether per-class metrics are required - if per_class_metrics: - # Plot metrics per class - fig = plotting.plot_metrics_per_class(metrics_df, self.colors) - else: - # Plot overall metrics - fig = plotting.plot_overall_metrics(metrics_df, self.colors) - - return fig + return ( + plotting.plot_metrics_per_class(metrics_df, self.colors) + if per_class_metrics + else plotting.plot_overall_metrics(metrics_df, self.colors) + ) def plot_metrics_all_thresholds( self, @@ -349,7 +347,8 @@ def plot_confusion_matrix( raise ValueError("predictions and labels must be 2-dimensional arrays.") if predictions.shape[1] != self.num_classes: raise ValueError( - f"The number of columns in predictions ({predictions.shape[1]}) must match num_classes ({self.num_classes})." + f"The number of columns in predictions ({predictions.shape[1]}) " + + f"must match num_classes ({self.num_classes})." ) if self.task == "binary": @@ -369,7 +368,7 @@ def plot_confusion_matrix( return fig - elif self.task == "multilabel": + if self.task == "multilabel": # Binarize predictions for multilabel classification y_pred = (predictions >= self.threshold).astype(int) y_true = labels.astype(int) @@ -392,7 +391,7 @@ def plot_confusion_matrix( axes = axes.flatten() # Plot each confusion matrix - for idx, (conf_mat, class_name) in enumerate(zip(conf_mats, class_names)): + for idx, (conf_mat, class_name) in enumerate(zip(conf_mats, class_names, strict=True)): disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"]) disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f") axes[idx].set_title(f"{class_name}") @@ -407,5 +406,4 @@ def plot_confusion_matrix( return fig - else: - raise ValueError(f"Unsupported task type: {self.task}") + raise ValueError(f"Unsupported task type: {self.task}") diff --git a/birdnet_analyzer/evaluation/assessment/plotting.py b/birdnet_analyzer/evaluation/assessment/plotting.py index 27f82036..ab5ecd6e 100644 --- a/birdnet_analyzer/evaluation/assessment/plotting.py +++ b/birdnet_analyzer/evaluation/assessment/plotting.py @@ -13,7 +13,7 @@ - plot_confusion_matrices: Visualizes confusion matrices for binary, multiclass, or multilabel tasks. """ -from typing import Dict, List, Literal +from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -21,7 +21,7 @@ import seaborn as sns -def plot_overall_metrics(metrics_df: pd.DataFrame, colors: List[str]) -> plt.Figure: +def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure: """ Plots a bar chart for overall performance metrics. @@ -52,7 +52,7 @@ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: List[str]) -> plt.Fig # Extract metric names and values metrics = metrics_df.index # Metric names - values = metrics_df["Overall"].values # Metric values + values = metrics_df["Overall"].to_numpy() # Metric values # Plot bar chart fig = plt.figure(figsize=(10, 6)) @@ -69,7 +69,7 @@ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: List[str]) -> plt.Fig return fig -def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: List[str]) -> plt.Figure: +def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure: """ Plots metric values per class, with each metric represented by a distinct color and line. @@ -127,9 +127,9 @@ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: List[str]) -> plt.F def plot_metrics_across_thresholds( thresholds: np.ndarray, - metric_values_dict: Dict[str, np.ndarray], - metrics_to_plot: List[str], - colors: List[str], + metric_values_dict: dict[str, np.ndarray], + metrics_to_plot: list[str], + colors: list[str], ) -> plt.Figure: """ Plots metrics across different thresholds. @@ -195,10 +195,10 @@ def plot_metrics_across_thresholds( def plot_metrics_across_thresholds_per_class( thresholds: np.ndarray, - metric_values_dict_per_class: Dict[str, Dict[str, np.ndarray]], - metrics_to_plot: List[str], - class_names: List[str], - colors: List[str], + metric_values_dict_per_class: dict[str, dict[str, np.ndarray]], + metrics_to_plot: list[str], + class_names: list[str], + colors: list[str], ) -> plt.Figure: """ Plots metrics across different thresholds per class. @@ -247,10 +247,7 @@ def plot_metrics_across_thresholds_per_class( fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4)) # Flatten axes for easy indexing - if num_classes == 1: - axes = [axes] - else: - axes = axes.flatten() + axes = [axes] if num_classes == 1 else axes.flatten() # Line styles for distinction line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))] @@ -269,7 +266,8 @@ def plot_metrics_across_thresholds_per_class( metric_values = metric_values_dict[metric_name] if len(metric_values) != len(thresholds): raise ValueError( - f"Length of metric '{metric_name}' values for class '{class_name}' does not match length of thresholds." + f"Length of metric '{metric_name}' values for class '{class_name}' " + + "does not match length of thresholds." ) ax.plot( thresholds, @@ -300,7 +298,7 @@ def plot_metrics_across_thresholds_per_class( def plot_confusion_matrices( conf_mat: np.ndarray, task: Literal["binary", "multiclass", "multilabel"], - class_names: List[str], + class_names: list[str], ) -> plt.Figure: """ Plots confusion matrices for each class in a single figure with multiple subplots. @@ -360,10 +358,7 @@ def plot_confusion_matrices( fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2)) # Flatten axes for easy indexing - if num_labels == 1: - axes = [axes] - else: - axes = axes.flatten() + axes = [axes] if num_labels == 1 else axes.flatten() # Plot each class's confusion matrix for i in range(num_labels): diff --git a/birdnet_analyzer/evaluation/preprocessing/data_processor.py b/birdnet_analyzer/evaluation/preprocessing/data_processor.py index ae66c682..5b549b7e 100644 --- a/birdnet_analyzer/evaluation/preprocessing/data_processor.py +++ b/birdnet_analyzer/evaluation/preprocessing/data_processor.py @@ -7,7 +7,7 @@ import os import warnings -from typing import Dict, List, Optional, Tuple +from typing import ClassVar import numpy as np import pandas as pd @@ -28,7 +28,7 @@ class DataProcessor: """ # Default column mappings for predictions and annotations - DEFAULT_COLUMNS_PREDICTIONS = { + DEFAULT_COLUMNS_PREDICTIONS: ClassVar[dict[str, str]] = { "Start Time": "Start Time", "End Time": "End Time", "Class": "Class", @@ -37,7 +37,7 @@ class DataProcessor: "Confidence": "Confidence", } - DEFAULT_COLUMNS_ANNOTATIONS = { + DEFAULT_COLUMNS_ANNOTATIONS: ClassVar[dict[str, str]] = { "Start Time": "Start Time", "End Time": "End Time", "Class": "Class", @@ -49,14 +49,14 @@ def __init__( self, prediction_directory_path: str, annotation_directory_path: str, - prediction_file_name: Optional[str] = None, - annotation_file_name: Optional[str] = None, - class_mapping: Optional[Dict[str, str]] = None, + prediction_file_name: str | None = None, + annotation_file_name: str | None = None, + class_mapping: dict[str, str] | None = None, sample_duration: int = 3, min_overlap: float = 0.5, - columns_predictions: Optional[Dict[str, str]] = None, - columns_annotations: Optional[Dict[str, str]] = None, - recording_duration: Optional[float] = None, + columns_predictions: dict[str, str] | None = None, + columns_annotations: dict[str, str] | None = None, + recording_duration: float | None = None, ) -> None: """ Initializes the DataProcessor by loading prediction and annotation data. @@ -66,12 +66,15 @@ def __init__( annotation_directory_path (str): Path to the folder containing annotation files. prediction_file_name (Optional[str]): Name of the prediction file to process. annotation_file_name (Optional[str]): Name of the annotation file to process. - class_mapping (Optional[Dict[str, str]]): Optional dictionary mapping raw class names to standardized class names. + class_mapping (Optional[Dict[str, str]]): Optional dictionary mapping raw class + names to standardized class names. sample_duration (int, optional): Length of each data sample in seconds. Defaults to 3. - min_overlap (float, optional): Minimum overlap required between prediction and annotation to consider a match. + min_overlap (float, optional): Minimum overlap required between prediction and + annotation to consider a match. columns_predictions (Optional[Dict[str, str]], optional): Column name mappings for prediction files. columns_annotations (Optional[Dict[str, str]], optional): Column name mappings for annotation files. - recording_duration (Optional[float], optional): User-specified recording duration in seconds. Defaults to None. + recording_duration (Optional[float], optional): User-specified recording duration in seconds. + Defaults to None. Raises: ValueError: If any parameter is invalid (e.g., negative sample duration). @@ -79,30 +82,26 @@ def __init__( # Initialize instance variables self.sample_duration: int = sample_duration self.min_overlap: float = min_overlap - self.class_mapping: Optional[Dict[str, str]] = class_mapping + self.class_mapping: dict[str, str] | None = class_mapping # Use provided column mappings or defaults - self.columns_predictions: Dict[str, str] = ( - columns_predictions if columns_predictions is not None else self.DEFAULT_COLUMNS_PREDICTIONS.copy() - ) - self.columns_annotations: Dict[str, str] = ( - columns_annotations if columns_annotations is not None else self.DEFAULT_COLUMNS_ANNOTATIONS.copy() - ) + self.columns_predictions: dict[str, str] = columns_predictions if columns_predictions is not None else self.DEFAULT_COLUMNS_PREDICTIONS.copy() + self.columns_annotations: dict[str, str] = columns_annotations if columns_annotations is not None else self.DEFAULT_COLUMNS_ANNOTATIONS.copy() - self.recording_duration: Optional[float] = recording_duration + self.recording_duration: float | None = recording_duration # Paths and filenames self.prediction_directory_path: str = prediction_directory_path - self.prediction_file_name: Optional[str] = prediction_file_name + self.prediction_file_name: str | None = prediction_file_name self.annotation_directory_path: str = annotation_directory_path - self.annotation_file_name: Optional[str] = annotation_file_name + self.annotation_file_name: str | None = annotation_file_name # DataFrames for predictions and annotations self.predictions_df: pd.DataFrame = pd.DataFrame() self.annotations_df: pd.DataFrame = pd.DataFrame() # Placeholder for unique classes across predictions and annotations - self.classes: Tuple[str, ...] = () + self.classes: tuple[str, ...] = () # Placeholder for samples DataFrame and tensors self.samples_df: pd.DataFrame = pd.DataFrame() @@ -153,18 +152,10 @@ def _validate_columns(self) -> None: required_columns = ["Start Time", "End Time", "Class"] # Check for missing or None columns in predictions - missing_pred_columns = [ - col - for col in required_columns - if col not in self.columns_predictions or self.columns_predictions[col] is None - ] + missing_pred_columns = [col for col in required_columns if col not in self.columns_predictions or self.columns_predictions[col] is None] # Check for missing or None columns in annotations - missing_annot_columns = [ - col - for col in required_columns - if col not in self.columns_annotations or self.columns_annotations[col] is None - ] + missing_annot_columns = [col for col in required_columns if col not in self.columns_annotations or self.columns_annotations[col] is None] if missing_pred_columns: raise ValueError(f"Missing or None prediction columns: {', '.join(missing_pred_columns)}") @@ -201,15 +192,14 @@ def load_data(self) -> None: # Apply class mapping to predictions if provided if self.class_mapping: class_col_pred = self.get_column_name("Class", prediction=True) - self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply( - lambda x: self.class_mapping.get(x, x) - ) + self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply(lambda x: self.class_mapping.get(x, x)) else: # Case: Specific files are provided for predictions and annotations. # Ensure filenames correspond to the same recording (heuristic check). if not self.prediction_file_name.startswith(os.path.splitext(self.annotation_file_name)[0]): warnings.warn( - "Prediction file name and annotation file name do not fully match, but proceeding anyway." + "Prediction file name and annotation file name do not fully match, but proceeding anyway.", + stacklevel=2, ) # Construct full file paths @@ -231,9 +221,7 @@ def load_data(self) -> None: # Apply class mapping to predictions if provided if self.class_mapping: class_col_pred = self.get_column_name("Class", prediction=True) - self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply( - lambda x: self.class_mapping.get(x, x) - ) + self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply(lambda x: self.class_mapping.get(x, x)) # Consolidate all unique classes from predictions and annotations class_col_pred = self.get_column_name("Class", prediction=True) @@ -266,13 +254,12 @@ def _prepare_dataframe(self, df: pd.DataFrame, prediction: bool) -> pd.DataFrame if recording_col in df.columns: # Extract recording filename using the 'Recording' column df["recording_filename"] = extract_recording_filename(df[recording_col]) + elif "source_file" in df.columns: + # Fall back to extracting from the 'source_file' column + df["recording_filename"] = extract_recording_filename_from_filename(df["source_file"]) else: - if "source_file" in df.columns: - # Fall back to extracting from the 'source_file' column - df["recording_filename"] = extract_recording_filename_from_filename(df["source_file"]) - else: - # Assign a default empty string if no relevant columns exist - df["recording_filename"] = "" + # Assign a default empty string if no relevant columns exist + df["recording_filename"] = "" return df @@ -287,9 +274,7 @@ def process_data(self) -> None: self.samples_df = pd.DataFrame() # Initialize the samples DataFrame # Get the unique set of recording filenames from both predictions and annotations - recording_filenames = set(self.predictions_df["recording_filename"].unique()).union( - set(self.annotations_df["recording_filename"].unique()) - ) + recording_filenames = set(self.predictions_df["recording_filename"].unique()).union(set(self.annotations_df["recording_filename"].unique())) # Process each recording for recording_filename in recording_filenames: @@ -362,11 +347,11 @@ def determine_file_duration(self, pred_df: pd.DataFrame, annot_df: pd.DataFrame) file_duration_col_annot = self.get_column_name("Duration", prediction=False) # Try to get duration from 'Duration' column in pred_df - if file_duration_col_pred in pred_df.columns and pred_df[file_duration_col_pred].notnull().any(): + if file_duration_col_pred in pred_df.columns and pred_df[file_duration_col_pred].notna().any(): duration = max(duration, pred_df[file_duration_col_pred].dropna().max()) # Try to get duration from 'Duration' column in annot_df - if file_duration_col_annot in annot_df.columns and annot_df[file_duration_col_annot].notnull().any(): + if file_duration_col_annot in annot_df.columns and annot_df[file_duration_col_annot].notna().any(): duration = max(duration, annot_df[file_duration_col_annot].dropna().max()) # If no duration is found, use the maximum 'End Time' value @@ -459,10 +444,7 @@ def update_samples_with_predictions(self, pred_df: pd.DataFrame, samples_df: pd. confidence = row.get(confidence_col, 0.0) # Identify samples that overlap with the prediction based on min_overlap - sample_indices = samples_df[ - (samples_df["start_time"] <= end_time - self.min_overlap) - & (samples_df["end_time"] >= begin_time + self.min_overlap) - ].index + sample_indices = samples_df[(samples_df["start_time"] <= end_time - self.min_overlap) & (samples_df["end_time"] >= begin_time + self.min_overlap)].index # Update the confidence scores for the overlapping samples for i in sample_indices: @@ -497,10 +479,7 @@ def update_samples_with_annotations(self, annot_df: pd.DataFrame, samples_df: pd end_time = row[end_time_col] # Identify samples that overlap with the annotation based on min_overlap - sample_indices = samples_df[ - (samples_df["start_time"] <= end_time - self.min_overlap) - & (samples_df["end_time"] >= begin_time + self.min_overlap) - ].index + sample_indices = samples_df[(samples_df["start_time"] <= end_time - self.min_overlap) & (samples_df["end_time"] >= begin_time + self.min_overlap)].index # Set annotation value to 1 for the overlapping samples for i in sample_indices: @@ -525,12 +504,12 @@ def create_tensors(self) -> None: # Check for NaN values in annotation columns annotation_columns = [f"{cls}_annotation" for cls in self.classes] - if self.samples_df[annotation_columns].isnull().values.any(): + if self.samples_df[annotation_columns].isna().to_numpy().any(): raise ValueError("NaN values found in annotation columns.") # Check for NaN values in confidence columns confidence_columns = [f"{cls}_confidence" for cls in self.classes] - if self.samples_df[confidence_columns].isnull().values.any(): + if self.samples_df[confidence_columns].isna().to_numpy().any(): raise ValueError("NaN values found in confidence columns.") # Convert confidence scores and annotations into numpy arrays (tensors) @@ -584,9 +563,9 @@ def get_sample_data(self) -> pd.DataFrame: def get_filtered_tensors( self, - selected_classes: Optional[List[str]] = None, - selected_recordings: Optional[List[str]] = None, - ) -> Tuple[np.ndarray, np.ndarray, Tuple[str]]: + selected_classes: list[str] | None = None, + selected_recordings: list[str] | None = None, + ) -> tuple[np.ndarray, np.ndarray, tuple[str]]: """ Filters the prediction and label tensors based on selected classes and recordings. @@ -617,9 +596,7 @@ def get_filtered_tensors( raise ValueError("samples_df must contain a 'filename' column.") # Determine the classes to filter by - classes = ( - self.classes if selected_classes is None else tuple(cls for cls in selected_classes if cls in self.classes) - ) + classes = self.classes if selected_classes is None else tuple(cls for cls in selected_classes if cls in self.classes) if not classes: raise ValueError("No valid classes selected.") diff --git a/birdnet_analyzer/evaluation/preprocessing/utils.py b/birdnet_analyzer/evaluation/preprocessing/utils.py index c0feb29f..d1a90cb5 100644 --- a/birdnet_analyzer/evaluation/preprocessing/utils.py +++ b/birdnet_analyzer/evaluation/preprocessing/utils.py @@ -9,7 +9,7 @@ """ import os -from typing import List + import pandas as pd @@ -65,11 +65,11 @@ def read_and_concatenate_files_in_directory(directory_path: str) -> pd.DataFrame Raises: ValueError: If the columns in the files are inconsistent. """ - df_list: List[pd.DataFrame] = [] # List to hold individual DataFrames + df_list: list[pd.DataFrame] = [] # List to hold individual DataFrames columns_set = None # To ensure consistency in column names # Iterate through each file in the directory - for filename in os.listdir(directory_path): + for filename in sorted(os.listdir(directory_path)): if filename.endswith(".txt"): filepath = os.path.join(directory_path, filename) # Construct the full file path diff --git a/birdnet_analyzer/gui/__init__.py b/birdnet_analyzer/gui/__init__.py index cbb40720..0218c3a1 100644 --- a/birdnet_analyzer/gui/__init__.py +++ b/birdnet_analyzer/gui/__init__.py @@ -1,13 +1,9 @@ def main(): import birdnet_analyzer.gui.multi_file as mfa - import birdnet_analyzer.gui.review as review import birdnet_analyzer.gui.segments as gs import birdnet_analyzer.gui.single_file as sfa - import birdnet_analyzer.gui.species as species - import birdnet_analyzer.gui.train as train import birdnet_analyzer.gui.utils as gu - import birdnet_analyzer.gui.embeddings as embeddings - import birdnet_analyzer.gui.evaluation as evaluation + from birdnet_analyzer.gui import embeddings, evaluation, review, species, train gu.open_window( [ diff --git a/birdnet_analyzer/gui/analysis.py b/birdnet_analyzer/gui/analysis.py index 2d8c383d..06f02efd 100644 --- a/birdnet_analyzer/gui/analysis.py +++ b/birdnet_analyzer/gui/analysis.py @@ -5,12 +5,14 @@ import gradio as gr import birdnet_analyzer.config as cfg -import birdnet_analyzer.gui.utils as gu import birdnet_analyzer.gui.localization as loc -import birdnet_analyzer.model as model - - -from birdnet_analyzer.analyze.utils import analyze_file, combine_results, save_analysis_params +import birdnet_analyzer.gui.utils as gu +from birdnet_analyzer import model +from birdnet_analyzer.analyze.utils import ( + analyze_file, + combine_results, + save_analysis_params, +) SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) ORIGINAL_LABELS_FILE = str(Path(SCRIPT_DIR).parent / cfg.LABELS_FILE) @@ -103,7 +105,7 @@ def run_analysis( week = -1 if use_yearlong else week flist = _set_params( - input=input_dir if input_dir else input_path, + audio_input=input_dir if input_dir else input_path, min_conf=confidence, custom_classifier=custom_classifier, sensitivity=min(1.25, max(0.75, float(sensitivity))), @@ -143,8 +145,7 @@ def run_analysis( # Analyze files if cfg.CPU_THREADS < 2: - for entry in flist: - result_list.append(analyze_file_wrapper(entry)) + result_list.extend(analyze_file_wrapper(entry) for entry in flist) else: with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor: futures = (executor.submit(analyze_file_wrapper, arg) for arg in flist) diff --git a/birdnet_analyzer/gui/embeddings.py b/birdnet_analyzer/gui/embeddings.py index c6506bbf..9f381923 100644 --- a/birdnet_analyzer/gui/embeddings.py +++ b/birdnet_analyzer/gui/embeddings.py @@ -14,7 +14,7 @@ def play_audio(audio_infos): - import birdnet_analyzer.audio as audio + from birdnet_analyzer import audio arr, sr = audio.open_audio_file( audio_infos[0], @@ -98,13 +98,13 @@ def run_embeddings( threads, batch_size, ) - except: + except Exception as e: db.db.close() # Transform audiospeed from slider to float audio_speed = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed)) if fmin is None or fmax is None or fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax: - raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]") + raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]") from e run(input_path, db_path, overlap, audio_speed, fmin, fmax, threads, batch_size) @@ -142,14 +142,14 @@ def run_search(db_path, query_path, max_samples, score_fn, crop_mode, crop_overl return chunks, 0, gr.Button(interactive=True), {} -def run_export(export_state): - import birdnet_analyzer.audio as audio +def run_export(export_state: dict): + from birdnet_analyzer import audio if len(export_state.items()) > 0: export_folder = gu.select_folder(state_key="embeddings-search-export-folder") if export_folder: - for index, file in export_state.items(): + for file in export_state.values(): filebasename = os.path.basename(file[0]) filebasename = os.path.splitext(filebasename)[0] dest = os.path.join(export_folder, f"{file[4]:.5f}_{filebasename}_{file[1]}_{file[1] + file[2]}.wav") @@ -335,8 +335,7 @@ def check_settings(dir_name, db_name): def _build_search_tab(): - import birdnet_analyzer.audio as audio - import birdnet_analyzer.utils as utils + from birdnet_analyzer import audio, utils with gr.Tab(loc.localize("embeddings-search-tab-title")): results_state = gr.State([]) @@ -470,7 +469,7 @@ def render_results(results, page, db_path, exports): with gr.Row(): play_btn = gr.Button("â–¶") play_btn.click(play_audio, inputs=plot_audio_state, outputs=hidden_audio) - checkbox = gr.Checkbox(label="Export", value=(index in exports.keys())) + checkbox = gr.Checkbox(label="Export", value=(index in exports)) checkbox.change( update_export_state, inputs=[plot_audio_state, checkbox, export_state], @@ -563,8 +562,8 @@ def update_query_spectrogram(audiofilepath, db_selection, crop_mode, crop_overla spec = utils.spectrogram_from_audio(sig, rate, fig_size=(10, 4)) return spec, [], {} - else: - return None, [], {} + + return None, [], {} crop_mode.change( update_query_spectrogram, diff --git a/birdnet_analyzer/gui/evaluation.py b/birdnet_analyzer/gui/evaluation.py index 68c3cff5..b165cff5 100644 --- a/birdnet_analyzer/gui/evaluation.py +++ b/birdnet_analyzer/gui/evaluation.py @@ -10,8 +10,10 @@ import birdnet_analyzer.gui.localization as loc import birdnet_analyzer.gui.utils as gu -from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor from birdnet_analyzer.evaluation import process_data +from birdnet_analyzer.evaluation.assessment.performance_assessor import ( + PerformanceAssessor, +) from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor @@ -134,7 +136,7 @@ def get_columns_from_uploaded_files(files): print(f"Error reading file {file_obj}: {e}") gr.Warning(f"{loc.localize('eval-tab-warning-error-reading-file')} {file_obj}") - return sorted(list(columns)) + return sorted(columns) def save_uploaded_files(files): if not files: @@ -217,7 +219,7 @@ def initialize_processor( mapping_path = mapping_file_obj if mapping_file_obj else None if mapping_path: - with open(mapping_path, "r") as f: + with open(mapping_path) as f: class_mapping = json.load(f) else: class_mapping = None @@ -241,11 +243,7 @@ def initialize_processor( return avail_classes, avail_recordings, proc, annotation_dir, prediction_dir except KeyError as e: print(f"Column missing in files: {e}") - raise gr.Error( - f"{loc.localize('eval-tab-error-missing-col')}: " - + str(e) - + f". {loc.localize('eval-tab-error-missing-col-info')}" - ) from e + raise gr.Error(f"{loc.localize('eval-tab-error-missing-col')}: " + str(e) + f". {loc.localize('eval-tab-error-missing-col-info')}") from e except Exception as e: print(f"Error initializing processor: {e}") @@ -311,16 +309,8 @@ def update_selections( state = ProcessorState(proc, annotation_dir, prediction_dir) # If no current selection exists, default to all available classes/recordings; # otherwise, preserve any selections that are still valid. - new_classes = ( - avail_classes - if not current_classes - else [c for c in current_classes if c in avail_classes] or avail_classes - ) - new_recordings = ( - avail_recordings - if not current_recordings - else [r for r in current_recordings if r in avail_recordings] or avail_recordings - ) + new_classes = avail_classes if not current_classes else [c for c in current_classes if c in avail_classes] or avail_classes + new_recordings = avail_recordings if not current_recordings else [r for r in current_recordings if r in avail_recordings] or avail_recordings return ( gr.update(choices=avail_classes, value=new_classes), @@ -360,7 +350,7 @@ def get_selection_tables(directory): # Update column dropdowns when files are uploaded. def update_annotation_columns(uploaded_files): cols = get_columns_from_uploaded_files(uploaded_files) - cols = [""] + cols + cols = ["", *cols] updates = [] for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]: @@ -372,7 +362,7 @@ def update_annotation_columns(uploaded_files): def update_prediction_columns(uploaded_files): cols = get_columns_from_uploaded_files(uploaded_files) - cols = [""] + cols + cols = ["", *cols] updates = [] for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]: @@ -389,7 +379,7 @@ def select_directory_on_empty(): # Nishant - Function modified for For Folder s if folder: files = get_selection_tables(folder) files_to_display = files[:100] + [["..."]] if len(files) > 100 else files - return [files, files_to_display, gr.update(visible=True)] + on_select(files) + return [files, files_to_display, gr.update(visible=True), *on_select(files)] return ["", [[loc.localize("eval-tab-no-files-found")]]] @@ -415,115 +405,112 @@ def select_directory_on_empty(): # Nishant - Function modified for For Folder s ) # ----------------------- Annotations Columns Box ----------------------- - with gr.Group(visible=False) as annotation_group: - with gr.Accordion(loc.localize("eval-tab-annotation-col-accordion-label"), open=True): - with gr.Row(): - annotation_columns: dict[str, gr.Dropdown] = {} + with ( + gr.Group(visible=False) as annotation_group, + gr.Accordion(loc.localize("eval-tab-annotation-col-accordion-label"), open=True), + gr.Row(), + ): + annotation_columns: dict[str, gr.Dropdown] = {} - for col in ["Start Time", "End Time", "Class", "Recording", "Duration"]: - annotation_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col]) + for col in ["Start Time", "End Time", "Class", "Recording", "Duration"]: + annotation_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col]) # ----------------------- Predictions Columns Box ----------------------- - with gr.Group(visible=False) as prediction_group: - with gr.Accordion(loc.localize("eval-tab-prediction-col-accordion-label"), open=True): - with gr.Row(): - prediction_columns: dict[str, gr.Dropdown] = {} + with ( + gr.Group(visible=False) as prediction_group, + gr.Accordion(loc.localize("eval-tab-prediction-col-accordion-label"), open=True), + gr.Row(), + ): + prediction_columns: dict[str, gr.Dropdown] = {} - for col in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]: - prediction_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col]) + for col in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]: + prediction_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col]) # ----------------------- Class Mapping Box ----------------------- with gr.Group(visible=False) as mapping_group: - with gr.Accordion(loc.localize("eval-tab-class-mapping-accordion-label"), open=False): - with gr.Row(): - mapping_file = gr.File( - label=loc.localize("eval-tab-upload-mapping-file-label"), - file_count="single", - file_types=[".json"], - ) - download_mapping_button = gr.DownloadButton( - label=loc.localize("eval-tab-mapping-file-template-download-button-label") - ) + with gr.Accordion(loc.localize("eval-tab-class-mapping-accordion-label"), open=False), gr.Row(): + mapping_file = gr.File( + label=loc.localize("eval-tab-upload-mapping-file-label"), + file_count="single", + file_types=[".json"], + ) + download_mapping_button = gr.DownloadButton(label=loc.localize("eval-tab-mapping-file-template-download-button-label")) download_mapping_button.click(fn=download_class_mapping_template) # ----------------------- Classes and Recordings Selection Box ----------------------- - with gr.Group(visible=False) as class_recording_group: - with gr.Accordion(loc.localize("eval-tab-select-classes-recordings-accordion-label"), open=False): - with gr.Row(): - with gr.Column(): - select_classes_checkboxgroup = gr.CheckboxGroup( - choices=[], - value=[], - label=loc.localize("eval-tab-select-classes-checkboxgroup-label"), - info=loc.localize("eval-tab-select-classes-checkboxgroup-info"), - interactive=True, - elem_classes="custom-checkbox-group", - ) - - with gr.Column(): - select_recordings_checkboxgroup = gr.CheckboxGroup( - choices=[], - value=[], - label=loc.localize("eval-tab-select-recordings-checkboxgroup-label"), - info=loc.localize("eval-tab-select-recordings-checkboxgroup-info"), - interactive=True, - elem_classes="custom-checkbox-group", - ) + with ( + gr.Group(visible=False) as class_recording_group, + gr.Accordion(loc.localize("eval-tab-select-classes-recordings-accordion-label"), open=False), + gr.Row(), + ): + with gr.Column(): + select_classes_checkboxgroup = gr.CheckboxGroup( + choices=[], + value=[], + label=loc.localize("eval-tab-select-classes-checkboxgroup-label"), + info=loc.localize("eval-tab-select-classes-checkboxgroup-info"), + interactive=True, + elem_classes="custom-checkbox-group", + ) + + with gr.Column(): + select_recordings_checkboxgroup = gr.CheckboxGroup( + choices=[], + value=[], + label=loc.localize("eval-tab-select-recordings-checkboxgroup-label"), + info=loc.localize("eval-tab-select-recordings-checkboxgroup-info"), + interactive=True, + elem_classes="custom-checkbox-group", + ) # ----------------------- Parameters Box ----------------------- - with gr.Group(): - with gr.Accordion(loc.localize("eval-tab-parameters-accordion-label"), open=False): - with gr.Row(): - sample_duration = gr.Number( - value=3, - label=loc.localize("eval-tab-sample-duration-number-label"), - precision=0, - info=loc.localize("eval-tab-sample-duration-number-info"), - ) - recording_duration = gr.Textbox( - label=loc.localize("eval-tab-recording-duration-textbox-label"), - placeholder=loc.localize("eval-tab-recording-duration-textbox-placeholder"), - info=loc.localize("eval-tab-recording-duration-textbox-info"), - ) - min_overlap = gr.Number( - value=0.5, - label=loc.localize("eval-tab-min-overlap-number-label"), - info=loc.localize("eval-tab-min-overlap-number-info"), - ) - threshold = gr.Slider( - minimum=0.01, - maximum=0.99, - value=0.1, - label=loc.localize("eval-tab-threshold-number-label"), - info=loc.localize("eval-tab-threshold-number-info"), - ) - class_wise = gr.Checkbox( - label=loc.localize("eval-tab-classwise-checkbox-label"), - value=False, - info=loc.localize("eval-tab-classwise-checkbox-info"), - ) + with gr.Group(), gr.Accordion(loc.localize("eval-tab-parameters-accordion-label"), open=False), gr.Row(): + sample_duration = gr.Number( + value=3, + label=loc.localize("eval-tab-sample-duration-number-label"), + precision=0, + info=loc.localize("eval-tab-sample-duration-number-info"), + ) + recording_duration = gr.Textbox( + label=loc.localize("eval-tab-recording-duration-textbox-label"), + placeholder=loc.localize("eval-tab-recording-duration-textbox-placeholder"), + info=loc.localize("eval-tab-recording-duration-textbox-info"), + ) + min_overlap = gr.Number( + value=0.5, + label=loc.localize("eval-tab-min-overlap-number-label"), + info=loc.localize("eval-tab-min-overlap-number-info"), + ) + threshold = gr.Slider( + minimum=0.01, + maximum=0.99, + value=0.1, + label=loc.localize("eval-tab-threshold-number-label"), + info=loc.localize("eval-tab-threshold-number-info"), + ) + class_wise = gr.Checkbox( + label=loc.localize("eval-tab-classwise-checkbox-label"), + value=False, + info=loc.localize("eval-tab-classwise-checkbox-info"), + ) # ----------------------- Metrics Box ----------------------- - with gr.Group(): - with gr.Accordion(loc.localize("eval-tab-metrics-accordian-label"), open=False): - with gr.Row(): - metric_info = { - "AUROC": loc.localize("eval-tab-auroc-checkbox-info"), - "Precision": loc.localize("eval-tab-precision-checkbox-info"), - "Recall": loc.localize("eval-tab-recall-checkbox-info"), - "F1 Score": loc.localize("eval-tab-f1-score-checkbox-info"), - "Average Precision (AP)": loc.localize("eval-tab-ap-checkbox-info"), - "Accuracy": loc.localize("eval-tab-accuracy-checkbox-info"), - } - metrics_checkboxes = {} - - for metric_name, description in metric_info.items(): - metrics_checkboxes[metric_name.lower()] = gr.Checkbox( - label=metric_name, value=True, info=description - ) - - # ----------------------- Actions Box ----------------------- + with gr.Group(), gr.Accordion(loc.localize("eval-tab-metrics-accordian-label"), open=False), gr.Row(): + metric_info = { + "AUROC": loc.localize("eval-tab-auroc-checkbox-info"), + "Precision": loc.localize("eval-tab-precision-checkbox-info"), + "Recall": loc.localize("eval-tab-recall-checkbox-info"), + "F1 Score": loc.localize("eval-tab-f1-score-checkbox-info"), + "Average Precision (AP)": loc.localize("eval-tab-ap-checkbox-info"), + "Accuracy": loc.localize("eval-tab-accuracy-checkbox-info"), + } + metrics_checkboxes = {} + + for metric_name, description in metric_info.items(): + metrics_checkboxes[metric_name.lower()] = gr.Checkbox(label=metric_name, value=True, info=description) + + # ----------------------- Actions Box ----------------------- calculate_button = gr.Button(loc.localize("eval-tab-calculate-metrics-button-label"), variant="huggingface") @@ -531,9 +518,7 @@ def select_directory_on_empty(): # Nishant - Function modified for For Folder s with gr.Row(): plot_metrics_button = gr.Button(loc.localize("eval-tab-plot-metrics-button-label")) plot_confusion_button = gr.Button(loc.localize("eval-tab-plot-confusion-matrix-button-label")) - plot_metrics_all_thresholds_button = gr.Button( - loc.localize("eval-tab-plot-metrics-all-thresholds-button-label") - ) + plot_metrics_all_thresholds_button = gr.Button(loc.localize("eval-tab-plot-metrics-all-thresholds-button-label")) with gr.Row(): download_results_button = gr.DownloadButton(loc.localize("eval-tab-result-table-download-button-label")) @@ -607,7 +592,7 @@ def calculate_metrics( ): selected_metrics = [] - for value, (m_lower, _) in zip(metrics_checkbox_values, metrics_checkboxes.items()): + for value, (m_lower, _) in zip(metrics_checkbox_values, metrics_checkboxes.items(), strict=True): if value: selected_metrics.append(m_lower) @@ -711,8 +696,8 @@ def calculate_metrics( select_classes_checkboxgroup, select_recordings_checkboxgroup, processor_state, - ] - + [checkbox for checkbox in metrics_checkboxes.values()], + *list(metrics_checkboxes.values()), + ], outputs=[ metric_table, action_col, @@ -769,10 +754,7 @@ def plot_confusion_matrix(pa: PerformanceAssessor, predictions, labels): prediction_select_directory_btn.click( get_selection_func("eval-predictions-dir", update_prediction_columns), outputs=[prediction_files_state, prediction_directory_input, prediction_group] - + [ - prediction_columns[label] - for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"] - ], + + [prediction_columns[label] for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]], show_progress=True, ) diff --git a/birdnet_analyzer/gui/localization.py b/birdnet_analyzer/gui/localization.py index fb9a30f8..b27d7a2e 100644 --- a/birdnet_analyzer/gui/localization.py +++ b/birdnet_analyzer/gui/localization.py @@ -1,7 +1,8 @@ +# ruff: noqa: PLW0603 import json import os -import birdnet_analyzer.gui.settings as settings +from birdnet_analyzer.gui import settings SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) LANGUAGE_DIR = os.path.join(os.path.dirname(SCRIPT_DIR), "lang") @@ -20,20 +21,25 @@ def load_local_state(): settings.ensure_settings_file() try: - TARGET_LANGUAGE = json.load(open(settings.GUI_SETTINGS_PATH, encoding="utf-8"))["language-id"] + with open(settings.GUI_SETTINGS_PATH, encoding="utf-8") as f: + settings_data = json.load(f) + + if "language-id" in settings_data: + TARGET_LANGUAGE = settings_data["language-id"] except FileNotFoundError: print(f"gui-settings.json not found. Using fallback language {settings.FALLBACK_LANGUAGE}.") try: - with open(f"{LANGUAGE_DIR}/{TARGET_LANGUAGE}.json", "r", encoding="utf-8") as f: + with open(f"{LANGUAGE_DIR}/{TARGET_LANGUAGE}.json", encoding="utf-8") as f: LANGUAGE_LOOKUP = json.load(f) except FileNotFoundError: print( - f"Language file for {TARGET_LANGUAGE} not found in {LANGUAGE_DIR}. Using fallback language {settings.FALLBACK_LANGUAGE}." + f"Language file for {TARGET_LANGUAGE} not found in {LANGUAGE_DIR}." + + "Using fallback language {settings.FALLBACK_LANGUAGE}." ) if TARGET_LANGUAGE != settings.FALLBACK_LANGUAGE: - with open(f"{LANGUAGE_DIR}/{settings.FALLBACK_LANGUAGE}.json", "r") as f: + with open(f"{LANGUAGE_DIR}/{settings.FALLBACK_LANGUAGE}.json") as f: fallback: dict = json.load(f) for key, value in fallback.items(): @@ -49,7 +55,8 @@ def localize(key: str) -> str: key (str): The key to be localized. Returns: - str: The localized string corresponding to the given key. If the key is not found in the localization lookup, the original key is returned. + str: The localized string corresponding to the given key. + If the key is not found in the localization lookup, the original key is returned. """ return LANGUAGE_LOOKUP.get(key, key) diff --git a/birdnet_analyzer/gui/multi_file.py b/birdnet_analyzer/gui/multi_file.py index 2f311982..1e8bfe8b 100644 --- a/birdnet_analyzer/gui/multi_file.py +++ b/birdnet_analyzer/gui/multi_file.py @@ -1,3 +1,4 @@ + # ruff: noqa: I001 import gradio as gr import birdnet_analyzer.config as cfg @@ -75,7 +76,7 @@ def run_batch_analysis( custom_classifier_file, output_type, combine_tables, - "en" if not locale else locale, + locale if locale else "en", batch_size if batch_size and batch_size > 0 else 1, threads if threads and threads > 0 else 4, input_dir, @@ -158,29 +159,27 @@ def select_directory_wrapper(): # Nishant - Function modified for For Folder se map_plot, ) = gu.species_lists() - with gr.Accordion(loc.localize("multi-tab-output-accordion-label"), open=True): - with gr.Group(): - output_type_radio = gr.CheckboxGroup( - list(OUTPUT_TYPE_MAP.items()), - value="table", - label=loc.localize("multi-tab-output-radio-label"), - info=loc.localize("multi-tab-output-radio-info"), + with gr.Accordion(loc.localize("multi-tab-output-accordion-label"), open=True), gr.Group(): + output_type_radio = gr.CheckboxGroup( + list(OUTPUT_TYPE_MAP.items()), + value="table", + label=loc.localize("multi-tab-output-radio-label"), + info=loc.localize("multi-tab-output-radio-info"), + ) + + with gr.Row(): + combine_tables_checkbox = gr.Checkbox( + False, + label=loc.localize("multi-tab-output-combine-tables-checkbox-label"), + info=loc.localize("multi-tab-output-combine-tables-checkbox-info"), ) - with gr.Row(): - with gr.Column(): - combine_tables_checkbox = gr.Checkbox( - False, - label=loc.localize("multi-tab-output-combine-tables-checkbox-label"), - info=loc.localize("multi-tab-output-combine-tables-checkbox-info"), - ) - - with gr.Row(): - skip_existing_checkbox = gr.Checkbox( - False, - label=loc.localize("multi-tab-skip-existing-checkbox-label"), - info=loc.localize("multi-tab-skip-existing-checkbox-info"), - ) + with gr.Row(): + skip_existing_checkbox = gr.Checkbox( + False, + label=loc.localize("multi-tab-skip-existing-checkbox-label"), + info=loc.localize("multi-tab-skip-existing-checkbox-info"), + ) with gr.Row(): batch_size_number = gr.Number( diff --git a/birdnet_analyzer/gui/review.py b/birdnet_analyzer/gui/review.py index 10aaa0ef..29608374 100644 --- a/birdnet_analyzer/gui/review.py +++ b/birdnet_analyzer/gui/review.py @@ -7,9 +7,9 @@ import gradio as gr import birdnet_analyzer.config as cfg -import birdnet_analyzer.gui.utils as gu import birdnet_analyzer.gui.localization as loc -import birdnet_analyzer.utils as utils +import birdnet_analyzer.gui.utils as gu +from birdnet_analyzer import utils POSITIVE_LABEL_DIR = "Positive" NEGATIVE_LABEL_DIR = "Negative" @@ -24,9 +24,9 @@ def collect_segments(directory, shuffle=False): entry.path for entry in os.scandir(directory) if ( - entry.is_file() and - not entry.name.startswith(".") and - entry.name.rsplit(".", 1)[-1] in cfg.ALLOWED_FILETYPES + entry.is_file() + and not entry.name.startswith(".") + and entry.name.rsplit(".", 1)[-1] in cfg.ALLOWED_FILETYPES ) ] if os.path.isdir(directory) @@ -94,7 +94,7 @@ def create_log_plot(positives, negatives, fig_num=None): ] p_colors = ["blue", "purple", "orange", "green"] - for target_p, p_color, threshold in zip(target_ps, p_colors, thresholds): + for target_p, p_color, threshold in zip(target_ps, p_colors, thresholds, strict=True): if threshold <= 1: ax.vlines( threshold, @@ -150,49 +150,42 @@ def create_log_plot(positives, negatives, fig_num=None): elem_id="segments-results-grid", ) - with gr.Column() as review_item_col: - with gr.Row(): - with gr.Column(): - with gr.Group(): - spectrogram_image = gr.Plot( - label=loc.localize("review-tab-spectrogram-plot-label"), show_label=False - ) - # with gr.Row(): - spectrogram_dl_btn = gr.Button("Download spectrogram", size="sm") - - with gr.Column(): - positive_btn = gr.Button( - loc.localize("review-tab-pos-button-label"), - elem_id="positive-button", - variant="huggingface", - icon=os.path.join(SCRIPT_DIR, "assets/arrow_up.svg"), + with gr.Column() as review_item_col, gr.Row(): + with gr.Column(), gr.Group(): + spectrogram_image = gr.Plot( + label=loc.localize("review-tab-spectrogram-plot-label"), show_label=False + ) + spectrogram_dl_btn = gr.Button("Download spectrogram", size="sm") + + with gr.Column(): + positive_btn = gr.Button( + loc.localize("review-tab-pos-button-label"), + elem_id="positive-button", + variant="huggingface", + icon=os.path.join(SCRIPT_DIR, "assets/arrow_up.svg"), + ) + negative_btn = gr.Button( + loc.localize("review-tab-neg-button-label"), + elem_id="negative-button", + variant="huggingface", + icon=os.path.join(SCRIPT_DIR, "assets/arrow_down.svg"), + ) + + with gr.Row(): + undo_btn = gr.Button( + loc.localize("review-tab-undo-button-label"), + elem_id="undo-button", + icon=os.path.join(SCRIPT_DIR, "assets/arrow_left.svg"), ) - negative_btn = gr.Button( - loc.localize("review-tab-neg-button-label"), - elem_id="negative-button", - variant="huggingface", - icon=os.path.join(SCRIPT_DIR, "assets/arrow_down.svg"), + skip_btn = gr.Button( + loc.localize("review-tab-skip-button-label"), + elem_id="skip-button", + icon=os.path.join(SCRIPT_DIR, "assets/arrow_right.svg"), ) - with gr.Row(): - undo_btn = gr.Button( - loc.localize("review-tab-undo-button-label"), - elem_id="undo-button", - icon=os.path.join(SCRIPT_DIR, "assets/arrow_left.svg"), - ) - skip_btn = gr.Button( - loc.localize("review-tab-skip-button-label"), - elem_id="skip-button", - icon=os.path.join(SCRIPT_DIR, "assets/arrow_right.svg"), - ) - - with gr.Group(): - review_audio = gr.Audio( - type="filepath", sources=[], show_download_button=False, autoplay=True - ) - autoplay_checkbox = gr.Checkbox( - True, label=loc.localize("review-tab-autoplay-checkbox-label") - ) + with gr.Group(): + review_audio = gr.Audio(type="filepath", sources=[], show_download_button=False, autoplay=True) + autoplay_checkbox = gr.Checkbox(True, label=loc.localize("review-tab-autoplay-checkbox-label")) no_samles_label = gr.Label(loc.localize("review-tab-no-files-label"), visible=False, show_label=False) with gr.Group(): @@ -239,12 +232,12 @@ def update_values(next_review_state, skip_plot=False): return update_dict - def next_review(next_review_state: dict, target_dir: str = None): + def next_review(next_review_state: dict, target_dir: str | None = None): try: current_file = next_review_state["files"][0] - except IndexError: + except IndexError as e: if next_review_state["input_directory"]: - raise gr.Error(loc.localize("review-tab-no-files-error")) + raise gr.Error(loc.localize("review-tab-no-files-error")) from e return {review_state: next_review_state} @@ -276,8 +269,8 @@ def next_review(next_review_state: dict, target_dir: str = None): def select_subdir(new_value: str, next_review_state: dict): if new_value != next_review_state["current_species"]: return update_review(next_review_state, selected_species=new_value) - else: - return {review_state: next_review_state} + + return {review_state: next_review_state} def start_review(next_review_state): dir_name = gu.select_folder(state_key="review-input-dir") @@ -287,15 +280,14 @@ def start_review(next_review_state): specieslist = [ e.name for e in os.scandir(next_review_state["input_directory"]) - if e.is_dir() and e.name != POSITIVE_LABEL_DIR and e.name != NEGATIVE_LABEL_DIR + if e.is_dir() and e.name not in (POSITIVE_LABEL_DIR, NEGATIVE_LABEL_DIR) ] next_review_state["species_list"] = specieslist return update_review(next_review_state) - else: - return {review_state: next_review_state} + return {review_state: next_review_state} def try_confidence(filename): try: @@ -308,7 +300,7 @@ def try_confidence(filename): except ValueError: return 0 - def update_review(next_review_state: dict, selected_species: str = None): + def update_review(next_review_state: dict, selected_species: str | None = None): next_review_state["history"] = [] next_review_state["skipped"] = [] @@ -488,7 +480,7 @@ def download_plot(plot, filename=""): inputs=review_state, outputs=review_btn_output, show_progress=True, - show_progress_on=review_audio + show_progress_on=review_audio, ) negative_btn.click( @@ -496,7 +488,7 @@ def download_plot(plot, filename=""): inputs=review_state, outputs=review_btn_output, show_progress=True, - show_progress_on=review_audio + show_progress_on=review_audio, ) skip_btn.click( @@ -504,7 +496,7 @@ def download_plot(plot, filename=""): inputs=review_state, outputs=review_btn_output, show_progress=True, - show_progress_on=review_audio + show_progress_on=review_audio, ) undo_btn.click( @@ -512,7 +504,7 @@ def download_plot(plot, filename=""): inputs=review_state, outputs=review_btn_output, show_progress=True, - show_progress_on=review_audio + show_progress_on=review_audio, ) select_directory_btn.click( diff --git a/birdnet_analyzer/gui/segments.py b/birdnet_analyzer/gui/segments.py index d029e97f..cb36c628 100644 --- a/birdnet_analyzer/gui/segments.py +++ b/birdnet_analyzer/gui/segments.py @@ -5,11 +5,11 @@ import gradio as gr import birdnet_analyzer.config as cfg -import birdnet_analyzer.gui.utils as gu import birdnet_analyzer.gui.localization as loc - +import birdnet_analyzer.gui.utils as gu from birdnet_analyzer.segments.utils import extract_segments + def extract_segments_wrapper(entry): return (entry[0][0], extract_segments(entry)) @@ -18,7 +18,7 @@ def extract_segments_wrapper(entry): def _extract_segments( audio_dir, result_dir, output_dir, min_conf, num_seq, audio_speed, seq_length, threads, progress=gr.Progress() ): - from birdnet_analyzer.segments.utils import parse_folders, parse_files + from birdnet_analyzer.segments.utils import parse_files, parse_folders gu.validate(audio_dir, loc.localize("validation-no-audio-directory-selected")) diff --git a/birdnet_analyzer/gui/settings.py b/birdnet_analyzer/gui/settings.py index f3b904f1..32b7a9f5 100644 --- a/birdnet_analyzer/gui/settings.py +++ b/birdnet_analyzer/gui/settings.py @@ -1,10 +1,9 @@ +import json import os from pathlib import Path -import json import birdnet_analyzer.gui.utils as gu -import birdnet_analyzer.utils as utils - +from birdnet_analyzer import utils FALLBACK_LANGUAGE = "en" SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) @@ -24,7 +23,7 @@ def get_state_dict() -> dict: dict: The state dictionary loaded from the JSON file, or an empty dictionary if the file does not exist or an error occurs. """ try: - with open(STATE_SETTINGS_PATH, "r", encoding="utf-8") as f: + with open(STATE_SETTINGS_PATH, encoding="utf-8") as f: return json.load(f) except FileNotFoundError: try: @@ -98,7 +97,7 @@ def get_setting(key, default=None): ensure_settings_file() try: - with open(GUI_SETTINGS_PATH, "r", encoding="utf-8") as f: + with open(GUI_SETTINGS_PATH, encoding="utf-8") as f: settings_dict: dict = json.load(f) return settings_dict.get(key, default) diff --git a/birdnet_analyzer/gui/single_file.py b/birdnet_analyzer/gui/single_file.py index b835fc2c..af74fd98 100644 --- a/birdnet_analyzer/gui/single_file.py +++ b/birdnet_analyzer/gui/single_file.py @@ -2,11 +2,10 @@ import gradio as gr -import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg import birdnet_analyzer.gui.localization as loc import birdnet_analyzer.gui.utils as gu -import birdnet_analyzer.utils as utils +from birdnet_analyzer import audio, utils @gu.gui_runtime_error_handler @@ -66,7 +65,7 @@ def run_single_file_analysis( custom_classifier_file, "csv", None, - "en" if not locale else locale, + locale if locale else "en", 1, 4, None, @@ -79,7 +78,7 @@ def run_single_file_analysis( raise gr.Error(loc.localize("single-tab-analyze-file-error")) # read the result file to return the data to be displayed. - with open(result_filepath, "r", encoding="utf-8") as f: + with open(result_filepath, encoding="utf-8") as f: reader = csv.reader(f) data = list(reader) data = [lc[0:-1] for lc in data[1:]] # remove last column (file path) and first row (header) @@ -176,8 +175,8 @@ def get_audio_path(i, generate_spectrogram): else gr.Plot(visible=False), gr.Button(interactive=True), ) - except: - raise gr.Error(loc.localize("single-tab-generate-spectrogram-error")) + except Exception as e: + raise gr.Error(loc.localize("single-tab-generate-spectrogram-error")) from e else: return None, None, gr.Plot(visible=False), gr.update(interactive=False) @@ -187,8 +186,8 @@ def try_generate_spectrogram(audio_path, generate_spectrogram): return gr.Plot( visible=True, value=utils.spectrogram_from_file(audio_path["path"], fig_size=(20, 4)) ) - except: - raise gr.Error(loc.localize("single-tab-generate-spectrogram-error")) + except Exception as e: + raise gr.Error(loc.localize("single-tab-generate-spectrogram-error")) from e else: return gr.Plot() @@ -231,8 +230,7 @@ def try_generate_spectrogram(audio_path, generate_spectrogram): def time_to_seconds(time_str): try: hours, minutes, seconds = time_str.split(":") - total_seconds = int(hours) * 3600 + int(minutes) * 60 + float(seconds) - return total_seconds + return int(hours) * 3600 + int(minutes) * 60 + float(seconds) except ValueError as e: raise ValueError("Input must be in the format hh:mm:ss or hh:mm:ss.ssssss with numeric values.") from e diff --git a/birdnet_analyzer/gui/species.py b/birdnet_analyzer/gui/species.py index 7b259414..77a389dc 100644 --- a/birdnet_analyzer/gui/species.py +++ b/birdnet_analyzer/gui/species.py @@ -3,9 +3,9 @@ import gradio as gr import birdnet_analyzer.config as cfg -import birdnet_analyzer.gui.utils as gu import birdnet_analyzer.gui.localization as loc -import birdnet_analyzer.gui.settings as settings +import birdnet_analyzer.gui.utils as gu +from birdnet_analyzer.gui import settings @gu.gui_runtime_error_handler diff --git a/birdnet_analyzer/gui/train.py b/birdnet_analyzer/gui/train.py index fca80bbb..42cb2eb8 100644 --- a/birdnet_analyzer/gui/train.py +++ b/birdnet_analyzer/gui/train.py @@ -8,7 +8,7 @@ import birdnet_analyzer.config as cfg import birdnet_analyzer.gui.localization as loc import birdnet_analyzer.gui.utils as gu -import birdnet_analyzer.utils as utils +from birdnet_analyzer import utils _GRID_MAX_HEIGHT = 240 @@ -214,9 +214,9 @@ def trial_progression(trial): history, metrics = history_result except Exception as e: if e.args and len(e.args) > 1: - raise gr.Error(loc.localize(e.args[1])) - else: - raise gr.Error(f"{e}") + raise gr.Error(loc.localize(e.args[1])) from e + + raise gr.Error(f"{e}") from e if len(history.epoch) < epochs: gr.Info(loc.localize("training-tab-early-stoppage-msg")) @@ -457,20 +457,19 @@ def on_crop_select(new_crop_mode): info=loc.localize("training-tab-autotune-checkbox-info"), ) - with gr.Column(visible=False) as autotune_params: - with gr.Row(): - autotune_trials = gr.Number( - cfg.AUTOTUNE_TRIALS, - label=loc.localize("training-tab-autotune-trials-number-label"), - info=loc.localize("training-tab-autotune-trials-number-info"), - minimum=1, - ) - autotune_executions_per_trials = gr.Number( - cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL, - minimum=1, - label=loc.localize("training-tab-autotune-executions-number-label"), - info=loc.localize("training-tab-autotune-executions-number-info"), - ) + with gr.Column(visible=False) as autotune_params, gr.Row(): + autotune_trials = gr.Number( + cfg.AUTOTUNE_TRIALS, + label=loc.localize("training-tab-autotune-trials-number-label"), + info=loc.localize("training-tab-autotune-trials-number-info"), + minimum=1, + ) + autotune_executions_per_trials = gr.Number( + cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL, + minimum=1, + label=loc.localize("training-tab-autotune-executions-number-label"), + info=loc.localize("training-tab-autotune-executions-number-info"), + ) with gr.Column() as custom_params: with gr.Row(): @@ -554,26 +553,25 @@ def on_crop_select(new_crop_mode): show_label=True, ) - with gr.Row(visible=False) as focal_loss_params: - with gr.Column(): - focal_loss_gamma = gr.Slider( - minimum=0.5, - maximum=5.0, - value=cfg.FOCAL_LOSS_GAMMA, - step=0.1, - label=loc.localize("training-tab-focal-loss-gamma-slider-label"), - info=loc.localize("training-tab-focal-loss-gamma-slider-info"), - interactive=True, - ) - focal_loss_alpha = gr.Slider( - minimum=0.1, - maximum=0.9, - value=cfg.FOCAL_LOSS_ALPHA, - step=0.05, - label=loc.localize("training-tab-focal-loss-alpha-slider-label"), - info=loc.localize("training-tab-focal-loss-alpha-slider-info"), - interactive=True, - ) + with gr.Row(visible=False) as focal_loss_params, gr.Row(): + focal_loss_gamma = gr.Slider( + minimum=0.5, + maximum=5.0, + value=cfg.FOCAL_LOSS_GAMMA, + step=0.1, + label=loc.localize("training-tab-focal-loss-gamma-slider-label"), + info=loc.localize("training-tab-focal-loss-gamma-slider-info"), + interactive=True, + ) + focal_loss_alpha = gr.Slider( + minimum=0.1, + maximum=0.9, + value=cfg.FOCAL_LOSS_ALPHA, + step=0.05, + label=loc.localize("training-tab-focal-loss-alpha-slider-label"), + info=loc.localize("training-tab-focal-loss-alpha-slider-info"), + interactive=True, + ) def on_focal_loss_change(value): return gr.Row(visible=value) @@ -653,9 +651,9 @@ def train_and_show_metrics(*args): ) return history, gr.Dataframe(visible=True, value=table_data) - else: - # No metrics available, just return history and hide table - return history, gr.Dataframe(visible=False) + + # No metrics available, just return history and hide table + return history, gr.Dataframe(visible=False) start_training_button.click( train_and_show_metrics, diff --git a/birdnet_analyzer/gui/utils.py b/birdnet_analyzer/gui/utils.py index b66d3082..316780b8 100644 --- a/birdnet_analyzer/gui/utils.py +++ b/birdnet_analyzer/gui/utils.py @@ -1,3 +1,4 @@ +# ruff: noqa: PLW0603 import multiprocessing import os import sys @@ -9,8 +10,7 @@ import webview import birdnet_analyzer.config as cfg -import birdnet_analyzer.utils as utils - +from birdnet_analyzer import utils if utils.FROZEN: # divert stdout & stderr to logs.txt file since we have no console when deployed @@ -27,11 +27,11 @@ APPDIR.mkdir(parents=True, exist_ok=True) - sys.stderr = sys.stdout = open(str(APPDIR / "logs.txt"), "a") + sys.stderr = sys.stdout = open(str(APPDIR / "logs.txt"), "a") # noqa: SIM115 cfg.ERROR_LOG_FILE = str(APPDIR / os.path.basename(cfg.ERROR_LOG_FILE)) -import birdnet_analyzer.gui.settings as settings # noqa: E402 -import birdnet_analyzer.gui.localization as loc # noqa: E402 +import birdnet_analyzer.gui.localization as loc +from birdnet_analyzer.gui import settings loc.load_local_state() @@ -207,10 +207,11 @@ def build_header(): with gr.Row(): gr.Markdown( f""" -
- -

BirdNET Analyzer

-
+
+ +

BirdNET Analyzer

+
""" ) @@ -219,15 +220,17 @@ def build_footer(): with gr.Row(): gr.Markdown( f""" -
-
-
GUI version: {os.environ["GUI_VERSION"] if utils.FROZEN else "main"}
-
Model version: {cfg.MODEL_VERSION}
-
-
K. Lisa Yang Center for Conservation Bioacoustics
Chemnitz University of Technology
-
{loc.localize("footer-help")}:
birdnet.cornell.edu/analyzer
-
- """ +
+
+
GUI version: {os.environ["GUI_VERSION"] if utils.FROZEN else "main"}
+
Model version: {cfg.MODEL_VERSION}
+
+
K. Lisa Yang Center for Conservation Bioacoustics
Chemnitz University of Technology
+
{loc.localize("footer-help")}:
birdnet.cornell.edu/analyzer
+
""" ) @@ -286,7 +289,7 @@ def on_theme_change(value): def on_tab_select(value: gr.SelectData): if value.selected and os.path.exists(cfg.ERROR_LOG_FILE): - with open(cfg.ERROR_LOG_FILE, "r", encoding="utf-8") as f: + with open(cfg.ERROR_LOG_FILE, encoding="utf-8") as f: lines = f.readlines() last_100_lines = lines[-100:] return "".join(last_100_lines) @@ -306,7 +309,8 @@ def sample_sliders(opened=True): Returns: A tuple with the created elements: - (Slider (min confidence), Slider (sensitivity), Slider (overlap), Slider (audio speed), Number (fmin), Number (fmax)) + (Slider (min confidence), Slider (sensitivity), Slider (overlap), + Slider (audio speed), Number (fmin), Number (fmax)) """ with gr.Accordion(loc.localize("inference-settings-accordion-label"), open=opened): with gr.Group(): @@ -436,24 +440,23 @@ def plot_map_scatter_mapbox(lat, lon, zoom=4): def species_list_coordinates(show_map=False): with gr.Row(equal_height=True): - with gr.Column(scale=1): - with gr.Group(): - lat_number = gr.Slider( - minimum=-90, - maximum=90, - value=0, - step=1, - label=loc.localize("species-list-coordinates-lat-number-label"), - info=loc.localize("species-list-coordinates-lat-number-info"), - ) - lon_number = gr.Slider( - minimum=-180, - maximum=180, - value=0, - step=1, - label=loc.localize("species-list-coordinates-lon-number-label"), - info=loc.localize("species-list-coordinates-lon-number-info"), - ) + with gr.Column(scale=1), gr.Group(): + lat_number = gr.Slider( + minimum=-90, + maximum=90, + value=0, + step=1, + label=loc.localize("species-list-coordinates-lat-number-label"), + info=loc.localize("species-list-coordinates-lat-number-info"), + ) + lon_number = gr.Slider( + minimum=-180, + maximum=180, + value=0, + step=1, + label=loc.localize("species-list-coordinates-lon-number-label"), + info=loc.localize("species-list-coordinates-lon-number-info"), + ) map_plot = gr.Plot(plot_map_scatter_mapbox(0, 0), show_label=False, scale=2, visible=show_map) @@ -561,14 +564,14 @@ def show_species_choice(choice: str): gr.Column(visible=False), gr.Column(visible=False), ] - elif choice == _PREDICT_SPECIES: + if choice == _PREDICT_SPECIES: return [ gr.Row(visible=True), gr.File(visible=False), gr.Column(visible=False), gr.Column(visible=False), ] - elif choice == _CUSTOM_CLASSIFIER: + if choice == _CUSTOM_CLASSIFIER: return [ gr.Row(visible=False), gr.File(visible=False), @@ -592,72 +595,72 @@ def species_lists(opened=True): Returns: A tuple with the created elements: - (Radio (choice), File (custom species list), Slider (lat), Slider (lon), Slider (week), Slider (threshold), Checkbox (yearlong?), State (custom classifier)) + (Radio (choice), File (custom species list), Slider (lat), Slider (lon), + Slider (week), Slider (threshold), Checkbox (yearlong?), State (custom classifier)) """ - with gr.Accordion(loc.localize("species-list-accordion-label"), open=opened): - with gr.Row(): - species_list_radio = gr.Radio( - [_CUSTOM_SPECIES, _PREDICT_SPECIES, _CUSTOM_CLASSIFIER, _ALL_SPECIES], - value=_ALL_SPECIES, - label=loc.localize("species-list-radio-label"), - info=loc.localize("species-list-radio-info"), - elem_classes="d-block", - ) - - with gr.Column(visible=False) as position_row: - lat_number, lon_number, week_number, sf_thresh_number, yearlong_checkbox, map_plot = ( - species_list_coordinates() - ) + with gr.Accordion(loc.localize("species-list-accordion-label"), open=opened), gr.Row(): + species_list_radio = gr.Radio( + [_CUSTOM_SPECIES, _PREDICT_SPECIES, _CUSTOM_CLASSIFIER, _ALL_SPECIES], + value=_ALL_SPECIES, + label=loc.localize("species-list-radio-label"), + info=loc.localize("species-list-radio-info"), + elem_classes="d-block", + ) - species_file_input = gr.File( - file_types=[".txt"], visible=False, label=loc.localize("species-list-custom-list-file-label") + with gr.Column(visible=False) as position_row: + lat_number, lon_number, week_number, sf_thresh_number, yearlong_checkbox, map_plot = ( + species_list_coordinates() ) - empty_col = gr.Column() - with gr.Column(visible=False) as custom_classifier_selector: - classifier_selection_button = gr.Button( - loc.localize("species-list-custom-classifier-selection-button-label") - ) - classifier_file_input = gr.Files(file_types=[".tflite"], visible=False, interactive=False) - selected_classifier_state = gr.State() + species_file_input = gr.File( + file_types=[".txt"], visible=False, label=loc.localize("species-list-custom-list-file-label") + ) + empty_col = gr.Column() - def on_custom_classifier_selection_click(): - file = select_file(("TFLite classifier (*.tflite)",), state_key="custom_classifier_file") + with gr.Column(visible=False) as custom_classifier_selector: + classifier_selection_button = gr.Button( + loc.localize("species-list-custom-classifier-selection-button-label") + ) + classifier_file_input = gr.Files(file_types=[".tflite"], visible=False, interactive=False) + selected_classifier_state = gr.State() - if file: - labels = os.path.splitext(file)[0] + "_Labels.txt" + def on_custom_classifier_selection_click(): + file = select_file(("TFLite classifier (*.tflite)",), state_key="custom_classifier_file") - if not os.path.isfile(labels): - labels = file.replace("Model_FP32.tflite", "Labels.txt") + if file: + labels = os.path.splitext(file)[0] + "_Labels.txt" - return file, gr.File(value=[file, labels], visible=True) + if not os.path.isfile(labels): + labels = file.replace("Model_FP32.tflite", "Labels.txt") - return None, None + return file, gr.File(value=[file, labels], visible=True) - classifier_selection_button.click( - on_custom_classifier_selection_click, - outputs=[selected_classifier_state, classifier_file_input], - show_progress=False, - ) + return None, None - species_list_radio.change( - show_species_choice, - inputs=[species_list_radio], - outputs=[position_row, species_file_input, custom_classifier_selector, empty_col], + classifier_selection_button.click( + on_custom_classifier_selection_click, + outputs=[selected_classifier_state, classifier_file_input], show_progress=False, ) - return ( - species_list_radio, - species_file_input, - lat_number, - lon_number, - week_number, - sf_thresh_number, - yearlong_checkbox, - selected_classifier_state, - map_plot, - ) + species_list_radio.change( + show_species_choice, + inputs=[species_list_radio], + outputs=[position_row, species_file_input, custom_classifier_selector, empty_col], + show_progress=False, + ) + + return ( + species_list_radio, + species_file_input, + lat_number, + lon_number, + week_number, + sf_thresh_number, + yearlong_checkbox, + selected_classifier_state, + map_plot, + ) def _get_network_shortcuts(): @@ -748,9 +751,8 @@ def open_window(builder: list[Callable] | Callable): if callable(builder): map_plots.append(builder()) - elif isinstance(builder, (tuple, set, list)): - for build in builder: - map_plots.append(build()) + elif isinstance(builder, tuple | set | list): + map_plots.extend(build() for build in builder) build_settings() build_footer() diff --git a/birdnet_analyzer/model.py b/birdnet_analyzer/model.py index 76b82611..ebaa0387 100644 --- a/birdnet_analyzer/model.py +++ b/birdnet_analyzer/model.py @@ -1,3 +1,4 @@ +# ruff: noqa: PLW0603 """Contains functions to use the BirdNET models.""" import os @@ -7,10 +8,11 @@ import numpy as np import birdnet_analyzer.config as cfg -import birdnet_analyzer.utils as utils +from birdnet_analyzer import utils SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -34,10 +36,11 @@ C_PBMODEL = None EMPTY_CLASS_EXCEPTION_REF = None + def get_empty_class_exception(): import keras_tuner.errors - global EMPTY_CLASS_EXCEPTION_REF + global EMPTY_CLASS_EXCEPTION_REF if EMPTY_CLASS_EXCEPTION_REF: return EMPTY_CLASS_EXCEPTION_REF @@ -95,9 +98,7 @@ def mixup(x, y, augmentation_ratio=0.25, alpha=0.2): Returns: Augmented data. """ - - # Set numpy random seed - np.random.seed(cfg.RANDOM_SEED) + rng = np.random.default_rng(cfg.RANDOM_SEED) # Get indices of all positive samples positive_indices = np.unique(np.where(y[:, :] == 1)[0]) @@ -110,24 +111,24 @@ def mixup(x, y, augmentation_ratio=0.25, alpha=0.2): for _ in range(num_samples_to_augment): # Randomly choose one instance from the positive samples - index = np.random.choice(positive_indices) + index = rng.choice(positive_indices) # Choose another one, when the chosen one was already mixed up while index in mixed_up_indices: - index = np.random.choice(positive_indices) + index = rng.choice(positive_indices) x1, y1 = x[index], y[index] # Randomly choose a different instance from the dataset - second_index = np.random.choice(positive_indices) + second_index = rng.choice(positive_indices) # Choose again, when the same or an already mixed up sample was selected while second_index == index or second_index in mixed_up_indices: - second_index = np.random.choice(positive_indices) + second_index = rng.choice(positive_indices) x2, y2 = x[second_index], y[second_index] # Generate a random mixing coefficient (lambda) - lambda_ = np.random.beta(alpha, alpha) + lambda_ = rng.beta(alpha, alpha) # Mix the embeddings and labels mixed_x = lambda_ * x1 + (1 - lambda_) * x2 @@ -159,9 +160,7 @@ def random_split(x, y, val_ratio=0.2): Returns: A tuple of (x_train, y_train, x_val, y_val). """ - - # Set numpy random seed - np.random.seed(cfg.RANDOM_SEED) + rng = np.random.default_rng(cfg.RANDOM_SEED) # Get number of classes num_classes = y.shape[1] @@ -183,7 +182,7 @@ def random_split(x, y, val_ratio=0.2): num_samples_val = max(0, num_samples - num_samples_train) # Randomly choose samples for training and validation - np.random.shuffle(positive_indices) + rng.shuffle(positive_indices) train_indices = positive_indices[:num_samples_train] val_indices = positive_indices[num_samples_train : num_samples_train + num_samples_val] @@ -202,7 +201,7 @@ def random_split(x, y, val_ratio=0.2): num_samples = len(non_event_indices) num_samples_train = max(1, int(num_samples * (1 - val_ratio))) num_samples_val = max(0, num_samples - num_samples_train) - np.random.shuffle(non_event_indices) + rng.shuffle(non_event_indices) train_indices = non_event_indices[:num_samples_train] val_indices = non_event_indices[num_samples_train : num_samples_train + num_samples_val] x_train.append(x[train_indices]) @@ -218,12 +217,12 @@ def random_split(x, y, val_ratio=0.2): # Shuffle data indices = np.arange(len(x_train)) - np.random.shuffle(indices) + rng.shuffle(indices) x_train = x_train[indices] y_train = y_train[indices] indices = np.arange(len(x_val)) - np.random.shuffle(indices) + rng.shuffle(indices) x_val = x_val[indices] y_val = y_val[indices] @@ -244,9 +243,7 @@ def random_multilabel_split(x, y, val_ratio=0.2): A tuple of (x_train, y_train, x_val, y_val). """ - - # Set numpy random seed - np.random.seed(cfg.RANDOM_SEED) + rng = np.random.default_rng(cfg.RANDOM_SEED) # Find all combinations of labels class_combinations = np.unique(y, axis=0) @@ -270,7 +267,7 @@ def random_multilabel_split(x, y, val_ratio=0.2): num_samples_train = max(1, int(num_samples * (1 - val_ratio))) num_samples_val = max(0, num_samples - num_samples_train) # Randomly choose samples for training and validation - np.random.shuffle(indices) + rng.shuffle(indices) train_indices = indices[:num_samples_train] val_indices = indices[num_samples_train : num_samples_train + num_samples_val] # Append samples to training and validation data @@ -287,12 +284,12 @@ def random_multilabel_split(x, y, val_ratio=0.2): # Shuffle data indices = np.arange(len(x_train)) - np.random.shuffle(indices) + rng.shuffle(indices) x_train = x_train[indices] y_train = y_train[indices] indices = np.arange(len(x_val)) - np.random.shuffle(indices) + rng.shuffle(indices) x_val = x_val[indices] y_val = y_val[indices] @@ -311,19 +308,17 @@ def upsample_core(x: np.ndarray, y: np.ndarray, min_samples: int, apply: callabl Returns: tuple: A tuple containing the upsampled feature matrix and target labels. """ + rng = np.random.default_rng(cfg.RANDOM_SEED) y_temp = [] x_temp = [] if cfg.BINARY_CLASSIFICATION: # Determine if 1 or 0 is the minority class - if y.sum(axis=0) < len(y) - y.sum(axis=0): - minority_label = 1 - else: - minority_label = 0 + minority_label = 1 if y.sum(axis=0) < len(y) - y.sum(axis=0) else 0 while np.where(y == minority_label)[0].shape[0] + len(y_temp) < min_samples: # Randomly choose a sample from the minority class - random_index = np.random.choice(np.where(y == minority_label)[0], size=size) + random_index = rng.choice(np.where(y == minority_label)[0], size=size) # Apply SMOTE x_app, y_app = apply(x, y, random_index) @@ -334,7 +329,7 @@ def upsample_core(x: np.ndarray, y: np.ndarray, min_samples: int, apply: callabl while y[:, i].sum() + len(y_temp) < min_samples: try: # Randomly choose a sample from the minority class - random_index = np.random.choice(np.where(y[:, i] == 1)[0], size=size) + random_index = rng.choice(np.where(y[:, i] == 1)[0], size=size) except ValueError as e: raise get_empty_class_exception()(index=i) from e @@ -362,13 +357,14 @@ def upsampling(x: np.ndarray, y: np.ndarray, ratio=0.5, mode="repeat"): """ # Set numpy random seed - np.random.seed(cfg.RANDOM_SEED) + rng = np.random.default_rng(cfg.RANDOM_SEED) # Determine min number of samples - if cfg.BINARY_CLASSIFICATION: - min_samples = int(max(y.sum(axis=0), len(y) - y.sum(axis=0)) * ratio) - else: - min_samples = int(np.max(y.sum(axis=0)) * ratio) + min_samples = ( + int(max(y.sum(axis=0), len(y) - y.sum(axis=0)) * ratio) + if cfg.BINARY_CLASSIFICATION + else int(np.max(y.sum(axis=0)) * ratio) + ) x_temp = [] y_temp = [] @@ -397,7 +393,7 @@ def applyMean(x, y, random_indices): # select two random samples and calculate the linear combination def applyLinearCombination(x, y, random_indices): # Calculate the linear combination of the two samples - alpha = np.random.uniform(0, 1) + alpha = rng.uniform(0, 1) new_sample = alpha * x[random_indices[0]] + (1 - alpha) * x[random_indices[1]] # Append the new sample and label to a temp list @@ -413,13 +409,13 @@ def applySmote(x, y, random_index, k=5): indices = np.argsort(distances)[1 : k + 1] # Randomly choose one of the neighbors - random_neighbor = np.random.choice(indices) + random_neighbor = rng.choice(indices) # Calculate the difference vector diff = x[random_neighbor] - x[random_index[0]] # Randomly choose a weight between 0 and 1 - weight = np.random.uniform(0, 1) + weight = rng.uniform(0, 1) # Calculate the new sample new_sample = x[random_index[0]] + weight * diff @@ -436,7 +432,7 @@ def applySmote(x, y, random_index, k=5): # Shuffle data indices = np.arange(len(x)) - np.random.shuffle(indices) + rng.shuffle(indices) x = x[indices] y = y[indices] @@ -530,10 +526,7 @@ def load_model(class_output=True): INPUT_LAYER_INDEX = input_details[0]["index"] # Get classification output or feature embeddings - if class_output: - OUTPUT_LAYER_INDEX = output_details[0]["index"] - else: - OUTPUT_LAYER_INDEX = output_details[0]["index"] - 1 + OUTPUT_LAYER_INDEX = output_details[0]["index"] if class_output else output_details[0]["index"] - 1 else: # Load protobuf model @@ -623,10 +616,10 @@ def build_linear_classifier(num_labels, input_size, hidden_units=0, dropout=0.0) # Input layer model.add(keras.layers.InputLayer(input_shape=(input_size,))) - + # Batch normalization on input to standardize embeddings model.add(keras.layers.BatchNormalization()) - + # Optional L2 regularization for all dense layers regularizer = keras.regularizers.l2(1e-5) @@ -635,13 +628,14 @@ def build_linear_classifier(num_labels, input_size, hidden_units=0, dropout=0.0) # Dropout layer before hidden layer if dropout > 0: model.add(keras.layers.Dropout(dropout)) - + # Add a hidden layer with L2 regularization - model.add(keras.layers.Dense(hidden_units, - activation="relu", - kernel_regularizer=regularizer, - kernel_initializer='he_normal')) - + model.add( + keras.layers.Dense( + hidden_units, activation="relu", kernel_regularizer=regularizer, kernel_initializer="he_normal" + ) + ) + # Add another batch normalization after the hidden layer model.add(keras.layers.BatchNormalization()) @@ -650,9 +644,7 @@ def build_linear_classifier(num_labels, input_size, hidden_units=0, dropout=0.0) model.add(keras.layers.Dropout(dropout)) # Classification layer with L2 regularization - model.add(keras.layers.Dense(num_labels, - kernel_regularizer=regularizer, - kernel_initializer='glorot_uniform')) + model.add(keras.layers.Dense(num_labels, kernel_regularizer=regularizer, kernel_initializer="glorot_uniform")) # Activation layer model.add(keras.layers.Activation("sigmoid")) @@ -718,11 +710,11 @@ def on_epoch_end(self, epoch, logs=None): self.on_epoch_end_fn(epoch, logs) # Set random seed - np.random.seed(cfg.RANDOM_SEED) + rng = np.random.default_rng(cfg.RANDOM_SEED) # Shuffle data idx = np.arange(x_train.shape[0]) - np.random.shuffle(idx) + rng.shuffle(idx) x_train = x_train[idx] y_train = y_train[idx] @@ -732,10 +724,10 @@ def on_epoch_end(self, epoch, logs=None): x_train, y_train, x_val, y_val = random_split(x_train, y_train, val_split) else: x_train, y_train, x_val, y_val = random_multilabel_split(x_train, y_train, val_split) - else: + else: x_val = x_test y_val = y_test - + print( f"Training on {x_train.shape[0]} samples, validating on {x_val.shape[0]} samples.", flush=True, @@ -757,7 +749,7 @@ def on_epoch_end(self, epoch, logs=None): # Early stopping with patience depending on dataset size patience = min(10, max(5, int(epochs / 10))) min_delta = 0.001 - + callbacks = [ # EarlyStopping with restore_best_weights keras.callbacks.EarlyStopping( @@ -774,29 +766,26 @@ def on_epoch_end(self, epoch, logs=None): # Learning rate schedule - use cosine decay with warmup warmup_epochs = min(5, int(epochs * 0.1)) - total_steps = epochs * x_train.shape[0] / batch_size - warmup_steps = warmup_epochs * x_train.shape[0] / batch_size - + def lr_schedule(epoch, lr): if epoch < warmup_epochs: # Linear warmup return learning_rate * (epoch + 1) / warmup_epochs - else: - # Cosine decay - progress = (epoch - warmup_epochs) / (epochs - warmup_epochs) - return learning_rate * (0.1 + 0.9 * (1 + np.cos(np.pi * progress)) / 2) - + + # Cosine decay + progress = (epoch - warmup_epochs) / (epochs - warmup_epochs) + return learning_rate * (0.1 + 0.9 * (1 + np.cos(np.pi * progress)) / 2) + # Add LR scheduler callback callbacks.append(keras.callbacks.LearningRateScheduler(lr_schedule)) - + optimizer_cls = keras.optimizers.legacy.Adam if sys.platform == "darwin" else keras.optimizers.Adam + def _focal_loss(y_true, y_pred): + return focal_loss(y_true, y_pred, gamma=cfg.FOCAL_LOSS_GAMMA, alpha=cfg.FOCAL_LOSS_ALPHA) + # Choose the loss function based on config - loss_function = custom_loss - if train_with_focal_loss: - loss_function = lambda y_true, y_pred: focal_loss( - y_true, y_pred, gamma=cfg.FOCAL_LOSS_GAMMA, alpha=cfg.FOCAL_LOSS_ALPHA - ) + loss_function = _focal_loss if train_with_focal_loss else custom_loss # Compile model with appropriate metrics for classification task classifier.compile( @@ -871,7 +860,9 @@ def save_linear_classifier(classifier, model_path: str, labels: list[str], mode= # Save model as tflite converter = tf.lite.TFLiteConverter.from_keras_model(combined_model) tflite_model = converter.convert() - open(model_path, "wb").write(tflite_model) + + with open(model_path, "wb") as f: + f.write(tflite_model) if mode == "append": labels = [*utils.read_lines(os.path.join(SCRIPT_DIR, cfg.LABELS_FILE)), *labels] @@ -884,7 +875,7 @@ def save_linear_classifier(classifier, model_path: str, labels: list[str], mode= save_model_params(model_path.replace(".tflite", "_Params.csv")) -def save_raven_model(classifier, model_path, labels: list[str], mode="replace"): +def save_raven_model(classifier, model_path: str, labels: list[str], mode="replace"): """ Save a TensorFlow model with a custom classifier and associated metadata for use with BirdNET. @@ -945,7 +936,7 @@ def basic(self, inputs): # Save signature model os.makedirs(os.path.dirname(model_path), exist_ok=True) - model_path = model_path[:-7] if model_path.endswith(".tflite") else model_path + model_path = model_path.removesuffix(".tflite") tf.saved_model.save(smodel, model_path, signatures=signatures) if mode == "append": @@ -959,7 +950,7 @@ def basic(self, inputs): with open(os.path.join(labels_dir, "label_names.csv"), "w", newline="") as labelsfile: labelwriter = csv.writer(labelsfile) - labelwriter.writerows(zip(labelIds, labels)) + labelwriter.writerows(zip(labelIds, labels, strict=True)) # Save class names file classes_dir = os.path.join(model_path, "classes") @@ -1017,8 +1008,6 @@ def predict_filter(lat, lon, week): Returns: A list of probabilities for all species. """ - global M_INTERPRETER - # Does interpreter exist? if M_INTERPRETER is None: load_meta_model() @@ -1053,52 +1042,51 @@ def explore(lat: float, lon: float, week: int): l_filter = np.where(l_filter >= cfg.LOCATION_FILTER_THRESHOLD, l_filter, 0) # Zip with labels - l_filter = list(zip(l_filter, cfg.LABELS)) + l_filter = list(zip(l_filter, cfg.LABELS, strict=True)) # Sort by filter value - l_filter = sorted(l_filter, key=lambda x: x[0], reverse=True) - - return l_filter + return sorted(l_filter, key=lambda x: x[0], reverse=True) def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25, epsilon=1e-7): """ Focal loss for better handling of class imbalance. - + This loss function gives more weight to hard examples and down-weights easy examples. Particularly helpful for imbalanced datasets where some classes have few samples. - + Args: y_true: Ground truth labels. y_pred: Predicted probabilities. gamma: Focusing parameter. Higher values mean more focus on hard examples. alpha: Balance parameter. Controls weight of positive vs negative examples. epsilon: Small constant to prevent log(0). - + Returns: Focal loss value. """ import tensorflow.keras.backend as K - + # Apply sigmoid if not already applied y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon) - + # Calculate cross entropy cross_entropy = -y_true * K.log(y_pred) - (1 - y_true) * K.log(1 - y_pred) - + # Calculate focal weight p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred) focal_weight = K.pow(1 - p_t, gamma) - + # Apply alpha balancing alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha) - + # Calculate focal loss focal_loss = alpha_factor * focal_weight * cross_entropy - + # Sum over all classes return K.sum(focal_loss, axis=-1) + def custom_loss(y_true, y_pred, epsilon=1e-7): import tensorflow.keras.backend as K @@ -1109,9 +1097,7 @@ def custom_loss(y_true, y_pred, epsilon=1e-7): negative_loss = -K.sum((1 - y_true) * K.log(K.clip(1 - y_pred, epsilon, 1.0 - epsilon)), axis=-1) # Combine both loss terms - total_loss = positive_loss + negative_loss - - return total_loss + return positive_loss + negative_loss def flat_sigmoid(x, sensitivity=-1, bias=1.0): @@ -1155,8 +1141,6 @@ def predict(sample): if cfg.CUSTOM_CLASSIFIER is not None: return predict_with_custom_classifier(sample) - global INTERPRETER - # Does interpreter or keras model exist? if INTERPRETER is None and PBMODEL is None: load_model() @@ -1169,15 +1153,10 @@ def predict(sample): # Make a prediction (Audio only for now) INTERPRETER.set_tensor(INPUT_LAYER_INDEX, np.array(sample, dtype="float32")) INTERPRETER.invoke() - prediction = INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX) - - return prediction - - else: - # Make a prediction (Audio only for now) - prediction = PBMODEL.basic(sample)["scores"] + return INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX) - return prediction + # Make a prediction (Audio only for now) + return PBMODEL.basic(sample)["scores"] def predict_with_custom_classifier(sample): @@ -1189,10 +1168,6 @@ def predict_with_custom_classifier(sample): Returns: The prediction scores for the sample. """ - global C_INTERPRETER - global C_INPUT_SIZE - global C_PBMODEL - # Does interpreter exist? if C_INTERPRETER is None and C_PBMODEL is None: load_custom_classifier() @@ -1207,13 +1182,10 @@ def predict_with_custom_classifier(sample): # Make a prediction C_INTERPRETER.set_tensor(C_INPUT_LAYER_INDEX, np.array(vector, dtype="float32")) C_INTERPRETER.invoke() - prediction = C_INTERPRETER.get_tensor(C_OUTPUT_LAYER_INDEX) - return prediction - else: - prediction = C_PBMODEL.basic(sample)["scores"] + return C_INTERPRETER.get_tensor(C_OUTPUT_LAYER_INDEX) - return prediction + return C_PBMODEL.basic(sample)["scores"] def embeddings(sample): @@ -1225,8 +1197,6 @@ def embeddings(sample): Returns: The embeddings. """ - global INTERPRETER - # Does interpreter exist? if INTERPRETER is None: load_model(False) @@ -1238,6 +1208,5 @@ def embeddings(sample): # Extract feature embeddings INTERPRETER.set_tensor(INPUT_LAYER_INDEX, np.array(sample, dtype="float32")) INTERPRETER.invoke() - features = INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX) - return features + return INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX) diff --git a/birdnet_analyzer/network/client.py b/birdnet_analyzer/network/client.py index 8af97d24..ae6b8b97 100644 --- a/birdnet_analyzer/network/client.py +++ b/birdnet_analyzer/network/client.py @@ -29,20 +29,19 @@ def send_request(host: str, port: int, fpath: str, mdata: str) -> dict: print(f"Requesting analysis for {fpath}") - # Make payload - multipart_form_data = {"audio": (fpath.rsplit(os.sep, 1)[-1], open(fpath, "rb")), "meta": (None, mdata)} + with open(fpath, "rb") as f: + # Make payload + multipart_form_data = {"audio": (fpath.rsplit(os.sep, 1)[-1], f), "meta": (None, mdata)} - # Send request - start_time = time.time() - response = requests.post(url, files=multipart_form_data) - end_time = time.time() - - print("Response: {}, Time: {:.4f}s".format(response.text, end_time - start_time), flush=True) + # Send request + start_time = time.time() + response = requests.post(url, files=multipart_form_data) + end_time = time.time() - # Convert to dict - data = json.loads(response.text) + print(f"Response: {response.text}, Time: {end_time - start_time:.4f}s", flush=True) - return data + # Convert to dict + return json.loads(response.text) def _save_result(data, fpath): @@ -62,7 +61,7 @@ def _save_result(data, fpath): if __name__ == "__main__": - import birdnet_analyzer.cli as cli + from birdnet_analyzer import cli # Freeze support for executable freeze_support() diff --git a/birdnet_analyzer/network/server.py b/birdnet_analyzer/network/server.py index f8566241..34f7316b 100644 --- a/birdnet_analyzer/network/server.py +++ b/birdnet_analyzer/network/server.py @@ -1,11 +1,10 @@ import os -from multiprocessing import freeze_support import shutil import tempfile +from multiprocessing import freeze_support import birdnet_analyzer.config as cfg -import birdnet_analyzer.cli as cli -import birdnet_analyzer.utils as utils +from birdnet_analyzer import cli, utils def start_server(host="0.0.0.0", port=8080, spath="uploads/", threads=1, locale="en"): @@ -38,7 +37,7 @@ def start_server(host="0.0.0.0", port=8080, spath="uploads/", threads=1, locale= # Load translated labels lfile = os.path.join( - cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", "_{}.txt".format(locale)) + cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", f"_{locale}.txt") ) if locale not in ["en"] and os.path.isfile(lfile): diff --git a/birdnet_analyzer/network/utils.py b/birdnet_analyzer/network/utils.py index 35784213..a10791d8 100644 --- a/birdnet_analyzer/network/utils.py +++ b/birdnet_analyzer/network/utils.py @@ -10,10 +10,8 @@ import bottle -import birdnet_analyzer.analyze as analyze import birdnet_analyzer.config as cfg -import birdnet_analyzer.species as species -import birdnet_analyzer.utils as utils +from birdnet_analyzer import analyze, species, utils def result_pooling(lines: list[str], num_results=5, pmode="avg"): @@ -137,7 +135,7 @@ def handle_request(): cfg.LOCATION_FILTER_THRESHOLD = max(0.01, min(0.99, float(mdata.get("sf_thresh", 0.03)))) # Set species list - if not cfg.LATITUDE == -1 and not cfg.LONGITUDE == -1: + if cfg.LATITUDE != -1 and cfg.LONGITUDE != -1: cfg.SPECIES_LIST_FILE = None cfg.SPECIES_LIST = species.get_species_list( cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD @@ -177,8 +175,7 @@ def handle_request(): return json.dumps(data) - else: - return json.dumps({"msg": "Error during analysis."}) + return json.dumps({"msg": "Error during analysis."}) except Exception as e: # Write error log diff --git a/birdnet_analyzer/search/__main__.py b/birdnet_analyzer/search/__main__.py index f5a61cb5..0e8f70ad 100644 --- a/birdnet_analyzer/search/__main__.py +++ b/birdnet_analyzer/search/__main__.py @@ -1,3 +1,3 @@ from birdnet_analyzer.search.cli import main -main() \ No newline at end of file +main() diff --git a/birdnet_analyzer/search/cli.py b/birdnet_analyzer/search/cli.py index 4bbeeee7..747fc1c9 100644 --- a/birdnet_analyzer/search/cli.py +++ b/birdnet_analyzer/search/cli.py @@ -1,10 +1,9 @@ -import birdnet_analyzer.utils as utils +from birdnet_analyzer import utils @utils.runtime_error_handler def main(): - import birdnet_analyzer.cli as cli - from birdnet_analyzer import search + from birdnet_analyzer import cli, search parser = cli.search_parser() args = parser.parse_args() diff --git a/birdnet_analyzer/search/core.py b/birdnet_analyzer/search/core.py index 9ed8a8c2..a45523a5 100644 --- a/birdnet_analyzer/search/core.py +++ b/birdnet_analyzer/search/core.py @@ -36,8 +36,8 @@ def search( """ import os - import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg + from birdnet_analyzer import audio from birdnet_analyzer.search.utils import get_search_results # Create output folder @@ -49,8 +49,8 @@ def search( try: settings = db.get_metadata("birdnet_analyzer_settings") - except: - raise ValueError("No settings present in database.") + except KeyError as e: + raise ValueError("No settings present in database.") from e fmin = settings["BANDPASS_FMIN"] fmax = settings["BANDPASS_FMAX"] @@ -60,7 +60,7 @@ def search( results = get_search_results(queryfile, db, n_results, audio_speed, fmin, fmax, score_function, crop_mode, overlap) # Save the results - for i, r in enumerate(results): + for r in results: embedding_source = db.get_embedding_source(r.embedding_id) file = embedding_source.source_id filebasename = os.path.basename(file) diff --git a/birdnet_analyzer/search/utils.py b/birdnet_analyzer/search/utils.py index 35ec4ebe..b60c046c 100644 --- a/birdnet_analyzer/search/utils.py +++ b/birdnet_analyzer/search/utils.py @@ -3,9 +3,8 @@ from perch_hoplite.db.search_results import SearchResult from scipy.spatial.distance import euclidean -import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg -import birdnet_analyzer.model as model +from birdnet_analyzer import audio, model def cosine_sim(a, b): @@ -52,8 +51,8 @@ def get_query_embedding(queryfile_path): samples = sig_splits data = np.array(samples, dtype="float32") - query = model.embeddings(data) - return query + + return model.embeddings(data) def get_search_results( @@ -80,10 +79,7 @@ def get_search_results( raise ValueError("Invalid score function. Choose 'cosine', 'euclidean' or 'dot'.") db_embeddings_count = db.count_embeddings() - - if n_results > db_embeddings_count - 1: - n_results = db_embeddings_count - 1 - + n_results = min(n_results, db_embeddings_count - 1) scores_by_embedding_id = {} for embedding in query_embeddings: diff --git a/birdnet_analyzer/segments/__main__.py b/birdnet_analyzer/segments/__main__.py index dddb4e4d..4bb36590 100644 --- a/birdnet_analyzer/segments/__main__.py +++ b/birdnet_analyzer/segments/__main__.py @@ -1,3 +1,3 @@ from birdnet_analyzer.segments.cli import main -main() \ No newline at end of file +main() diff --git a/birdnet_analyzer/segments/cli.py b/birdnet_analyzer/segments/cli.py index 181fe65e..3707fdad 100644 --- a/birdnet_analyzer/segments/cli.py +++ b/birdnet_analyzer/segments/cli.py @@ -3,8 +3,7 @@ @runtime_error_handler def main(): - import birdnet_analyzer.cli as cli - from birdnet_analyzer import segments + from birdnet_analyzer import cli, segments # Parse arguments parser = cli.segments_parser() diff --git a/birdnet_analyzer/segments/core.py b/birdnet_analyzer/segments/core.py index ba127e91..f3b084b3 100644 --- a/birdnet_analyzer/segments/core.py +++ b/birdnet_analyzer/segments/core.py @@ -1,5 +1,5 @@ def segments( - input: str, + audio_input: str, output: str | None = None, results: str | None = None, *, @@ -12,7 +12,7 @@ def segments( """ Processes audio files to extract segments based on detection results. Args: - input (str): Path to the input folder containing audio files. + audio_input (str): Path to the input folder containing audio files. output (str | None, optional): Path to the output folder where segments will be saved. If not provided, the input folder will be used as the output folder. Defaults to None. results (str | None, optional): Path to the folder containing detection result files. @@ -36,10 +36,13 @@ def segments( from multiprocessing import Pool import birdnet_analyzer.config as cfg + from birdnet_analyzer.segments.utils import ( + extract_segments, + parse_files, + parse_folders, + ) - from birdnet_analyzer.segments.utils import extract_segments, parse_folders, parse_files # noqa: E402 - - cfg.INPUT_PATH = input + cfg.INPUT_PATH = audio_input if not output: cfg.OUTPUT_PATH = cfg.INPUT_PATH @@ -49,7 +52,7 @@ def segments( results = results if results else cfg.INPUT_PATH # Parse audio and result folders - cfg.FILE_LIST = parse_folders(input, results) + cfg.FILE_LIST = parse_folders(audio_input, results) # Set number of threads cfg.CPU_THREADS = threads diff --git a/birdnet_analyzer/segments/utils.py b/birdnet_analyzer/segments/utils.py index e2fa1e8b..6e240288 100644 --- a/birdnet_analyzer/segments/utils.py +++ b/birdnet_analyzer/segments/utils.py @@ -7,12 +7,11 @@ import numpy as np -import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg -import birdnet_analyzer.utils as utils +from birdnet_analyzer import audio, utils # Set numpy random seed -np.random.seed(cfg.RANDOM_SEED) +RNG = np.random.default_rng(cfg.RANDOM_SEED) SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) @@ -27,14 +26,14 @@ def detect_rtype(line: str): """ if line.lower().startswith("selection"): return "table" - # elif line.lower().startswith("filepath"): - # return "r" - elif line.lower().startswith("indir"): + + if line.lower().startswith("indir"): return "kaleidoscope" - elif line.lower().startswith("start (s)"): + + if line.lower().startswith("start (s)"): return "csv" - else: - return "audacity" + + return "audacity" def get_header_mapping(line: str) -> dict: @@ -49,22 +48,14 @@ def get_header_mapping(line: str) -> dict: """ rtype = detect_rtype(line) - if rtype == "table" or rtype == "audacity": - sep = "\t" - else: - sep = "," + sep = "\t" if rtype in ("table", "audacity") else "," cols = line.split(sep) - mapping = {} + return {col: i for i, col in enumerate(cols)} - for i, col in enumerate(cols): - mapping[col] = i - return mapping - - -def parse_folders(apath: str, rpath: str, allowed_result_filetypes: list[str] = ["txt", "csv"]) -> list[dict]: +def parse_folders(apath: str, rpath: str, allowed_result_filetypes: tuple[str] = ("txt", "csv")) -> list[dict]: """Read audio and result files. Reads all audio files and BirdNET output inside directory recursively. @@ -72,7 +63,7 @@ def parse_folders(apath: str, rpath: str, allowed_result_filetypes: list[str] = Args: apath (str): Path to search for audio files. rpath (str): Path to search for result files. - allowed_result_filetypes (list[str]): List of extensions for the result files. + allowed_result_filetypes (tuple[str]): List of extensions for the result files. Returns: list[dict]: A list of {"audio": path_to_audio, "result": path_to_result }. @@ -169,7 +160,7 @@ def parse_files(flist: list[dict], max_segments=100): # Shuffle segments for each species and limit to max_segments for s in species_segments: - np.random.shuffle(species_segments[s]) + RNG.shuffle(species_segments[s]) species_segments[s] = species_segments[s][:max_segments] # Make dict of segments per audio file @@ -187,9 +178,7 @@ def parse_files(flist: list[dict], max_segments=100): print(f"Found {seg_cnt} segments in {len(segments)} audio files.") # Convert to list - flist = [tuple(e) for e in segments.items()] - - return flist + return [tuple(e) for e in segments.items()] def find_segments_from_combined(rfile: str) -> list[dict]: @@ -352,7 +341,7 @@ def extract_segments(item: tuple[tuple[str, list[dict]], float, dict[str]]): print(f"Error: Cannot open audio file {afile}", flush=True) utils.write_error_log(ex) - return + return None # Extract segments for seg_cnt, seg in enumerate(segments, 1): diff --git a/birdnet_analyzer/species/__main__.py b/birdnet_analyzer/species/__main__.py index 744e7159..d203ebbf 100644 --- a/birdnet_analyzer/species/__main__.py +++ b/birdnet_analyzer/species/__main__.py @@ -1,3 +1,3 @@ from birdnet_analyzer.species.cli import main -main() \ No newline at end of file +main() diff --git a/birdnet_analyzer/species/cli.py b/birdnet_analyzer/species/cli.py index e3d543e8..cda85dcb 100644 --- a/birdnet_analyzer/species/cli.py +++ b/birdnet_analyzer/species/cli.py @@ -3,8 +3,7 @@ @runtime_error_handler def main(): - import birdnet_analyzer.cli as cli - from birdnet_analyzer import species + from birdnet_analyzer import cli, species # Parse arguments parser = cli.species_parser() diff --git a/birdnet_analyzer/species/utils.py b/birdnet_analyzer/species/utils.py index b68320e3..45f40f2b 100644 --- a/birdnet_analyzer/species/utils.py +++ b/birdnet_analyzer/species/utils.py @@ -6,8 +6,7 @@ import os import birdnet_analyzer.config as cfg -import birdnet_analyzer.model as model -import birdnet_analyzer.utils as utils +from birdnet_analyzer import model, utils def get_species_list(lat: float, lon: float, week: int, threshold=0.05, sort=False) -> list[str]: @@ -64,7 +63,7 @@ def run(output_path, lat, lon, week, threshold, sortby): # Get species list species_list = get_species_list( - cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD, False if sortby == "freq" else True + cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD, sortby != "freq" ) print(f"Done. {len(species_list)} species on list.", flush=True) diff --git a/birdnet_analyzer/train/__main__.py b/birdnet_analyzer/train/__main__.py index 25861754..000c33d7 100644 --- a/birdnet_analyzer/train/__main__.py +++ b/birdnet_analyzer/train/__main__.py @@ -1,3 +1,3 @@ from birdnet_analyzer.train.cli import main -main() \ No newline at end of file +main() diff --git a/birdnet_analyzer/train/cli.py b/birdnet_analyzer/train/cli.py index d9db66df..061c654d 100644 --- a/birdnet_analyzer/train/cli.py +++ b/birdnet_analyzer/train/cli.py @@ -3,8 +3,7 @@ @runtime_error_handler def main(): - import birdnet_analyzer.cli as cli - from birdnet_analyzer import train + from birdnet_analyzer import cli, train # Parse arguments parser = cli.train_parser() diff --git a/birdnet_analyzer/train/core.py b/birdnet_analyzer/train/core.py index 7fd3d81d..18dcd96e 100644 --- a/birdnet_analyzer/train/core.py +++ b/birdnet_analyzer/train/core.py @@ -2,9 +2,9 @@ def train( - input: str, + audio_input: str, output: str = "checkpoints/custom/Custom_Classifier", - test_data: str = None, + test_data: str | None = None, *, crop_mode: Literal["center", "first", "segments"] = "center", overlap: float = 0.0, @@ -36,7 +36,7 @@ def train( """ Trains a custom classifier model using the BirdNET-Analyzer framework. Args: - input (str): Path to the training data directory. + audio_input (str): Path to the training data directory. test_data (str, optional): Path to the test data directory. Defaults to None. If not specified, a validation split will be used. output (str, optional): Path to save the trained model. Defaults to "checkpoints/custom/Custom_Classifier". crop_mode (Literal["center", "first", "segments", "smart"], optional): Mode for cropping audio samples. Defaults to "center". @@ -68,14 +68,14 @@ def train( Returns: None """ - from birdnet_analyzer.train.utils import train_model import birdnet_analyzer.config as cfg + from birdnet_analyzer.train.utils import train_model from birdnet_analyzer.utils import ensure_model_exists ensure_model_exists() # Config - cfg.TRAIN_DATA_PATH = input + cfg.TRAIN_DATA_PATH = audio_input cfg.TEST_DATA_PATH = test_data cfg.SAMPLE_CROP_MODE = crop_mode cfg.SIG_OVERLAP = overlap diff --git a/birdnet_analyzer/train/utils.py b/birdnet_analyzer/train/utils.py index 8865a21a..9083bbfa 100644 --- a/birdnet_analyzer/train/utils.py +++ b/birdnet_analyzer/train/utils.py @@ -11,10 +11,8 @@ import numpy as np import tqdm -import birdnet_analyzer.audio as audio import birdnet_analyzer.config as cfg -import birdnet_analyzer.model as model -import birdnet_analyzer.utils as utils +from birdnet_analyzer import audio, model, utils def save_sample_counts(labels, y_train): @@ -102,6 +100,7 @@ def _load_audio_file(f, label_vector, config): return x_train, y_train + def _load_training_data(cache_mode=None, cache_file="", progress_callback=None): """Loads the data for training. @@ -126,15 +125,15 @@ def _load_training_data(cache_mode=None, cache_file="", progress_callback=None): utils.load_from_cache(cache_file) ) return x_train, y_train, x_test, y_test, labels - else: - print(f"\t...cache file not found: {cache_file}", flush=True) + + print(f"\t...cache file not found: {cache_file}", flush=True) # Print train and test data path as confirmation print(f"\t...train data path: {cfg.TRAIN_DATA_PATH}", flush=True) print(f"\t...test data path: {cfg.TEST_DATA_PATH}", flush=True) # Get list of subfolders as labels - train_folders = list(sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH))) + train_folders = sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH)) # Read all individual labels from the folder names labels = [] @@ -146,7 +145,7 @@ def _load_training_data(cache_mode=None, cache_file="", progress_callback=None): labels.append(label) # Sort labels - labels = list(sorted(labels)) + labels = sorted(labels) # Get valid labels valid_labels = [ @@ -191,7 +190,7 @@ def _load_training_data(cache_mode=None, cache_file="", progress_callback=None): def load_data(data_path, allowed_folders): x = [] y = [] - folders = list(sorted(utils.list_subdirectories(data_path))) + folders = sorted(utils.list_subdirectories(data_path)) for folder in folders: if folder not in allowed_folders: @@ -254,7 +253,7 @@ def load_data(data_path, allowed_folders): x_train, y_train = load_data(cfg.TRAIN_DATA_PATH, train_folders) if cfg.TEST_DATA_PATH and cfg.TEST_DATA_PATH != cfg.TRAIN_DATA_PATH: - test_folders = list(sorted(utils.list_subdirectories(cfg.TEST_DATA_PATH))) + test_folders = sorted(utils.list_subdirectories(cfg.TEST_DATA_PATH)) allowed_test_folders = [ folder for folder in test_folders if folder in train_folders and not folder.startswith("-") ] @@ -279,14 +278,14 @@ def load_data(data_path, allowed_folders): def normalize_embeddings(embeddings): """ Normalize embeddings to improve training stability and performance. - + This applies L2 normalization to each embedding vector, which can help - with convergence and model performance, especially when training on + with convergence and model performance, especially when training on embeddings from different sources or domains. - + Args: embeddings: numpy array of embedding vectors - + Returns: Normalized embeddings array """ @@ -295,8 +294,7 @@ def normalize_embeddings(embeddings): # Avoid division by zero norms[norms == 0] = 1.0 # Normalize each embedding vector - normalized = embeddings / norms - return normalized + return embeddings / norms def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, autotune_directory="autotune"): @@ -314,7 +312,9 @@ def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, # Load training data print("Loading training data...", flush=True) - x_train, y_train, x_test, y_test, labels = _load_training_data(cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE, on_data_load_end) + x_train, y_train, x_test, y_test, labels = _load_training_data( + cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE, on_data_load_end + ) print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True) if len(x_test) > 0: print(f"...Loaded {x_test.shape[0]} test samples.", flush=True) @@ -447,18 +447,18 @@ def run_trial(self, trial, *args, **kwargs): ), train_with_focal_loss=hp.Boolean("focal_loss", default=cfg.TRAIN_WITH_FOCAL_LOSS), focal_loss_gamma=hp.Choice( - "focal_loss_gamma", - [0.5, 1.0, 2.0, 3.0, 4.0], + "focal_loss_gamma", + [0.5, 1.0, 2.0, 3.0, 4.0], default=cfg.FOCAL_LOSS_GAMMA, parent_name="focal_loss", - parent_values=[True] + parent_values=[True], ), focal_loss_alpha=hp.Choice( - "focal_loss_alpha", - [0.1, 0.25, 0.5, 0.75, 0.9], + "focal_loss_alpha", + [0.1, 0.25, 0.5, 0.75, 0.9], default=cfg.FOCAL_LOSS_ALPHA, parent_name="focal_loss", - parent_values=[True] + parent_values=[True], ), ) @@ -482,7 +482,7 @@ def run_trial(self, trial, *args, **kwargs): # Return the negative AUPRC for minimization (keras-tuner minimizes by default) return [-h for h in histories] - + # Create the tuner instance tuner = BirdNetTuner( x_train=x_train, @@ -589,132 +589,162 @@ def run_trial(self, trial, *args, **kwargs): if len(x_test) > 0: print("\nEvaluating model on test data...", flush=True) metrics = evaluate_model(classifier, x_test, y_test, labels) - + # Save evaluation results to file if metrics: import csv + eval_file_path = cfg.CUSTOM_CLASSIFIER + "_evaluation.csv" - with open(eval_file_path, 'w', newline='') as f: + with open(eval_file_path, "w", newline="") as f: writer = csv.writer(f) - + # Define all the metrics as columns, including both default and optimized threshold metrics - header = ['Class', - 'Precision (0.5)', 'Recall (0.5)', 'F1 Score (0.5)', - 'Precision (opt)', 'Recall (opt)', 'F1 Score (opt)', - 'AUPRC', 'AUROC', 'Optimal Threshold', - 'True Positives', 'False Positives', 'True Negatives', 'False Negatives', - 'Samples', 'Percentage (%)'] + header = [ + "Class", + "Precision (0.5)", + "Recall (0.5)", + "F1 Score (0.5)", + "Precision (opt)", + "Recall (opt)", + "F1 Score (opt)", + "AUPRC", + "AUROC", + "Optimal Threshold", + "True Positives", + "False Positives", + "True Negatives", + "False Negatives", + "Samples", + "Percentage (%)", + ] writer.writerow(header) - + # Write macro-averaged metrics (overall scores) first - writer.writerow([ - 'OVERALL (Macro-avg)', - f"{metrics['macro_precision_default']:.4f}", - f"{metrics['macro_recall_default']:.4f}", - f"{metrics['macro_f1_default']:.4f}", - f"{metrics['macro_precision_opt']:.4f}", - f"{metrics['macro_recall_opt']:.4f}", - f"{metrics['macro_f1_opt']:.4f}", - f"{metrics['macro_auprc']:.4f}", - f"{metrics['macro_auroc']:.4f}", - '', '', '', '', '', '', '' # Empty cells for Threshold, TP, FP, TN, FN, Samples, Percentage - ]) - + writer.writerow( + [ + "OVERALL (Macro-avg)", + f"{metrics['macro_precision_default']:.4f}", + f"{metrics['macro_recall_default']:.4f}", + f"{metrics['macro_f1_default']:.4f}", + f"{metrics['macro_precision_opt']:.4f}", + f"{metrics['macro_recall_opt']:.4f}", + f"{metrics['macro_f1_opt']:.4f}", + f"{metrics['macro_auprc']:.4f}", + f"{metrics['macro_auroc']:.4f}", + "", + "", + "", + "", + "", + "", + "", # Empty cells for Threshold, TP, FP, TN, FN, Samples, Percentage + ] + ) + # Write per-class metrics (one row per species) - for class_name, class_metrics in metrics['class_metrics'].items(): - distribution = metrics['class_distribution'].get(class_name, {'count': 0, 'percentage': 0.0}) - writer.writerow([ - class_name, - f"{class_metrics['precision_default']:.4f}", - f"{class_metrics['recall_default']:.4f}", - f"{class_metrics['f1_default']:.4f}", - f"{class_metrics['precision_opt']:.4f}", - f"{class_metrics['recall_opt']:.4f}", - f"{class_metrics['f1_opt']:.4f}", - f"{class_metrics['auprc']:.4f}", - f"{class_metrics['auroc']:.4f}", - f"{class_metrics['threshold']:.2f}", - class_metrics['tp'], - class_metrics['fp'], - class_metrics['tn'], - class_metrics['fn'], - distribution['count'], - f"{distribution['percentage']:.2f}" - ]) - + for class_name, class_metrics in metrics["class_metrics"].items(): + distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0}) + writer.writerow( + [ + class_name, + f"{class_metrics['precision_default']:.4f}", + f"{class_metrics['recall_default']:.4f}", + f"{class_metrics['f1_default']:.4f}", + f"{class_metrics['precision_opt']:.4f}", + f"{class_metrics['recall_opt']:.4f}", + f"{class_metrics['f1_opt']:.4f}", + f"{class_metrics['auprc']:.4f}", + f"{class_metrics['auroc']:.4f}", + f"{class_metrics['threshold']:.2f}", + class_metrics["tp"], + class_metrics["fp"], + class_metrics["tn"], + class_metrics["fn"], + distribution["count"], + f"{distribution['percentage']:.2f}", + ] + ) + print(f"Evaluation results saved to {eval_file_path}", flush=True) else: print("\nNo separate test data provided for evaluation. Using validation metrics.", flush=True) - print(f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}, Best Loss: {best_val_loss} (epoch {best_epoch+1}/{len(history.epoch)})", flush=True) + print( + f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}, Best Loss: {best_val_loss} (epoch {best_epoch + 1}/{len(history.epoch)})", + flush=True, + ) return history, metrics + def find_optimal_threshold(y_true, y_pred_prob): """ Find the optimal classification threshold using the F1 score. - + For imbalanced datasets, the default threshold of 0.5 may not be optimal. This function finds the threshold that maximizes the F1 score for each class. - + Args: y_true: Ground truth labels y_pred_prob: Predicted probabilities - + Returns: The optimal threshold value """ from sklearn.metrics import f1_score - + # Try different thresholds and find the one that gives the best F1 score best_threshold = 0.5 best_f1 = 0.0 - + for threshold in np.arange(0.1, 0.9, 0.05): y_pred = (y_pred_prob >= threshold).astype(int) f1 = f1_score(y_true, y_pred) - + if f1 > best_f1: best_f1 = f1 best_threshold = threshold - + return best_threshold def evaluate_model(classifier, x_test, y_test, labels, threshold=None): """ Evaluates the trained model on test data and prints detailed metrics. - + Args: classifier: The trained model x_test: Test features (embeddings) y_test: Test labels labels: List of label names threshold: Classification threshold (if None, will find optimal threshold for each class) - + Returns: Dictionary with evaluation metrics """ from sklearn.metrics import ( - precision_score, recall_score, f1_score, - confusion_matrix, classification_report, - average_precision_score, roc_auc_score + average_precision_score, + confusion_matrix, + f1_score, + precision_score, + recall_score, + roc_auc_score, ) - + # Skip evaluation if test set is empty if len(x_test) == 0: print("No test data available for evaluation.") return {} - + # Make predictions y_pred_prob = classifier.predict(x_test) - + # Calculate metrics for each class metrics = {} - + print("\nModel Evaluation:") print("=================") - + # Calculate metrics for each class precisions_default = [] recalls_default = [] @@ -726,69 +756,69 @@ def evaluate_model(classifier, x_test, y_test, labels, threshold=None): aurocs = [] class_metrics = {} optimal_thresholds = {} - + # Print the metric calculation method that's being used print("\nNote: The AUPRC and AUROC metrics calculated during post-training evaluation may differ") print("from training history values due to different calculation methods:") print(" - Training history uses Keras metrics calculated over batches") print(" - Evaluation uses scikit-learn metrics calculated over the entire dataset") - + for i in range(y_test.shape[1]): try: # Calculate metrics with default threshold (0.5) y_pred_default = (y_pred_prob[:, i] >= 0.5).astype(int) - + class_precision_default = precision_score(y_test[:, i], y_pred_default) class_recall_default = recall_score(y_test[:, i], y_pred_default) class_f1_default = f1_score(y_test[:, i], y_pred_default) - + precisions_default.append(class_precision_default) recalls_default.append(class_recall_default) f1s_default.append(class_f1_default) - + # Find optimal threshold for this class if needed if threshold is None: class_threshold = find_optimal_threshold(y_test[:, i], y_pred_prob[:, i]) optimal_thresholds[labels[i]] = class_threshold else: class_threshold = threshold - + # Calculate metrics with optimized threshold y_pred_opt = (y_pred_prob[:, i] >= class_threshold).astype(int) - + class_precision_opt = precision_score(y_test[:, i], y_pred_opt) class_recall_opt = recall_score(y_test[:, i], y_pred_opt) class_f1_opt = f1_score(y_test[:, i], y_pred_opt) class_auprc = average_precision_score(y_test[:, i], y_pred_prob[:, i]) class_auroc = roc_auc_score(y_test[:, i], y_pred_prob[:, i]) - + precisions_opt.append(class_precision_opt) recalls_opt.append(class_recall_opt) f1s_opt.append(class_f1_opt) auprcs.append(class_auprc) aurocs.append(class_auroc) - + # Confusion matrix with optimized threshold tn, fp, fn, tp = confusion_matrix(y_test[:, i], y_pred_opt).ravel() - + class_metrics[labels[i]] = { - 'precision_default': class_precision_default, - 'recall_default': class_recall_default, - 'f1_default': class_f1_default, - 'precision_opt': class_precision_opt, - 'recall_opt': class_recall_opt, - 'f1_opt': class_f1_opt, - 'auprc': class_auprc, - 'auroc': class_auroc, - 'tp': tp, - 'fp': fp, - 'tn': tn, - 'fn': fn, - 'threshold': class_threshold + "precision_default": class_precision_default, + "recall_default": class_recall_default, + "f1_default": class_f1_default, + "precision_opt": class_precision_opt, + "recall_opt": class_recall_opt, + "f1_opt": class_f1_opt, + "auprc": class_auprc, + "auroc": class_auroc, + "tp": tp, + "fp": fp, + "tn": tn, + "fn": fn, + "threshold": class_threshold, } - + print(f"\nClass: {labels[i]}") - print(f" Default threshold (0.5):") + print(" Default threshold (0.5):") print(f" Precision: {class_precision_default:.4f}") print(f" Recall: {class_recall_default:.4f}") print(f" F1 Score: {class_f1_default:.4f}") @@ -798,50 +828,50 @@ def evaluate_model(classifier, x_test, y_test, labels, threshold=None): print(f" F1 Score: {class_f1_opt:.4f}") print(f" AUPRC: {class_auprc:.4f}") print(f" AUROC: {class_auroc:.4f}") - print(f" Confusion matrix (optimized threshold):") + print(" Confusion matrix (optimized threshold):") print(f" True Positives: {tp}") print(f" False Positives: {fp}") print(f" True Negatives: {tn}") print(f" False Negatives: {fn}") - + except Exception as e: print(f"Error calculating metrics for class {labels[i]}: {e}") - + # Calculate macro-averaged metrics for both default and optimized thresholds - metrics['macro_precision_default'] = np.mean(precisions_default) - metrics['macro_recall_default'] = np.mean(recalls_default) - metrics['macro_f1_default'] = np.mean(f1s_default) - metrics['macro_precision_opt'] = np.mean(precisions_opt) - metrics['macro_recall_opt'] = np.mean(recalls_opt) - metrics['macro_f1_opt'] = np.mean(f1s_opt) - metrics['macro_auprc'] = np.mean(auprcs) - metrics['macro_auroc'] = np.mean(aurocs) - metrics['class_metrics'] = class_metrics - metrics['optimal_thresholds'] = optimal_thresholds - + metrics["macro_precision_default"] = np.mean(precisions_default) + metrics["macro_recall_default"] = np.mean(recalls_default) + metrics["macro_f1_default"] = np.mean(f1s_default) + metrics["macro_precision_opt"] = np.mean(precisions_opt) + metrics["macro_recall_opt"] = np.mean(recalls_opt) + metrics["macro_f1_opt"] = np.mean(f1s_opt) + metrics["macro_auprc"] = np.mean(auprcs) + metrics["macro_auroc"] = np.mean(aurocs) + metrics["class_metrics"] = class_metrics + metrics["optimal_thresholds"] = optimal_thresholds + print("\nMacro-averaged metrics:") - print(f" Default threshold (0.5):") + print(" Default threshold (0.5):") print(f" Precision: {metrics['macro_precision_default']:.4f}") print(f" Recall: {metrics['macro_recall_default']:.4f}") print(f" F1 Score: {metrics['macro_f1_default']:.4f}") - print(f" Optimized thresholds:") + print(" Optimized thresholds:") print(f" Precision: {metrics['macro_precision_opt']:.4f}") print(f" Recall: {metrics['macro_recall_opt']:.4f}") print(f" F1 Score: {metrics['macro_f1_opt']:.4f}") print(f" AUPRC: {metrics['macro_auprc']:.4f}") print(f" AUROC: {metrics['macro_auroc']:.4f}") - + # Calculate class distribution in test set class_counts = y_test.sum(axis=0) total_samples = len(y_test) class_distribution = {} - + print("\nClass distribution in test set:") for i, count in enumerate(class_counts): percentage = count / total_samples * 100 - class_distribution[labels[i]] = {'count': int(count), 'percentage': percentage} + class_distribution[labels[i]] = {"count": int(count), "percentage": percentage} print(f" {labels[i]}: {int(count)} samples ({percentage:.2f}%)") - - metrics['class_distribution'] = class_distribution - + + metrics["class_distribution"] = class_distribution + return metrics diff --git a/birdnet_analyzer/translate.py b/birdnet_analyzer/translate.py index d7b5e74d..1ff075db 100644 --- a/birdnet_analyzer/translate.py +++ b/birdnet_analyzer/translate.py @@ -4,14 +4,43 @@ Uses the requests to the eBird-API. """ + import json import os import urllib.request import birdnet_analyzer.config as cfg -import birdnet_analyzer.utils as utils - -LOCALES = ['af', 'ar', 'cs', 'da', 'de', 'en_uk', 'es', 'fi', 'fr', 'hu', 'it', 'ja', 'ko', 'nl', 'no', 'pl', 'pt_BR', 'pt_PT', 'ro', 'ru', 'sk', 'sl', 'sv', 'th', 'tr', 'uk', 'zh'] +from birdnet_analyzer import utils + +LOCALES = [ + "af", + "ar", + "cs", + "da", + "de", + "en_uk", + "es", + "fi", + "fr", + "hu", + "it", + "ja", + "ko", + "nl", + "no", + "pl", + "pt_BR", + "pt_PT", + "ro", + "ru", + "sk", + "sl", + "sv", + "th", + "tr", + "uk", + "zh", +] """ Locales for 26 common languages (according to GitHub Copilot) """ API_TOKEN = "yourAPIToken" @@ -45,7 +74,7 @@ def translate(locale: str): Args: locale: Two character string of a language. - + Returns: The translated list of labels. """ diff --git a/birdnet_analyzer/utils.py b/birdnet_analyzer/utils.py index 8ebe2436..d8923235 100644 --- a/birdnet_analyzer/utils.py +++ b/birdnet_analyzer/utils.py @@ -1,8 +1,8 @@ """Module containing common function.""" -import sys import itertools import os +import sys import traceback from pathlib import Path @@ -54,7 +54,7 @@ def spectrogram_from_file(path, fig_num=None, fig_size=None, offset=0, duration= Returns: matplotlib.figure.Figure: The generated spectrogram figure. """ - import birdnet_analyzer.audio as audio + from birdnet_analyzer import audio # s, sr = librosa.load(path, offset=offset, duration=duration) s, sr = audio.open_audio_file(path, offset=offset, duration=duration, fmin=fmin, fmax=fmax, speed=speed) @@ -103,7 +103,7 @@ def spectrogram_from_audio(s, sr, fig_num=None, fig_size=None): return librosa.display.specshow(S_db, ax=ax, n_fft=1024, hop_length=512).figure -def collect_audio_files(path: str, max_files: int = None): +def collect_audio_files(path: str, max_files: int | None = None): """Collects all audio files in the given directory. Args: @@ -140,9 +140,11 @@ def collect_all_files(path: str, filetypes: list[str], pattern: str = ""): files = [] for root, _, flist in os.walk(path): - for f in flist: - if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in filetypes and (pattern in f or not pattern): - files.append(os.path.join(root, f)) + files.extend( + os.path.join(root, f) + for f in flist + if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in filetypes and (pattern in f or not pattern) + ) return sorted(files) @@ -227,22 +229,25 @@ def load_from_cache(path): data = np.load(path, allow_pickle=True) # Check if cache contains needed preprocessing parameters - if "fmin" in data and "fmax" in data and "audio_speed" in data and "crop_mode" in data and "overlap" in data: - # Check if preprocessing parameters match current settings - if ( + if ( + "fmin" in data + and "fmax" in data + and "audio_speed" in data + and "crop_mode" in data + and "overlap" in data + and ( # Check if preprocessing parameters match current settings data["fmin"] != cfg.BANDPASS_FMIN or data["fmax"] != cfg.BANDPASS_FMAX or data["audio_speed"] != cfg.AUDIO_SPEED or data["crop_mode"] != cfg.SAMPLE_CROP_MODE or data["overlap"] != cfg.SIG_OVERLAP - ): - print("\t...WARNING: Cache preprocessing parameters don't match current settings!", flush=True) - print(f"\t Cache: fmin={data['fmin']}, fmax={data['fmax']}, speed={data['audio_speed']}", flush=True) - print(f"\t Cache: crop_mode={data['crop_mode']}, overlap={data['overlap']}", flush=True) - print( - f"\t Current: fmin={cfg.BANDPASS_FMIN}, fmax={cfg.BANDPASS_FMAX}, speed={cfg.AUDIO_SPEED}", flush=True - ) - print(f"\t Current: crop_mode={cfg.SAMPLE_CROP_MODE}, overlap={cfg.SIG_OVERLAP}", flush=True) + ) + ): + print("\t...WARNING: Cache preprocessing parameters don't match current settings!", flush=True) + print(f"\t Cache: fmin={data['fmin']}, fmax={data['fmax']}, speed={data['audio_speed']}", flush=True) + print(f"\t Cache: crop_mode={data['crop_mode']}, overlap={data['overlap']}", flush=True) + print(f"\t Current: fmin={cfg.BANDPASS_FMIN}, fmax={cfg.BANDPASS_FMAX}, speed={cfg.AUDIO_SPEED}", flush=True) + print(f"\t Current: crop_mode={cfg.SAMPLE_CROP_MODE}, overlap={cfg.SIG_OVERLAP}", flush=True) # Extract and return data x_train = data["x_train"] @@ -400,11 +405,13 @@ def ensure_model_exists(): total_size = int(response.headers.get("content-length", 0)) block_size = 1024 - with tqdm(total=total_size, unit="iB", unit_scale=True, desc="Downloading model") as tqdm_bar: - with open(download_path, "wb") as file: - for data in response.iter_content(block_size): - tqdm_bar.update(len(data)) - file.write(data) + with ( + tqdm(total=total_size, unit="iB", unit_scale=True, desc="Downloading model") as tqdm_bar, + open(download_path, "wb") as file, + ): + for data in response.iter_content(block_size): + tqdm_bar.update(len(data)) + file.write(data) if response.status_code != 200 or (total_size not in (0, tqdm_bar.n)): raise ValueError(f"Failed to download the file. Status code: {response.status_code}") diff --git a/pyproject.toml b/pyproject.toml index f3bd4185..f5590c66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,3 +82,59 @@ birdnet_analyzer = [ "labels/**/*", "gui/assets/**/*", ] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["birdnet_analyzer"] + +[tool.ruff] +exclude = ["conf.py"] +line-length = 165 + +[tool.ruff.lint] +select = [ + "F", + "B", + "A", + "C4", + "T10", + "EXE", + "PIE", + "PYI", + "PT", + "Q", + "RSE", + "RET", + "SIM", + "TID", + "TD", + "TC", + #"PTH", + "FLY", + "I", + "NPY", + "PD", + #"N", + "PERF", + "E", + "W", + #"D", + "PL", + "UP", + "FURB", + "RUF", +] +ignore = [ + "B008", + "TD003", + "TD002", + "PD901", + "SIM108", + "E722", + "PLR2004", + "PLR0913", + "PLR0915", + "PLR0912", + "PLC0206", + "RUF015", +] diff --git a/tests/assessment/__init__.py b/tests/analyze/__init__.py similarity index 100% rename from tests/assessment/__init__.py rename to tests/analyze/__init__.py diff --git a/tests/analyze/test_analyze.py b/tests/analyze/test_analyze.py new file mode 100644 index 00000000..4be1e844 --- /dev/null +++ b/tests/analyze/test_analyze.py @@ -0,0 +1,258 @@ +import os +import shutil +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +import birdnet_analyzer.config as cfg +from birdnet_analyzer.analyze.core import analyze + + +@pytest.fixture +def setup_test_environment(): + """Create a temporary test environment with audio files.""" + # Create temp directory + test_dir = tempfile.mkdtemp() + input_dir = os.path.join(test_dir, "input") + output_dir = os.path.join(test_dir, "output") + + # Create directories + os.makedirs(input_dir) + os.makedirs(output_dir) + + # Create dummy audio files + test_file1 = os.path.join(input_dir, "test1.wav") + test_file2 = os.path.join(input_dir, "test2.wav") + with open(test_file1, "wb") as f: + f.write(b"dummy audio data") + with open(test_file2, "wb") as f: + f.write(b"more dummy audio data") + + # Store original config values + original_config = { + attr: getattr(cfg, attr) for attr in dir(cfg) if not attr.startswith("_") and not callable(getattr(cfg, attr)) + } + + yield { + "test_dir": test_dir, + "input_dir": input_dir, + "output_dir": output_dir, + "test_file1": test_file1, + "test_file2": test_file2, + } + + # Clean up + shutil.rmtree(test_dir) + + # Restore original config + for attr, value in original_config.items(): + setattr(cfg, attr, value) + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("birdnet_analyzer.analyze.utils.analyze_file") +@patch("birdnet_analyzer.analyze.utils.save_analysis_params") +def test_analyze_single_file( + mock_save_params, mock_analyze_file, mock_set_params, mock_ensure_model, setup_test_environment +): + """Test analyzing a single audio file.""" + env = setup_test_environment + + # Configure mocks + mock_set_params.return_value = [(env["test_file1"], {"param1": "value1"})] + mock_analyze_file.return_value = f"{env['test_file1']}_results.txt" + + # Set config values + cfg.FILE_LIST = [env["test_file1"]] + cfg.LABELS = ["Species1", "Species2"] + cfg.SPECIES_LIST = None + cfg.CPU_THREADS = 1 + cfg.COMBINE_RESULTS = False + cfg.OUTPUT_PATH = env["output_dir"] + + # Call function under test + analyze(env["test_file1"], env["output_dir"], min_conf=0.5) + + # Verify behavior + mock_ensure_model.assert_called_once() + mock_set_params.assert_called_once() + mock_analyze_file.assert_called_once_with((env["test_file1"], {"param1": "value1"})) + mock_save_params.assert_called_once() + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("multiprocessing.Pool") +@patch("birdnet_analyzer.analyze.utils.save_analysis_params") +def test_analyze_directory_multiprocess( + mock_save_params, mock_pool, mock_set_params, mock_ensure_model, setup_test_environment +): + """Test analyzing multiple files with multiprocessing.""" + env = setup_test_environment + + # Configure mocks + file_params = [(env["test_file1"], {"param1": "value1"}), (env["test_file2"], {"param1": "value1"})] + mock_set_params.return_value = file_params + + pool_instance = MagicMock() + mock_pool.return_value.__enter__.return_value = pool_instance + + async_result = MagicMock() + async_result.get.return_value = [f"{env['test_file1']}_results.txt", f"{env['test_file2']}_results.txt"] + pool_instance.map_async.return_value = async_result + + # Set config values + cfg.FILE_LIST = [env["test_file1"], env["test_file2"]] + cfg.LABELS = ["Species1", "Species2"] + cfg.SPECIES_LIST = None + cfg.CPU_THREADS = 2 + cfg.COMBINE_RESULTS = False + cfg.OUTPUT_PATH = env["output_dir"] + + # Call function under test + analyze(env["input_dir"], env["output_dir"], threads=2) + + # Verify behavior + mock_ensure_model.assert_called_once() + mock_set_params.assert_called_once() + mock_pool.assert_called_once_with(2) + pool_instance.map_async.assert_called_once() + mock_save_params.assert_called_once() + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("birdnet_analyzer.analyze.utils.analyze_file") +@patch("birdnet_analyzer.analyze.utils.save_analysis_params") +@patch("birdnet_analyzer.analyze.utils.combine_results") +def test_analyze_with_combined_results( + mock_combine_results, + mock_save_params, + mock_analyze_file, + mock_set_params, + mock_ensure_model, + setup_test_environment, +): + """Test analyzing files with combined results.""" + env = setup_test_environment + + # Configure mocks + result_file = f"{env['test_file1']}_results.txt" + mock_set_params.return_value = [(env["test_file1"], {"param1": "value1"})] + mock_analyze_file.return_value = result_file + + # Set config values + cfg.FILE_LIST = [env["test_file1"]] + cfg.LABELS = ["Species1", "Species2"] + cfg.SPECIES_LIST = None + cfg.CPU_THREADS = 1 + cfg.COMBINE_RESULTS = True + cfg.OUTPUT_PATH = env["output_dir"] + + # Call function under test + analyze(env["test_file1"], env["output_dir"], combine_results=True) + + # Verify behavior + mock_ensure_model.assert_called_once() + mock_set_params.assert_called_once() + mock_analyze_file.assert_called_once_with((env["test_file1"], {"param1": "value1"})) + mock_combine_results.assert_called_once_with([result_file]) + mock_save_params.assert_called_once() + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("birdnet_analyzer.analyze.utils.analyze_file") +def test_analyze_with_location_filtering(mock_analyze_file, mock_set_params, mock_ensure_model, setup_test_environment): + """Test analyzing with location-based filtering.""" + env = setup_test_environment + + # Configure mocks + mock_set_params.return_value = [(env["test_file1"], {"param1": "value1"})] + mock_analyze_file.return_value = f"{env['test_file1']}_results.txt" + + # Call function under test + analyze(env["test_file1"], env["output_dir"], lat=42.5, lon=-76.45, week=20) + + # Verify parameter passing + mock_set_params.assert_called_once() + _, kwargs = mock_set_params.call_args + assert kwargs["lat"] == 42.5 + assert kwargs["lon"] == -76.45 + assert kwargs["week"] == 20 + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("birdnet_analyzer.analyze.utils.analyze_file") +def test_analyze_with_custom_classifier(mock_analyze_file, mock_set_params, mock_ensure_model, setup_test_environment): + """Test analyzing with a custom classifier.""" + env = setup_test_environment + + # Create dummy classifier file + custom_classifier = os.path.join(env["test_dir"], "custom_model.tflite") + with open(custom_classifier, "wb") as f: + f.write(b"dummy model data") + + # Configure mocks + mock_set_params.return_value = [(env["test_file1"], {"param1": "value1"})] + mock_analyze_file.return_value = f"{env['test_file1']}_results.txt" + + # Call function under test + analyze(env["test_file1"], env["output_dir"], classifier=custom_classifier) + + # Verify parameter passing + mock_set_params.assert_called_once() + _, kwargs = mock_set_params.call_args + assert kwargs["custom_classifier"] == custom_classifier + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("birdnet_analyzer.analyze.utils.analyze_file") +def test_analyze_with_multiple_result_types( + mock_analyze_file, mock_set_params, mock_ensure_model, setup_test_environment +): + """Test analyzing with multiple output result types.""" + env = setup_test_environment + + # Configure mocks + mock_set_params.return_value = [(env["test_file1"], {"param1": "value1"})] + mock_analyze_file.return_value = f"{env['test_file1']}_results.txt" + + # Call function under test + analyze(env["test_file1"], env["output_dir"], rtype=["table", "csv", "audacity"]) + + # Verify parameter passing + mock_set_params.assert_called_once() + _, kwargs = mock_set_params.call_args + assert kwargs["rtype"] == ["table", "csv", "audacity"] + + +@patch("birdnet_analyzer.utils.ensure_model_exists") +@patch("birdnet_analyzer.analyze.core._set_params") +@patch("birdnet_analyzer.analyze.utils.analyze_file") +def test_analyze_with_custom_species_list( + mock_analyze_file, mock_set_params, mock_ensure_model, setup_test_environment +): + """Test analyzing with a custom species list.""" + env = setup_test_environment + + # Create dummy species list file + species_list = os.path.join(env["test_dir"], "species.txt") + with open(species_list, "w") as f: + f.write("Species1\nSpecies2\n") + + # Configure mocks + mock_set_params.return_value = [(env["test_file1"], {"param1": "value1"})] + mock_analyze_file.return_value = f"{env['test_file1']}_results.txt" + + # Call function under test + analyze(env["test_file1"], env["output_dir"], slist=species_list) + + # Verify parameter passing + mock_set_params.assert_called_once() + _, kwargs = mock_set_params.call_args + assert kwargs["slist"] == species_list diff --git a/tests/preprocessing/__init__.py b/tests/evaluation/__init__.py similarity index 100% rename from tests/preprocessing/__init__.py rename to tests/evaluation/__init__.py diff --git a/tests/evaluation/assessment/__init__.py b/tests/evaluation/assessment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/assessment/test_metrics.py b/tests/evaluation/assessment/test_metrics.py similarity index 92% rename from tests/assessment/test_metrics.py rename to tests/evaluation/assessment/test_metrics.py index bf8a5d55..05fb88b7 100644 --- a/tests/assessment/test_metrics.py +++ b/tests/evaluation/assessment/test_metrics.py @@ -1,23 +1,25 @@ -import pytest -import numpy as np +import re -from birdnet_analyzer.evaluation.assessment.metrics import ( - calculate_accuracy, - calculate_recall, - calculate_precision, - calculate_f1_score, - calculate_average_precision, - calculate_auroc, -) +import numpy as np +import pytest from sklearn.metrics import ( accuracy_score, + average_precision_score, + f1_score, precision_score, recall_score, - f1_score, - average_precision_score, roc_auc_score, ) +from birdnet_analyzer.evaluation.assessment.metrics import ( + calculate_accuracy, + calculate_auroc, + calculate_average_precision, + calculate_f1_score, + calculate_precision, + calculate_recall, +) + class TestCalculateAccuracy: def test_binary_classification_perfect(self): @@ -104,7 +106,7 @@ def test_multilabel_classification_imperfect(self): def test_incorrect_shapes(self): predictions = np.array([0.9, 0.2, 0.8]) labels = np.array([1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must have the same shape.")): calculate_accuracy( predictions, labels, @@ -116,7 +118,7 @@ def test_incorrect_shapes(self): def test_invalid_threshold(self): predictions = np.array([0.9, 0.2, 0.8, 0.1]) labels = np.array([1, 0, 1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Invalid threshold: 1.5. Must be between 0 and 1.")): calculate_accuracy( predictions, labels, @@ -140,7 +142,7 @@ def test_non_array_inputs(self): def test_empty_arrays(self): predictions = np.array([]) labels = np.array([]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must not be empty.")): calculate_accuracy( predictions, labels, @@ -249,15 +251,13 @@ def test_multilabel_classification_imperfect(self): threshold=0.5, averaging_method="macro", ) - expected_recall = recall_score( - labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0 - ) + expected_recall = recall_score(labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0) assert np.isclose(result, expected_recall) def test_incorrect_shapes(self): predictions = np.array([0.9, 0.2, 0.8]) labels = np.array([1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must have the same shape.")): calculate_recall( predictions, labels, @@ -268,7 +268,7 @@ def test_incorrect_shapes(self): def test_invalid_threshold(self): predictions = np.array([0.9, 0.2, 0.8, 0.1]) labels = np.array([1, 0, 1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Invalid threshold: -0.1. Must be between 0 and 1.")): calculate_recall( predictions, labels, @@ -290,7 +290,7 @@ def test_non_array_inputs(self): def test_empty_arrays(self): predictions = np.array([]) labels = np.array([]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must not be empty.")): calculate_recall( predictions, labels, @@ -308,9 +308,7 @@ def test_multilabel_classification_macro_average(self): threshold=0.5, averaging_method="macro", ) - expected_recall = recall_score( - labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0 - ) + expected_recall = recall_score(labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0) assert np.isclose(result, expected_recall) def test_binary_classification_no_positive_predictions(self): @@ -322,9 +320,7 @@ def test_binary_classification_no_positive_predictions(self): task="binary", threshold=0.5, ) - expected_recall = recall_score( - labels, (predictions >= 0.5).astype(int), zero_division=0 - ) + expected_recall = recall_score(labels, (predictions >= 0.5).astype(int), zero_division=0) assert np.isclose(result, expected_recall) def test_binary_classification_no_positive_labels(self): @@ -423,15 +419,13 @@ def test_multilabel_classification_imperfect(self): threshold=0.5, averaging_method="macro", ) - expected_precision = precision_score( - labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0 - ) + expected_precision = precision_score(labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0) assert np.isclose(result, expected_precision, atol=1e-4) def test_incorrect_shapes(self): predictions = np.array([0.9, 0.2, 0.8]) labels = np.array([1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must have the same shape.")): calculate_precision( predictions, labels, @@ -442,7 +436,7 @@ def test_incorrect_shapes(self): def test_invalid_threshold(self): predictions = np.array([0.9, 0.2, 0.8, 0.1]) labels = np.array([1, 0, 1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Invalid threshold: -0.1. Must be between 0 and 1.")): calculate_precision( predictions, labels, @@ -464,7 +458,7 @@ def test_non_array_inputs(self): def test_empty_arrays(self): predictions = np.array([]) labels = np.array([]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must not be empty.")): calculate_precision( predictions, labels, @@ -499,9 +493,7 @@ def test_binary_classification_no_positive_predictions(self): task="binary", threshold=0.5, ) - expected_precision = precision_score( - labels, (predictions >= 0.5).astype(int), zero_division=0 - ) + expected_precision = precision_score(labels, (predictions >= 0.5).astype(int), zero_division=0) assert np.isclose(result, expected_precision) def test_binary_classification_no_positive_labels(self): @@ -584,15 +576,13 @@ def test_multilabel_classification_imperfect(self): threshold=0.5, averaging_method="macro", ) - expected_f1 = f1_score( - labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0 - ) + expected_f1 = f1_score(labels, (predictions >= 0.5).astype(int), average="macro", zero_division=0) assert np.isclose(result, expected_f1, atol=1e-4) def test_incorrect_shapes(self): predictions = np.array([0.9, 0.2, 0.8]) labels = np.array([1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must have the same shape.")): calculate_f1_score( predictions, labels, @@ -603,7 +593,7 @@ def test_incorrect_shapes(self): def test_invalid_threshold(self): predictions = np.array([0.9, 0.2, 0.8, 0.1]) labels = np.array([1, 0, 1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Invalid threshold: 1.5. Must be between 0 and 1.")): calculate_f1_score( predictions, labels, @@ -625,7 +615,7 @@ def test_non_array_inputs(self): def test_empty_arrays(self): predictions = np.array([]) labels = np.array([]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must not be empty.")): calculate_f1_score( predictions, labels, @@ -672,9 +662,7 @@ def test_binary_classification_no_positive_predictions(self): task="binary", threshold=0.5, ) - expected_f1 = f1_score( - labels, (predictions >= 0.5).astype(int), zero_division=0 - ) + expected_f1 = f1_score(labels, (predictions >= 0.5).astype(int), zero_division=0) assert np.isclose(result, expected_f1) @@ -763,7 +751,7 @@ def test_multilabel_classification_imperfect(self): def test_incorrect_shapes(self): predictions = np.array([0.9, 0.2, 0.8]) labels = np.array([1, 0]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must have the same shape.")): calculate_average_precision( predictions, labels, @@ -784,7 +772,7 @@ def test_non_array_inputs(self): def test_empty_arrays(self): predictions = np.array([]) labels = np.array([]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must not be empty.")): calculate_average_precision( predictions, labels, @@ -925,7 +913,7 @@ def test_multilabel_classification_imperfect(self): def test_incorrect_shapes(self): predictions = np.array([0.9, 0.2]) labels = np.array([1]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must have the same shape.")): calculate_auroc( predictions, labels, @@ -946,7 +934,7 @@ def test_non_array_inputs(self): def test_empty_arrays(self): predictions = np.array([]) labels = np.array([]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Predictions and labels must not be empty.")): calculate_auroc( predictions, labels, diff --git a/tests/assessment/test_performance_assessor.py b/tests/evaluation/assessment/test_performance_assessor.py similarity index 67% rename from tests/assessment/test_performance_assessor.py rename to tests/evaluation/assessment/test_performance_assessor.py index f2ca4358..4dce04c0 100644 --- a/tests/assessment/test_performance_assessor.py +++ b/tests/evaluation/assessment/test_performance_assessor.py @@ -1,10 +1,13 @@ -import pytest +import re + +import matplotlib import numpy as np import pandas as pd -import matplotlib - -from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor +import pytest +from birdnet_analyzer.evaluation.assessment.performance_assessor import ( + PerformanceAssessor, +) matplotlib.use("Agg") # Use non-interactive backend for plotting @@ -23,9 +26,7 @@ def test_init_with_valid_inputs(self): classes = ("Class1", "Class2", "Class3") task = "multilabel" metrics_list = ("recall", "precision", "f1") - assessor = PerformanceAssessor( - num_classes, threshold, classes, task, metrics_list - ) + assessor = PerformanceAssessor(num_classes, threshold, classes, task, metrics_list) assert assessor.num_classes == num_classes assert assessor.threshold == threshold assert assessor.classes == classes @@ -36,53 +37,54 @@ def test_init_with_invalid_num_classes(self): """ Test initializing PerformanceAssessor with invalid num_classes (non-positive integer). """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("num_classes must be a positive integer.")): PerformanceAssessor(num_classes=0) def test_init_with_invalid_threshold(self): """ Test initializing PerformanceAssessor with invalid threshold (not between 0 and 1). """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("threshold must be a float between 0 and 1 (exclusive).")): PerformanceAssessor(num_classes=3, threshold=1.5) def test_init_with_invalid_classes_length(self): """ Test initializing PerformanceAssessor when length of classes does not match num_classes. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Length of classes (3) must match num_classes (2).")): PerformanceAssessor(num_classes=2, classes=("Class1", "Class2", "Class3")) def test_init_with_invalid_classes_type(self): """ Test initializing PerformanceAssessor when classes is not a tuple of strings. """ - with pytest.raises(ValueError): - PerformanceAssessor( - num_classes=2, classes=["Class1", "Class2"] - ) # Should be tuple + with pytest.raises(ValueError, match=re.escape("classes must be a tuple of strings.")): + PerformanceAssessor(num_classes=2, classes=["Class1", "Class2"]) # Should be tuple def test_init_with_invalid_task(self): """ Test initializing PerformanceAssessor with invalid task type. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("task must be 'binary' or 'multilabel'.")): PerformanceAssessor(num_classes=2, task="invalid_task") def test_init_with_invalid_metrics_list(self): """ Test initializing PerformanceAssessor with invalid metrics_list containing unsupported metric. """ - with pytest.raises(ValueError): - PerformanceAssessor( - num_classes=2, metrics_list=("recall", "unsupported_metric") - ) + with pytest.raises( + ValueError, + match=re.escape( + "Invalid metrics in ('recall', 'unsupported_metric'). Valid options are ['accuracy', 'recall', 'precision', 'f1', 'ap', 'auroc']." + ), + ): + PerformanceAssessor(num_classes=2, metrics_list=("recall", "unsupported_metric")) def test_init_with_empty_metrics_list(self): """ Test initializing PerformanceAssessor with empty metrics_list. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("metrics_list cannot be empty.")): PerformanceAssessor(num_classes=2, metrics_list=()) def test_init_with_large_num_classes(self): @@ -117,14 +119,20 @@ class TestPerformanceAssessorCalculateMetrics: Test suite for the PerformanceAssessor calculate_metrics method. """ + def setup_method(self): + """ + Setup method to create a PerformanceAssessor instance for testing. + """ + self.rng = np.random.default_rng(42) + def test_calculate_metrics_with_valid_inputs(self): """ Test calculate_metrics with valid predictions and labels. """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) metrics_df = assessor.calculate_metrics(predictions, labels) assert isinstance(metrics_df, pd.DataFrame) assert not metrics_df.empty @@ -135,11 +143,9 @@ def test_calculate_metrics_with_per_class_metrics(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) - metrics_df = assessor.calculate_metrics( - predictions, labels, per_class_metrics=True - ) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) + metrics_df = assessor.calculate_metrics(predictions, labels, per_class_metrics=True) assert isinstance(metrics_df, pd.DataFrame) assert not metrics_df.empty assert metrics_df.shape[1] == num_classes # Columns should be per class @@ -150,9 +156,9 @@ def test_calculate_metrics_with_invalid_predictions_shape(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100) # Invalid shape - labels = np.random.randint(0, 2, size=(100, num_classes)) - with pytest.raises(ValueError): + predictions = self.rng.random(100) # Invalid shape + labels = self.rng.integers(0, 2, size=(100, num_classes)) + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.calculate_metrics(predictions, labels) def test_calculate_metrics_with_invalid_labels_shape(self): @@ -161,9 +167,9 @@ def test_calculate_metrics_with_invalid_labels_shape(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100,)) # Invalid shape - with pytest.raises(ValueError): + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100,)) # Invalid shape + with pytest.raises(ValueError, match="predictions and labels must have the same shape."): assessor.calculate_metrics(predictions, labels) def test_calculate_metrics_with_mismatched_predictions_and_labels(self): @@ -172,11 +178,9 @@ def test_calculate_metrics_with_mismatched_predictions_and_labels(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint( - 0, 2, size=(90, num_classes) - ) # Different number of samples - with pytest.raises(ValueError): + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(90, num_classes)) # Different number of samples + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.calculate_metrics(predictions, labels) def test_calculate_metrics_with_invalid_predictions_type(self): @@ -186,7 +190,7 @@ def test_calculate_metrics_with_invalid_predictions_type(self): num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) predictions = [[0.1, 0.2, 0.3]] * 100 # List instead of numpy array - labels = np.random.randint(0, 2, size=(100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) with pytest.raises(TypeError): assessor.calculate_metrics(predictions, labels) @@ -196,7 +200,7 @@ def test_calculate_metrics_with_invalid_labels_type(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) + predictions = self.rng.random((100, num_classes)) labels = [[0, 1, 0]] * 100 # List instead of numpy array with pytest.raises(TypeError): assessor.calculate_metrics(predictions, labels) @@ -206,10 +210,13 @@ def test_calculate_metrics_with_invalid_metric_in_metrics_list(self): Test calculate_metrics when metrics_list contains an invalid metric. """ num_classes = 3 - with pytest.raises(ValueError): - PerformanceAssessor( - num_classes=num_classes, metrics_list=("invalid_metric",) - ) + with pytest.raises( + ValueError, + match=re.escape( + "Invalid metrics in ('invalid_metric',). Valid options are ['accuracy', 'recall', 'precision', 'f1', 'ap', 'auroc']." + ), + ): + PerformanceAssessor(num_classes=num_classes, metrics_list=("invalid_metric",)) def test_calculate_metrics_with_binary_task(self): """ @@ -217,8 +224,8 @@ def test_calculate_metrics_with_binary_task(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100, 1) - labels = np.random.randint(0, 2, size=(100, 1)) + predictions = self.rng.random((100, 1)) + labels = self.rng.integers(0, 2, size=(100, 1)) metrics_df = assessor.calculate_metrics(predictions, labels) assert isinstance(metrics_df, pd.DataFrame) assert not metrics_df.empty @@ -229,11 +236,9 @@ def test_calculate_metrics_with_no_classes(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes, classes=None) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) - metrics_df = assessor.calculate_metrics( - predictions, labels, per_class_metrics=True - ) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) + metrics_df = assessor.calculate_metrics(predictions, labels, per_class_metrics=True) assert isinstance(metrics_df, pd.DataFrame) assert not metrics_df.empty expected_columns = [f"Class {i}" for i in range(num_classes)] @@ -245,14 +250,20 @@ class TestPerformanceAssessorPlotMetrics: Test suite for the PerformanceAssessor plot_metrics method. """ + def setup_method(self): + """ + Setup method to create a PerformanceAssessor instance for testing. + """ + self.rng = np.random.default_rng(42) + def test_plot_metrics_with_valid_inputs(self): """ Test plot_metrics with valid predictions and labels. """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50, num_classes) - labels = np.random.randint(0, 2, size=(50, num_classes)) + predictions = self.rng.random((50, num_classes)) + labels = self.rng.integers(0, 2, size=(50, num_classes)) assessor.plot_metrics(predictions, labels) def test_plot_metrics_with_per_class_metrics(self): @@ -261,8 +272,8 @@ def test_plot_metrics_with_per_class_metrics(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50, num_classes) - labels = np.random.randint(0, 2, size=(50, num_classes)) + predictions = self.rng.random((50, num_classes)) + labels = self.rng.integers(0, 2, size=(50, num_classes)) assessor.plot_metrics(predictions, labels, per_class_metrics=True) def test_plot_metrics_with_invalid_predictions_shape(self): @@ -271,9 +282,9 @@ def test_plot_metrics_with_invalid_predictions_shape(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50) # Invalid shape - labels = np.random.randint(0, 2, size=(50, num_classes)) - with pytest.raises(ValueError): + predictions = self.rng.random(50) # Invalid shape + labels = self.rng.integers(0, 2, size=(50, num_classes)) + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_metrics(predictions, labels) def test_plot_metrics_with_invalid_labels_shape(self): @@ -282,9 +293,9 @@ def test_plot_metrics_with_invalid_labels_shape(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50, num_classes) - labels = np.random.randint(0, 2, size=(50,)) # Invalid shape - with pytest.raises(ValueError): + predictions = self.rng.random((50, num_classes)) + labels = self.rng.integers(0, 2, size=(50,)) # Invalid shape + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_metrics(predictions, labels) def test_plot_metrics_with_mismatched_predictions_and_labels(self): @@ -293,11 +304,9 @@ def test_plot_metrics_with_mismatched_predictions_and_labels(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50, num_classes) - labels = np.random.randint( - 0, 2, size=(40, num_classes) - ) # Different number of samples - with pytest.raises(ValueError): + predictions = self.rng.random((50, num_classes)) + labels = self.rng.integers(0, 2, size=(40, num_classes)) # Different number of samples + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_metrics(predictions, labels) def test_plot_metrics_with_binary_task(self): @@ -306,8 +315,8 @@ def test_plot_metrics_with_binary_task(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(50, 1) - labels = np.random.randint(0, 2, size=(50, 1)) + predictions = self.rng.random((50, 1)) + labels = self.rng.integers(0, 2, size=(50, 1)) assessor.plot_metrics(predictions, labels) def test_plot_metrics_with_no_classes(self): @@ -316,8 +325,8 @@ def test_plot_metrics_with_no_classes(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes, classes=None) - predictions = np.random.rand(50, num_classes) - labels = np.random.randint(0, 2, size=(50, num_classes)) + predictions = self.rng.random((50, num_classes)) + labels = self.rng.integers(0, 2, size=(50, num_classes)) assessor.plot_metrics(predictions, labels, per_class_metrics=True) def test_plot_metrics_with_invalid_predictions_type(self): @@ -327,7 +336,7 @@ def test_plot_metrics_with_invalid_predictions_type(self): num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) predictions = [[0.1, 0.2, 0.3]] * 50 # List instead of numpy array - labels = np.random.randint(0, 2, size=(50, num_classes)) + labels = self.rng.integers(0, 2, size=(50, num_classes)) with pytest.raises(TypeError): assessor.plot_metrics(predictions, labels) @@ -337,7 +346,7 @@ def test_plot_metrics_with_invalid_labels_type(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50, num_classes) + predictions = self.rng.random((50, num_classes)) labels = [[0, 1, 0]] * 50 # List instead of numpy array with pytest.raises(TypeError): assessor.plot_metrics(predictions, labels) @@ -348,8 +357,8 @@ def test_plot_metrics_with_large_number_of_classes(self): """ num_classes = 100 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(50, num_classes) - labels = np.random.randint(0, 2, size=(50, num_classes)) + predictions = self.rng.random((50, num_classes)) + labels = self.rng.integers(0, 2, size=(50, num_classes)) assessor.plot_metrics(predictions, labels, per_class_metrics=True) @@ -358,14 +367,20 @@ class TestPerformanceAssessorPlotMetricsAllThresholds: Test suite for the PerformanceAssessor plot_metrics_all_thresholds method. """ + def setup_method(self): + """ + Setup method to create a PerformanceAssessor instance for testing. + """ + self.rng = np.random.default_rng(42) + def test_plot_metrics_all_thresholds_with_valid_inputs(self): """ Test plot_metrics_all_thresholds with valid predictions and labels. """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) assessor.plot_metrics_all_thresholds(predictions, labels) def test_plot_metrics_all_thresholds_with_per_class_metrics(self): @@ -374,11 +389,9 @@ def test_plot_metrics_all_thresholds_with_per_class_metrics(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) - assessor.plot_metrics_all_thresholds( - predictions, labels, per_class_metrics=True - ) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) + assessor.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=True) def test_plot_metrics_all_thresholds_with_invalid_predictions_shape(self): """ @@ -386,9 +399,9 @@ def test_plot_metrics_all_thresholds_with_invalid_predictions_shape(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100) # Invalid shape - labels = np.random.randint(0, 2, size=(100, num_classes)) - with pytest.raises(ValueError): + predictions = self.rng.random(100) # Invalid shape + labels = self.rng.integers(0, 2, size=(100, num_classes)) + with pytest.raises(ValueError, match="predictions and labels must have the same shape."): assessor.plot_metrics_all_thresholds(predictions, labels) def test_plot_metrics_all_thresholds_with_invalid_labels_shape(self): @@ -397,9 +410,9 @@ def test_plot_metrics_all_thresholds_with_invalid_labels_shape(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100,)) # Invalid shape - with pytest.raises(ValueError): + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100,)) # Invalid shape + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_metrics_all_thresholds(predictions, labels) def test_plot_metrics_all_thresholds_with_mismatched_predictions_and_labels(self): @@ -408,11 +421,9 @@ def test_plot_metrics_all_thresholds_with_mismatched_predictions_and_labels(self """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint( - 0, 2, size=(90, num_classes) - ) # Different number of samples - with pytest.raises(ValueError): + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(90, num_classes)) # Different number of samples + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_metrics_all_thresholds(predictions, labels) def test_plot_metrics_all_thresholds_with_binary_task(self): @@ -421,8 +432,8 @@ def test_plot_metrics_all_thresholds_with_binary_task(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100, 1) - labels = np.random.randint(0, 2, size=(100, 1)) + predictions = self.rng.random((100, 1)) + labels = self.rng.integers(0, 2, size=(100, 1)) assessor.plot_metrics_all_thresholds(predictions, labels) def test_plot_metrics_all_thresholds_with_no_classes(self): @@ -431,11 +442,9 @@ def test_plot_metrics_all_thresholds_with_no_classes(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes, classes=None) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) - assessor.plot_metrics_all_thresholds( - predictions, labels, per_class_metrics=True - ) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) + assessor.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=True) def test_plot_metrics_all_thresholds_with_invalid_predictions_type(self): """ @@ -444,7 +453,7 @@ def test_plot_metrics_all_thresholds_with_invalid_predictions_type(self): num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) predictions = [[0.1, 0.2, 0.3]] * 100 # List instead of numpy array - labels = np.random.randint(0, 2, size=(100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) with pytest.raises(TypeError): assessor.plot_metrics_all_thresholds(predictions, labels) @@ -454,7 +463,7 @@ def test_plot_metrics_all_thresholds_with_invalid_labels_type(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) + predictions = self.rng.random((100, num_classes)) labels = [[0, 1, 0]] * 100 # List instead of numpy array with pytest.raises(TypeError): assessor.plot_metrics_all_thresholds(predictions, labels) @@ -465,11 +474,9 @@ def test_plot_metrics_all_thresholds_with_large_number_of_classes(self): """ num_classes = 50 assessor = PerformanceAssessor(num_classes=num_classes) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) - assessor.plot_metrics_all_thresholds( - predictions, labels, per_class_metrics=True - ) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) + assessor.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=True) class TestPerformanceAssessorPlotConfusionMatrix: @@ -477,14 +484,20 @@ class TestPerformanceAssessorPlotConfusionMatrix: Test suite for the PerformanceAssessor plot_confusion_matrix method. """ + def setup_method(self): + """ + Setup method to create a PerformanceAssessor instance for testing. + """ + self.rng = np.random.default_rng(42) # For reproducibility + def test_plot_confusion_matrix_with_valid_inputs(self): """ Test plot_confusion_matrix with valid predictions and labels. """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100, 1) - labels = np.random.randint(0, 2, size=(100, 1)) + predictions = self.rng.random((100, 1)) + labels = self.rng.integers(0, 2, size=(100, 1)) assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_multilabel_task(self): @@ -493,8 +506,8 @@ def test_plot_confusion_matrix_with_multilabel_task(self): """ num_classes = 3 assessor = PerformanceAssessor(num_classes=num_classes, task="multilabel") - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_invalid_predictions_shape(self): @@ -503,9 +516,9 @@ def test_plot_confusion_matrix_with_invalid_predictions_shape(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100) # Invalid shape - labels = np.random.randint(0, 2, size=(100, 1)) - with pytest.raises(ValueError): + predictions = self.rng.random(100) # Invalid shape + labels = self.rng.integers(0, 2, size=(100, 1)) + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_invalid_labels_shape(self): @@ -514,9 +527,9 @@ def test_plot_confusion_matrix_with_invalid_labels_shape(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100, 1) - labels = np.random.randint(0, 2, size=(100,)) # Invalid shape - with pytest.raises(ValueError): + predictions = self.rng.random((100, 1)) + labels = self.rng.integers(0, 2, size=(100,)) # Invalid shape + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_mismatched_predictions_and_labels(self): @@ -525,9 +538,9 @@ def test_plot_confusion_matrix_with_mismatched_predictions_and_labels(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100, 1) - labels = np.random.randint(0, 2, size=(90, 1)) # Different number of samples - with pytest.raises(ValueError): + predictions = self.rng.random((100, 1)) + labels = self.rng.integers(0, 2, size=(90, 1)) # Different number of samples + with pytest.raises(ValueError, match=re.escape("predictions and labels must have the same shape.")): assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_invalid_predictions_type(self): @@ -537,8 +550,8 @@ def test_plot_confusion_matrix_with_invalid_predictions_type(self): num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") predictions = [0.1] * 100 # List instead of numpy array - labels = np.random.randint(0, 2, size=(100, 1)) - with pytest.raises(TypeError): + labels = self.rng.integers(0, 2, size=(100, 1)) + with pytest.raises(TypeError, match=re.escape("predictions must be a NumPy array.")): assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_invalid_labels_type(self): @@ -547,9 +560,9 @@ def test_plot_confusion_matrix_with_invalid_labels_type(self): """ num_classes = 1 assessor = PerformanceAssessor(num_classes=num_classes, task="binary") - predictions = np.random.rand(100, 1) + predictions = self.rng.random((100, 1)) labels = [0] * 100 # List instead of numpy array - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=re.escape("labels must be a NumPy array.")): assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_large_number_of_classes(self): @@ -558,8 +571,8 @@ def test_plot_confusion_matrix_with_large_number_of_classes(self): """ num_classes = 20 assessor = PerformanceAssessor(num_classes=num_classes, task="multilabel") - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) assessor.plot_confusion_matrix(predictions, labels) def test_plot_confusion_matrix_with_no_classes(self): @@ -567,9 +580,7 @@ def test_plot_confusion_matrix_with_no_classes(self): Test plot_confusion_matrix when no classes are provided (classes=None). """ num_classes = 3 - assessor = PerformanceAssessor( - num_classes=num_classes, classes=None, task="multilabel" - ) - predictions = np.random.rand(100, num_classes) - labels = np.random.randint(0, 2, size=(100, num_classes)) + assessor = PerformanceAssessor(num_classes=num_classes, classes=None, task="multilabel") + predictions = self.rng.random((100, num_classes)) + labels = self.rng.integers(0, 2, size=(100, num_classes)) assessor.plot_confusion_matrix(predictions, labels) diff --git a/tests/assessment/test_plotting.py b/tests/evaluation/assessment/test_plotting.py similarity index 78% rename from tests/assessment/test_plotting.py rename to tests/evaluation/assessment/test_plotting.py index ada66f6e..d5bcb361 100644 --- a/tests/assessment/test_plotting.py +++ b/tests/evaluation/assessment/test_plotting.py @@ -1,15 +1,17 @@ -import pytest -import pandas as pd -import numpy as np +import re + import matplotlib import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest from birdnet_analyzer.evaluation.assessment.plotting import ( - plot_overall_metrics, - plot_metrics_per_class, + plot_confusion_matrices, plot_metrics_across_thresholds, plot_metrics_across_thresholds_per_class, - plot_confusion_matrices, + plot_metrics_per_class, + plot_overall_metrics, ) # Set the matplotlib backend to 'Agg' to prevent GUI issues during testing @@ -28,9 +30,7 @@ def test_valid_input(self): """ Test with valid inputs to ensure the function runs without errors. """ - metrics_df = pd.DataFrame( - {"Overall": [0.8, 0.75, 0.9]}, index=["Precision", "Recall", "F1"] - ) + metrics_df = pd.DataFrame({"Overall": [0.8, 0.75, 0.9]}, index=["Precision", "Recall", "F1"]) colors = ["blue", "green", "red"] plot_overall_metrics(metrics_df, colors) @@ -40,17 +40,16 @@ def test_empty_metrics_df(self): """ metrics_df = pd.DataFrame({"Overall": []}) colors = [] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("metrics_df is empty.")): plot_overall_metrics(metrics_df, colors) def test_missing_overall_column(self): """ Test with metrics_df missing 'Overall' column to ensure it raises KeyError. """ - metrics_df = pd.DataFrame( - {"Value": [0.8, 0.75, 0.9]}, index=["Precision", "Recall", "F1"] - ) + metrics_df = pd.DataFrame({"Value": [0.8, 0.75, 0.9]}, index=["Precision", "Recall", "F1"]) colors = ["blue", "green", "red"] + with pytest.raises(KeyError): plot_overall_metrics(metrics_df, colors) @@ -58,9 +57,7 @@ def test_colors_shorter_than_metrics(self): """ Test with fewer colors than metrics to check color assignment. """ - metrics_df = pd.DataFrame( - {"Overall": [0.8, 0.75, 0.9]}, index=["Precision", "Recall", "F1"] - ) + metrics_df = pd.DataFrame({"Overall": [0.8, 0.75, 0.9]}, index=["Precision", "Recall", "F1"]) colors = ["blue", "green"] # Only two colors for three metrics plot_overall_metrics(metrics_df, colors) @@ -68,9 +65,7 @@ def test_colors_longer_than_metrics(self): """ Test with more colors than metrics to ensure extra colors are ignored. """ - metrics_df = pd.DataFrame( - {"Overall": [0.8, 0.75]}, index=["Precision", "Recall"] - ) + metrics_df = pd.DataFrame({"Overall": [0.8, 0.75]}, index=["Precision", "Recall"]) colors = ["blue", "green", "red", "yellow"] plot_overall_metrics(metrics_df, colors) @@ -87,9 +82,7 @@ def test_invalid_colors_type(self): """ Test with invalid type for colors to ensure it raises TypeError. """ - metrics_df = pd.DataFrame( - {"Overall": [0.8, 0.75]}, index=["Precision", "Recall"] - ) + metrics_df = pd.DataFrame({"Overall": [0.8, 0.75]}, index=["Precision", "Recall"]) colors = "blue" # Should be a list with pytest.raises(TypeError): plot_overall_metrics(metrics_df, colors) @@ -98,9 +91,7 @@ def test_nan_values(self): """ Test with NaN values in metrics_df to ensure it handles missing data. """ - metrics_df = pd.DataFrame( - {"Overall": [0.8, np.nan, 0.9]}, index=["Precision", "Recall", "F1"] - ) + metrics_df = pd.DataFrame({"Overall": [0.8, np.nan, 0.9]}, index=["Precision", "Recall", "F1"]) colors = ["blue", "green", "red"] plot_overall_metrics(metrics_df, colors) @@ -108,9 +99,7 @@ def test_non_unique_metric_names(self): """ Test with non-unique metric names to check handling of index duplication. """ - metrics_df = pd.DataFrame( - {"Overall": [0.8, 0.75, 0.9]}, index=["Precision", "Precision", "F1"] - ) + metrics_df = pd.DataFrame({"Overall": [0.8, 0.75, 0.9]}, index=["Precision", "Precision", "F1"]) colors = ["blue", "green", "red"] plot_overall_metrics(metrics_df, colors) @@ -118,9 +107,7 @@ def test_extremely_large_values(self): """ Test with extremely large values to check plot scaling. """ - metrics_df = pd.DataFrame( - {"Overall": [1e10, 5e10, 1e11]}, index=["Metric1", "Metric2", "Metric3"] - ) + metrics_df = pd.DataFrame({"Overall": [1e10, 5e10, 1e11]}, index=["Metric1", "Metric2", "Metric3"]) colors = ["blue", "green", "red"] plot_overall_metrics(metrics_df, colors) @@ -130,6 +117,9 @@ class TestPlotMetricsPerClass: Test suite for the plot_metrics_per_class function. """ + def setup_method(self): + self.rng = np.random.default_rng(42) + def test_valid_input(self): """ Test with valid inputs to ensure the function runs without errors. @@ -151,7 +141,7 @@ def test_empty_metrics_df(self): """ metrics_df = pd.DataFrame() colors = [] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("metrics_df is empty.")): plot_metrics_per_class(metrics_df, colors) def test_mismatched_colors_length(self): @@ -224,9 +214,7 @@ def test_non_string_metric_names(self): """ Test with non-string metric names to ensure labels are handled correctly. """ - metrics_df = pd.DataFrame( - {"Class1": [0.8, 0.7, 0.9], "Class2": [0.85, 0.75, 0.95]}, index=[1, 2, 3] - ) + metrics_df = pd.DataFrame({"Class1": [0.8, 0.7, 0.9], "Class2": [0.85, 0.75, 0.95]}, index=[1, 2, 3]) colors = ["blue", "green"] plot_metrics_per_class(metrics_df, colors) @@ -235,10 +223,8 @@ def test_many_classes(self): Test with many classes to check plotting scales correctly. """ classes = [f"Class{i}" for i in range(20)] - data = np.random.rand(3, 20) - metrics_df = pd.DataFrame( - data, index=["Precision", "Recall", "F1"], columns=classes - ) + data = self.rng.random((3, 20)) + metrics_df = pd.DataFrame(data, index=["Precision", "Recall", "F1"], columns=classes) colors = ["blue", "green", "red"] plot_metrics_per_class(metrics_df, colors) @@ -248,21 +234,22 @@ class TestPlotMetricsAcrossThresholds: Test suite for the plot_metrics_across_thresholds function. """ + def setup_method(self): + self.rng = np.random.default_rng(42) + def test_valid_input(self): """ Test with valid inputs to ensure the function runs without errors. """ thresholds = np.linspace(0, 1, 10) metric_values_dict = { - "precision": np.random.rand(10), - "recall": np.random.rand(10), - "f1": np.random.rand(10), + "precision": self.rng.random(10), + "recall": self.rng.random(10), + "f1": self.rng.random(10), } metrics_to_plot = ["precision", "recall", "f1"] colors = ["blue", "green", "red"] - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_empty_thresholds(self): """ @@ -272,10 +259,9 @@ def test_empty_thresholds(self): metric_values_dict = {} metrics_to_plot = [] colors = [] - with pytest.raises(ValueError): - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + + with pytest.raises(ValueError, match=re.escape("thresholds array is empty.")): + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_mismatched_lengths(self): """ @@ -283,16 +269,14 @@ def test_mismatched_lengths(self): """ thresholds = np.linspace(0, 1, 10) metric_values_dict = { - "precision": np.random.rand(8), # Should be length 10 - "recall": np.random.rand(10), - "f1": np.random.rand(10), + "precision": self.rng.random(8), # Should be length 10 + "recall": self.rng.random(10), + "f1": self.rng.random(10), } metrics_to_plot = ["precision", "recall", "f1"] colors = ["blue", "green", "red"] - with pytest.raises(ValueError): - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + with pytest.raises(ValueError, match=re.escape("Length of metric 'precision' values does not match length of thresholds.")): + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_invalid_thresholds_type(self): """ @@ -303,35 +287,29 @@ def test_invalid_thresholds_type(self): metrics_to_plot = [] colors = [] with pytest.raises(TypeError): - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_invalid_metrics_dict_type(self): """ Test with invalid type for metric_values_dict. """ thresholds = np.linspace(0, 1, 10) - metric_values_dict = [("precision", np.random.rand(10))] + metric_values_dict = [("precision", self.rng.random(10))] metrics_to_plot = ["precision"] colors = ["blue"] with pytest.raises(TypeError): - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_invalid_metrics_to_plot_type(self): """ Test with invalid type for metrics_to_plot. """ thresholds = np.linspace(0, 1, 10) - metric_values_dict = {"precision": np.random.rand(10)} + metric_values_dict = {"precision": self.rng.random(10)} metrics_to_plot = "precision" # Should be a list colors = ["blue"] with pytest.raises(TypeError): - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_nan_values(self): """ @@ -339,15 +317,13 @@ def test_nan_values(self): """ thresholds = np.linspace(0, 1, 10) metric_values_dict = { - "precision": np.append(np.random.rand(9), np.nan), - "recall": np.random.rand(10), - "f1": np.random.rand(10), + "precision": np.append(self.rng.random(9), np.nan), + "recall": self.rng.random(10), + "f1": self.rng.random(10), } metrics_to_plot = ["precision", "recall", "f1"] colors = ["blue", "green", "red"] - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_empty_colors_list(self): """ @@ -355,15 +331,13 @@ def test_empty_colors_list(self): """ thresholds = np.linspace(0, 1, 10) metric_values_dict = { - "precision": np.random.rand(10), - "recall": np.random.rand(10), - "f1": np.random.rand(10), + "precision": self.rng.random(10), + "recall": self.rng.random(10), + "f1": self.rng.random(10), } metrics_to_plot = ["precision", "recall", "f1"] colors = [] - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_many_metrics(self): """ @@ -371,11 +345,9 @@ def test_many_metrics(self): """ thresholds = np.linspace(0, 1, 10) metrics_to_plot = [f"metric{i}" for i in range(20)] - metric_values_dict = {metric: np.random.rand(10) for metric in metrics_to_plot} + metric_values_dict = {metric: self.rng.random(10) for metric in metrics_to_plot} colors = ["blue", "green", "red"] * 7 - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_mismatched_colors_length(self): """ @@ -383,14 +355,12 @@ def test_mismatched_colors_length(self): """ thresholds = np.linspace(0, 1, 10) metric_values_dict = { - "precision": np.random.rand(10), - "recall": np.random.rand(10), + "precision": self.rng.random(10), + "recall": self.rng.random(10), } metrics_to_plot = ["precision", "recall"] colors = ["blue"] # Only one color provided - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) def test_large_thresholds_array(self): """ @@ -398,15 +368,13 @@ def test_large_thresholds_array(self): """ thresholds = np.linspace(0, 1, 1000) metric_values_dict = { - "precision": np.random.rand(1000), - "recall": np.random.rand(1000), - "f1": np.random.rand(1000), + "precision": self.rng.random(1000), + "recall": self.rng.random(1000), + "f1": self.rng.random(1000), } metrics_to_plot = ["precision", "recall", "f1"] colors = ["blue", "green", "red"] - plot_metrics_across_thresholds( - thresholds, metric_values_dict, metrics_to_plot, colors - ) + plot_metrics_across_thresholds(thresholds, metric_values_dict, metrics_to_plot, colors) class TestPlotMetricsAcrossThresholdsPerClass: @@ -414,6 +382,9 @@ class TestPlotMetricsAcrossThresholdsPerClass: Test suite for the plot_metrics_across_thresholds_per_class function. """ + def setup_method(self): + self.rng = np.random.default_rng(42) + def test_valid_input(self): """ Test with valid inputs to ensure the function runs without errors. @@ -422,8 +393,8 @@ def test_valid_input(self): class_names = ["Class1", "Class2"] metrics_to_plot = ["precision", "recall"] metric_values_dict_per_class = { - "Class1": {"precision": np.random.rand(10), "recall": np.random.rand(10)}, - "Class2": {"precision": np.random.rand(10), "recall": np.random.rand(10)}, + "Class1": {"precision": self.rng.random(10), "recall": self.rng.random(10)}, + "Class2": {"precision": self.rng.random(10), "recall": self.rng.random(10)}, } colors = ["blue", "green"] plot_metrics_across_thresholds_per_class( @@ -443,7 +414,7 @@ def test_empty_thresholds(self): metrics_to_plot = [] metric_values_dict_per_class = {} colors = [] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("thresholds array is empty.")): plot_metrics_across_thresholds_per_class( thresholds, metric_values_dict_per_class, @@ -477,9 +448,7 @@ def test_nan_values(self): thresholds = np.linspace(0, 1, 10) class_names = ["Class1"] metrics_to_plot = ["precision"] - metric_values_dict_per_class = { - "Class1": {"precision": np.append(np.random.rand(9), np.nan)} - } + metric_values_dict_per_class = {"Class1": {"precision": np.append(self.rng.random(9), np.nan)}} colors = ["blue"] plot_metrics_across_thresholds_per_class( thresholds, @@ -496,7 +465,7 @@ def test_empty_colors_list(self): thresholds = np.linspace(0, 1, 10) class_names = ["Class1"] metrics_to_plot = ["precision"] - metric_values_dict_per_class = {"Class1": {"precision": np.random.rand(10)}} + metric_values_dict_per_class = {"Class1": {"precision": self.rng.random(10)}} colors = [] plot_metrics_across_thresholds_per_class( thresholds, @@ -513,9 +482,7 @@ def test_many_classes(self): thresholds = np.linspace(0, 1, 10) class_names = [f"Class{i}" for i in range(20)] metrics_to_plot = ["precision"] - metric_values_dict_per_class = { - class_name: {"precision": np.random.rand(10)} for class_name in class_names - } + metric_values_dict_per_class = {class_name: {"precision": self.rng.random(10)} for class_name in class_names} colors = ["blue", "green", "red"] * 7 plot_metrics_across_thresholds_per_class( thresholds, @@ -533,8 +500,8 @@ def test_mismatched_colors_length(self): class_names = ["Class1", "Class2"] metrics_to_plot = ["precision"] metric_values_dict_per_class = { - "Class1": {"precision": np.random.rand(10)}, - "Class2": {"precision": np.random.rand(10)}, + "Class1": {"precision": self.rng.random(10)}, + "Class2": {"precision": self.rng.random(10)}, } colors = ["blue"] # Only one color provided plot_metrics_across_thresholds_per_class( @@ -552,7 +519,7 @@ def test_invalid_metrics_to_plot(self): thresholds = np.linspace(0, 1, 10) class_names = ["Class1"] metrics_to_plot = "precision" # Should be a list - metric_values_dict_per_class = {"Class1": {"precision": np.random.rand(10)}} + metric_values_dict_per_class = {"Class1": {"precision": self.rng.random(10)}} colors = ["blue"] with pytest.raises(TypeError): plot_metrics_across_thresholds_per_class( @@ -571,7 +538,7 @@ def test_missing_class_in_dict(self): class_names = ["Class1", "Class2"] metrics_to_plot = ["precision"] metric_values_dict_per_class = { - "Class1": {"precision": np.random.rand(10)} + "Class1": {"precision": self.rng.random(10)} # 'Class2' is missing } colors = ["blue", "green"] @@ -592,10 +559,13 @@ def test_mismatched_lengths(self): class_names = ["Class1"] metrics_to_plot = ["precision"] metric_values_dict_per_class = { - "Class1": {"precision": np.random.rand(9)} # Length should be 10 + "Class1": {"precision": self.rng.random(9)} # Length should be 10 } colors = ["blue"] - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape("Length of metric 'precision' values for class 'Class1' does not match length of thresholds."), + ): plot_metrics_across_thresholds_per_class( thresholds, metric_values_dict_per_class, @@ -611,7 +581,7 @@ def test_large_thresholds_array(self): thresholds = np.linspace(0, 1, 1000) class_names = ["Class1"] metrics_to_plot = ["precision"] - metric_values_dict_per_class = {"Class1": {"precision": np.random.rand(1000)}} + metric_values_dict_per_class = {"Class1": {"precision": self.rng.random(1000)}} colors = ["blue"] plot_metrics_across_thresholds_per_class( thresholds, @@ -627,6 +597,9 @@ class TestPlotConfusionMatrices: Test suite for the plot_confusion_matrices function. """ + def setup_method(self): + self.rng = np.random.default_rng(42) # For reproducibility + def test_binary_task(self): """ Test with binary task to ensure it runs without errors. @@ -652,7 +625,7 @@ def test_invalid_task(self): conf_mat = np.array([[50, 10], [5, 35]]) task = "invalid_task" class_names = ["Positive", "Negative"] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("Invalid task. Expected 'binary', 'multiclass', or 'multilabel'.")): plot_confusion_matrices(conf_mat, task, class_names) def test_empty_conf_mat(self): @@ -662,7 +635,7 @@ def test_empty_conf_mat(self): conf_mat = np.array([]) task = "binary" class_names = ["Positive", "Negative"] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("conf_mat is empty.")): plot_confusion_matrices(conf_mat, task, class_names) def test_mismatched_class_names(self): @@ -672,7 +645,7 @@ def test_mismatched_class_names(self): conf_mat = np.array([[50, 10], [5, 35]]) task = "binary" class_names = ["Positive"] # Should be two class names - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="For binary task, class_names must have exactly two elements."): plot_confusion_matrices(conf_mat, task, class_names) def test_invalid_conf_mat_type(self): @@ -708,7 +681,7 @@ def test_many_classes(self): Test with many classes to check plotting scales correctly. """ num_classes = 10 - conf_mat = np.random.randint(0, 100, size=(num_classes, 2, 2)) + conf_mat = self.rng.integers(0, 100, size=(num_classes, 2, 2)) task = "multilabel" class_names = [f"Class{i}" for i in range(num_classes)] plot_confusion_matrices(conf_mat, task, class_names) @@ -720,5 +693,5 @@ def test_invalid_conf_mat_shape(self): conf_mat = np.array([50, 10, 5, 35]) # Should be 2x2 or Nx2x2 task = "binary" class_names = ["Positive", "Negative"] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("For binary task, conf_mat must be of shape (2, 2).")): plot_confusion_matrices(conf_mat, task, class_names) diff --git a/tests/evaluation/preprocessing/__init__.py b/tests/evaluation/preprocessing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/preprocessing/test_data_processor.py b/tests/evaluation/preprocessing/test_data_processor.py similarity index 86% rename from tests/preprocessing/test_data_processor.py rename to tests/evaluation/preprocessing/test_data_processor.py index 43bb98a4..37fe073e 100644 --- a/tests/preprocessing/test_data_processor.py +++ b/tests/evaluation/preprocessing/test_data_processor.py @@ -8,7 +8,6 @@ class TestDataProcessorInit: - @patch("pandas.read_csv") def test_init_with_all_parameters(self, mock_read_csv): """Test initializing DataProcessor with all parameters.""" @@ -54,7 +53,7 @@ def test_init_with_all_parameters(self, mock_read_csv): ) # Your assertions here - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_init_with_minimal_parameters(self, mock_read_concat): """Test initializing DataProcessor with minimal parameters.""" # Mock the dataframes returned by the read function @@ -101,9 +100,7 @@ def test_init_with_invalid_sample_duration(self, mock_load_data): sample_duration=0, ) - with pytest.raises( - ValueError, match="Sample duration cannot exceed the recording duration." - ): + with pytest.raises(ValueError, match="Sample duration cannot exceed the recording duration."): DataProcessor( prediction_directory_path="", annotation_directory_path="", @@ -128,16 +125,14 @@ def test_init_with_invalid_min_overlap(self, mock_load_data): min_overlap=0, ) - with pytest.raises( - ValueError, match="Min overlap cannot exceed the sample duration." - ): + with pytest.raises(ValueError, match="Min overlap cannot exceed the sample duration."): DataProcessor( prediction_directory_path="", annotation_directory_path="", min_overlap=6, # Greater than default sample_duration=3 ) - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_init_with_nonexistent_paths(self, mock_read_concat): """Test initializing with paths that do not exist.""" # Mock the dataframes to be empty but with required columns @@ -169,7 +164,7 @@ def test_init_with_nonexistent_paths(self, mock_read_concat): pd.testing.assert_frame_equal(dp.predictions_df, mock_predictions_df) pd.testing.assert_frame_equal(dp.annotations_df, mock_annotations_df) - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_init_with_none_columns(self, mock_read_concat): """Test initializing with None columns mappings.""" # Mock the dataframes returned by the read function @@ -201,7 +196,7 @@ def test_init_with_none_columns(self, mock_read_concat): assert dp.columns_predictions == dp.DEFAULT_COLUMNS_PREDICTIONS assert dp.columns_annotations == dp.DEFAULT_COLUMNS_ANNOTATIONS - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_init_with_empty_class_mapping(self, mock_read_concat): """Test initializing with empty class_mapping.""" # Mock the dataframes returned by the read function @@ -224,24 +219,20 @@ def test_init_with_empty_class_mapping(self, mock_read_concat): # Set side effect for the mock mock_read_concat.side_effect = [mock_predictions_df, mock_annotations_df] - dp = DataProcessor( - prediction_directory_path="", annotation_directory_path="", class_mapping={} - ) + dp = DataProcessor(prediction_directory_path="", annotation_directory_path="", class_mapping={}) assert dp.class_mapping == {} @patch.object(DataProcessor, "load_data") def test_init_with_invalid_recording_duration(self, mock_load_data): """Test initializing with negative recording_duration.""" - with pytest.raises( - ValueError, match="Recording duration must be greater than 0." - ): + with pytest.raises(ValueError, match="Recording duration must be greater than 0."): DataProcessor( prediction_directory_path="", annotation_directory_path="", recording_duration=-10, ) - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_init_with_large_sample_duration(self, mock_read_concat): """Test initializing with large sample_duration.""" # Mock the dataframes returned by the read function @@ -271,7 +262,7 @@ def test_init_with_large_sample_duration(self, mock_read_concat): ) assert dp.sample_duration == 1000 - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_init_with_non_default_columns(self, mock_read_concat): """Test initializing with custom columns mappings.""" # Mock the dataframes returned by the read function @@ -323,18 +314,17 @@ def test_init_with_non_default_columns(self, mock_read_concat): # Since all required columns are provided, this isn't necessary, # but if you want to check optional columns: optional_col = "Confidence" - assert dp.get_column_name( - optional_col, prediction=True - ) == dp.DEFAULT_COLUMNS_PREDICTIONS.get(optional_col, optional_col) + assert dp.get_column_name(optional_col, prediction=True) == dp.DEFAULT_COLUMNS_PREDICTIONS.get( + optional_col, optional_col + ) optional_col = "Recording" - assert dp.get_column_name( - optional_col, prediction=False - ) == dp.DEFAULT_COLUMNS_ANNOTATIONS.get(optional_col, optional_col) + assert dp.get_column_name(optional_col, prediction=False) == dp.DEFAULT_COLUMNS_ANNOTATIONS.get( + optional_col, optional_col + ) class TestDataProcessorLoadData: - - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_load_data_with_none_filenames(self, mock_read_concat): """Test load_data when prediction_file_name and annotation_file_name are None.""" # Mocking the DataFrames returned by the utility function @@ -417,12 +407,8 @@ def test_load_data_with_specific_filenames(self, mock_read_csv): ) # Ensure that predictions_df and annotations_df are set correctly - pd.testing.assert_frame_equal( - dp.predictions_df, mock_predictions_df.assign(source_file="predictions.txt") - ) - pd.testing.assert_frame_equal( - dp.annotations_df, mock_annotations_df.assign(source_file="annotations.txt") - ) + pd.testing.assert_frame_equal(dp.predictions_df, mock_predictions_df.assign(source_file="predictions.txt")) + pd.testing.assert_frame_equal(dp.annotations_df, mock_annotations_df.assign(source_file="annotations.txt")) @patch("pandas.read_csv") def test_load_data_missing_prediction_file(self, mock_read_csv): @@ -432,14 +418,13 @@ def test_load_data_missing_prediction_file(self, mock_read_csv): def side_effect(*args, **kwargs): if "predictions.txt" in args[0]: raise FileNotFoundError("File not found") - else: - return pd.DataFrame( - { - "Class": ["A", "C"], - "Start Time": [0.5, 1.5], - "End Time": [1.5, 2.5], - } - ) + return pd.DataFrame( + { + "Class": ["A", "C"], + "Start Time": [0.5, 1.5], + "End Time": [1.5, 2.5], + } + ) mock_read_csv.side_effect = side_effect @@ -470,14 +455,14 @@ def test_load_data_missing_annotation_file(self, mock_read_csv): def side_effect(*args, **kwargs): if "annotations.txt" in args[0]: raise FileNotFoundError("File not found") - else: - return pd.DataFrame( - { - "Class": ["A", "B"], - "Start Time": [0, 1], - "End Time": [1, 2], - } - ) + + return pd.DataFrame( + { + "Class": ["A", "B"], + "Start Time": [0, 1], + "End Time": [1, 2], + } + ) mock_read_csv.side_effect = side_effect @@ -500,7 +485,7 @@ def side_effect(*args, **kwargs): recording_duration=10, ) - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_load_data_with_empty_directories(self, mock_read_concat): """Test load_data when directories are empty.""" mock_read_concat.return_value = pd.DataFrame( @@ -570,12 +555,10 @@ def test_load_data_filenames_do_not_match(self, mock_read_csv): recording_duration=10, ) - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_load_data_inconsistent_columns(self, mock_read_concat): """Test load_data when files have inconsistent columns.""" - mock_read_concat.side_effect = ValueError( - "File has different columns than previous files." - ) + mock_read_concat.side_effect = ValueError("File has different columns than previous files.") with pytest.raises(ValueError, match="different columns than previous files"): DataProcessor( @@ -596,7 +579,7 @@ def test_load_data_inconsistent_columns(self, mock_read_concat): recording_duration=10, ) - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_load_data_handle_different_encodings(self, mock_read_concat): """Test load_data handling different file encodings.""" mock_read_concat.return_value = pd.DataFrame( @@ -626,7 +609,7 @@ def test_load_data_handle_different_encodings(self, mock_read_concat): ) # Should proceed without errors - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_load_data_with_class_mapping(self, mock_read_concat): """Test load_data applying class mapping.""" mock_predictions_df = pd.DataFrame( @@ -670,7 +653,7 @@ def test_load_data_with_class_mapping(self, mock_read_concat): expected_classes = ("C", "ClassA", "ClassB", "ClassC") assert dp.classes == expected_classes - @patch("bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory") + @patch("birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory") def test_load_data_without_class_mapping(self, mock_read_concat): """Test load_data without class mapping.""" mock_predictions_df = pd.DataFrame( @@ -715,7 +698,6 @@ def test_load_data_without_class_mapping(self, mock_read_concat): class TestDataProcessorValidateParameters: - @patch.object(DataProcessor, "load_data") def test_sample_duration_zero(self, mock_load_data): """Test sample_duration=0 raises ValueError.""" @@ -762,13 +744,9 @@ def test_min_overlap_negative(self, mock_load_data, mock_process_data): @patch.object(DataProcessor, "process_data") @patch.object(DataProcessor, "load_data") - def test_min_overlap_greater_than_sample_duration( - self, mock_load_data, mock_process_data - ): + def test_min_overlap_greater_than_sample_duration(self, mock_load_data, mock_process_data): """Test min_overlap > sample_duration raises ValueError.""" - with pytest.raises( - ValueError, match="Min overlap cannot exceed the sample duration." - ): + with pytest.raises(ValueError, match="Min overlap cannot exceed the sample duration."): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -816,9 +794,7 @@ def test_recording_duration_none(self, mock_load_data, mock_process_data): @patch.object(DataProcessor, "load_data") def test_recording_duration_zero(self, mock_load_data): """Test recording_duration=0 raises ValueError.""" - with pytest.raises( - ValueError, match="Recording duration must be greater than 0." - ): + with pytest.raises(ValueError, match="Recording duration must be greater than 0."): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -828,9 +804,7 @@ def test_recording_duration_zero(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_recording_duration_negative(self, mock_load_data): """Test negative recording_duration raises ValueError.""" - with pytest.raises( - ValueError, match="Recording duration must be greater than 0." - ): + with pytest.raises(ValueError, match="Recording duration must be greater than 0."): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -839,7 +813,6 @@ def test_recording_duration_negative(self, mock_load_data): class TestDataProcessorValidateColumns: - @patch.object(DataProcessor, "process_data") @patch.object(DataProcessor, "load_data") def test_columns_all_required_present(self, mock_load_data, mock_process_data): @@ -860,16 +833,12 @@ def test_columns_all_required_present(self, mock_load_data, mock_process_data): }, ) except ValueError: - pytest.fail( - "Unexpected ValueError raised with all required columns present" - ) + pytest.fail("Unexpected ValueError raised with all required columns present") @patch.object(DataProcessor, "load_data") def test_columns_predictions_missing_start_time(self, mock_load_data): """Test missing 'Start Time' in columns_predictions raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None prediction columns: Start Time" - ): + with pytest.raises(ValueError, match="Missing or None prediction columns: Start Time"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -888,9 +857,7 @@ def test_columns_predictions_missing_start_time(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_predictions_missing_end_time(self, mock_load_data): """Test missing 'End Time' in columns_predictions raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None prediction columns: End Time" - ): + with pytest.raises(ValueError, match="Missing or None prediction columns: End Time"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -909,9 +876,7 @@ def test_columns_predictions_missing_end_time(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_predictions_missing_class(self, mock_load_data): """Test missing 'Class' in columns_predictions raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None prediction columns: Class" - ): + with pytest.raises(ValueError, match="Missing or None prediction columns: Class"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -930,9 +895,7 @@ def test_columns_predictions_missing_class(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_annotations_missing_start_time(self, mock_load_data): """Test missing 'Start Time' in columns_annotations raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None annotation columns: Start Time" - ): + with pytest.raises(ValueError, match="Missing or None annotation columns: Start Time"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -951,9 +914,7 @@ def test_columns_annotations_missing_start_time(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_annotations_missing_end_time(self, mock_load_data): """Test missing 'End Time' in columns_annotations raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None annotation columns: End Time" - ): + with pytest.raises(ValueError, match="Missing or None annotation columns: End Time"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -972,9 +933,7 @@ def test_columns_annotations_missing_end_time(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_annotations_missing_class(self, mock_load_data): """Test missing 'Class' in columns_annotations raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None annotation columns: Class" - ): + with pytest.raises(ValueError, match="Missing or None annotation columns: Class"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -993,9 +952,7 @@ def test_columns_annotations_missing_class(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_predictions_start_time_none(self, mock_load_data): """Test 'Start Time' in columns_predictions set to None raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None prediction columns: Start Time" - ): + with pytest.raises(ValueError, match="Missing or None prediction columns: Start Time"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -1014,9 +971,7 @@ def test_columns_predictions_start_time_none(self, mock_load_data): @patch.object(DataProcessor, "load_data") def test_columns_annotations_end_time_none(self, mock_load_data): """Test 'End Time' in columns_annotations set to None raises ValueError.""" - with pytest.raises( - ValueError, match="Missing or None annotation columns: End Time" - ): + with pytest.raises(ValueError, match="Missing or None annotation columns: End Time"): DataProcessor( prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", @@ -1070,12 +1025,11 @@ def test_columns_annotations_empty_dict(self, mock_load_data): class TestDataProcessorPrepareDataFrame: - def setup_method(self): """Set up a DataProcessor instance for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock empty DataFrames for predictions and annotations @@ -1112,9 +1066,7 @@ def test_with_recording_column(self): def test_with_source_file_column(self): """Test DataFrame without 'Recording' but with 'source_file'.""" - df = pd.DataFrame( - {"source_file": ["file1.txt", "file2.txt"], "OtherColumn": [1, 2]} - ) + df = pd.DataFrame({"source_file": ["file1.txt", "file2.txt"], "OtherColumn": [1, 2]}) result_df = self.dp._prepare_dataframe(df.copy(), prediction=False) expected_filenames = ["file1", "file2"] assert result_df["recording_filename"].tolist() == expected_filenames @@ -1159,9 +1111,7 @@ def test_recording_column_with_non_string_values(self): def test_source_file_with_none_values(self): """Test 'source_file' column containing None or NaN.""" - df = pd.DataFrame( - {"source_file": [None, float("nan"), "file.txt"], "OtherColumn": [1, 2, 3]} - ) + df = pd.DataFrame({"source_file": [None, float("nan"), "file.txt"], "OtherColumn": [1, 2, 3]}) result_df = self.dp._prepare_dataframe(df.copy(), prediction=True) expected_filenames = [None, float("nan"), "file"] pd.testing.assert_series_equal( @@ -1179,9 +1129,7 @@ def test_recording_column_with_empty_strings(self): def test_complex_paths_in_recording_column(self): """Test 'Recording' column with complex paths.""" - df = pd.DataFrame( - {"Recording": ["/a/b/c/d/e.wav", "C:\\folder\\subfolder\\file.wav"]} - ) + df = pd.DataFrame({"Recording": ["/a/b/c/d/e.wav", "C:/folder/subfolder/file.wav"]}) result_df = self.dp._prepare_dataframe(df.copy(), prediction=True) expected_filenames = ["e", "file"] assert result_df["recording_filename"].tolist() == expected_filenames @@ -1201,12 +1149,11 @@ def test_both_recording_and_source_file_columns(self): class TestDataProcessorProcessData: - def setup_method(self): """Set up a DataProcessor instance for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock empty DataFrames for predictions and annotations @@ -1224,9 +1171,7 @@ def setup_method(self): prediction_directory_path="dummy_pred_path", annotation_directory_path="dummy_annot_path", ) - self.dp.process_recording = MagicMock( - return_value=pd.DataFrame({"sample_data": [1, 2, 3]}) - ) + self.dp.process_recording = MagicMock(return_value=pd.DataFrame({"sample_data": [1, 2, 3]})) def teardown_method(self): """Stop patching.""" @@ -1241,87 +1186,59 @@ def test_empty_predictions_and_annotations(self): def test_single_recording(self): """Test with single recording in predictions and annotations.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [1]} - ) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [2]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [1]}) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [2]}) self.dp.process_data() assert not self.dp.samples_df.empty assert self.dp.process_recording.call_count == 1 def test_multiple_recordings(self): """Test with multiple recordings.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1", "rec2"], "OtherColumn": [1, 2]} - ) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec1", "rec2"], "OtherColumn": [3, 4]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1", "rec2"], "OtherColumn": [1, 2]}) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec1", "rec2"], "OtherColumn": [3, 4]}) self.dp.process_data() assert not self.dp.samples_df.empty assert self.dp.process_recording.call_count == 2 def test_predictions_extra_recordings(self): """Test with recordings in predictions not in annotations.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1", "rec2"], "OtherColumn": [1, 2]} - ) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [3]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1", "rec2"], "OtherColumn": [1, 2]}) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [3]}) self.dp.process_data() assert self.dp.process_recording.call_count == 2 def test_annotations_extra_recordings(self): """Test with recordings in annotations not in predictions.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [1]} - ) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec1", "rec2"], "OtherColumn": [3, 4]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [1]}) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec1", "rec2"], "OtherColumn": [3, 4]}) self.dp.process_data() assert self.dp.process_recording.call_count == 2 def test_duplicate_recording_entries(self): """Test with duplicate recording filenames.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1", "rec1"], "OtherColumn": [1, 2]} - ) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [3]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1", "rec1"], "OtherColumn": [1, 2]}) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [3]}) self.dp.process_data() assert self.dp.process_recording.call_count == 1 def test_missing_recording_filename_in_predictions(self): """Test missing 'recording_filename' in predictions.""" self.dp.predictions_df = pd.DataFrame({"OtherColumn": [1]}) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [2]} - ) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [2]}) with pytest.raises(KeyError): self.dp.process_data() def test_missing_recording_filename_in_annotations(self): """Test missing 'recording_filename' in annotations.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [1]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [1]}) self.dp.annotations_df = pd.DataFrame({"OtherColumn": [2]}) with pytest.raises(KeyError): self.dp.process_data() def test_no_overlapping_recordings(self): """Test with no overlapping recording filenames.""" - self.dp.predictions_df = pd.DataFrame( - {"recording_filename": ["rec1"], "OtherColumn": [1]} - ) - self.dp.annotations_df = pd.DataFrame( - {"recording_filename": ["rec2"], "OtherColumn": [2]} - ) + self.dp.predictions_df = pd.DataFrame({"recording_filename": ["rec1"], "OtherColumn": [1]}) + self.dp.annotations_df = pd.DataFrame({"recording_filename": ["rec2"], "OtherColumn": [2]}) self.dp.process_data() assert self.dp.process_recording.call_count == 2 @@ -1345,12 +1262,11 @@ def test_large_number_of_recordings(self): class TestDataProcessorProcessRecording: - def setup_method(self): """Set up a DataProcessor instance for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock empty DataFrames for predictions and annotations @@ -1379,12 +1295,8 @@ def teardown_method(self): def test_valid_predictions_and_annotations(self): """Test with valid predictions and annotations.""" - pred_df = pd.DataFrame( - {"Class": ["A", "B"], "Start Time": [0, 5], "End Time": [5, 10]} - ) - annot_df = pd.DataFrame( - {"Class": ["A", "C"], "Start Time": [2, 7], "End Time": [7, 12]} - ) + pred_df = pd.DataFrame({"Class": ["A", "B"], "Start Time": [0, 5], "End Time": [5, 10]}) + annot_df = pd.DataFrame({"Class": ["A", "C"], "Start Time": [2, 7], "End Time": [7, 12]}) samples_df = self.dp.process_recording("rec1", pred_df, annot_df) assert not samples_df.empty assert len(samples_df) == 3 # 15 / 5 = 3 samples @@ -1395,25 +1307,15 @@ def test_empty_predictions_and_annotations(self): annot_df = pd.DataFrame() samples_df = self.dp.process_recording("rec1", pred_df, annot_df) assert not samples_df.empty - assert ( - (samples_df[[f"{cls}_confidence" for cls in self.dp.classes]] == 0) - .all() - .all() - ) - assert ( - (samples_df[[f"{cls}_annotation" for cls in self.dp.classes]] == 0) - .all() - .all() - ) + assert (samples_df[[f"{cls}_confidence" for cls in self.dp.classes]] == 0).all().all() + assert (samples_df[[f"{cls}_annotation" for cls in self.dp.classes]] == 0).all().all() def test_only_predictions_present(self): """Test with only predictions present.""" pred_df = pd.DataFrame({"Class": ["A"], "Start Time": [0], "End Time": [5]}) annot_df = pd.DataFrame() samples_df = self.dp.process_recording("rec1", pred_df, annot_df) - assert ( - samples_df["A_confidence"].iloc[0] == 0.0 - ) # Default confidence since 'Confidence' column missing + assert samples_df["A_confidence"].iloc[0] == 0.0 # Default confidence since 'Confidence' column missing assert samples_df["A_annotation"].iloc[0] == 0 # No annotations def test_only_annotations_present(self): @@ -1455,9 +1357,7 @@ def test_custom_sample_duration_and_min_overlap(self): self.dp.sample_duration = 3 self.dp.min_overlap = 0.1 pred_df = pd.DataFrame({"Class": ["A"], "Start Time": [1], "End Time": [2]}) - annot_df = pd.DataFrame( - {"Class": ["A"], "Start Time": [1.5], "End Time": [2.5]} - ) + annot_df = pd.DataFrame({"Class": ["A"], "Start Time": [1.5], "End Time": [2.5]}) samples_df = self.dp.process_recording("rec1", pred_df, annot_df) assert len(samples_df) == 5 # 15 / 3 = 5 samples # Check if overlaps are correctly calculated @@ -1468,16 +1368,8 @@ def test_classes_not_in_self_classes(self): annot_df = pd.DataFrame({"Class": ["D"], "Start Time": [0], "End Time": [5]}) samples_df = self.dp.process_recording("rec1", pred_df, annot_df) # Since 'D' is not in self.classes, it should be skipped - assert ( - (samples_df[[f"{cls}_confidence" for cls in self.dp.classes]] == 0) - .all() - .all() - ) - assert ( - (samples_df[[f"{cls}_annotation" for cls in self.dp.classes]] == 0) - .all() - .all() - ) + assert (samples_df[[f"{cls}_confidence" for cls in self.dp.classes]] == 0).all().all() + assert (samples_df[[f"{cls}_annotation" for cls in self.dp.classes]] == 0).all().all() def test_zero_recording_duration(self): """Test with zero recording duration.""" @@ -1493,7 +1385,7 @@ def setup_method(self): """Set up a DataProcessor instance for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock the function to return DataFrames with expected columns @@ -1535,9 +1427,7 @@ def test_recording_duration_set_with_data(self): def test_duration_column_in_predictions(self): """Test when 'Duration' column is in predictions DataFrame.""" self.dp.recording_duration = None - pred_df = pd.DataFrame( - {"Start Time": [0], "End Time": [50], "Duration": [100.0]} - ) + pred_df = pd.DataFrame({"Start Time": [0], "End Time": [50], "Duration": [100.0]}) annot_df = pd.DataFrame() duration = self.dp.determine_file_duration(pred_df, annot_df) assert duration == 100.0 @@ -1546,9 +1436,7 @@ def test_duration_column_in_annotations(self): """Test when 'Duration' column is in annotations DataFrame.""" self.dp.recording_duration = None pred_df = pd.DataFrame() - annot_df = pd.DataFrame( - {"Start Time": [0], "End Time": [50], "Duration": [90.0]} - ) + annot_df = pd.DataFrame({"Start Time": [0], "End Time": [50], "Duration": [90.0]}) duration = self.dp.determine_file_duration(pred_df, annot_df) assert duration == 90.0 @@ -1571,12 +1459,8 @@ def test_empty_dataframes(self): def test_null_duration_columns(self): """Test when 'Duration' columns are present but all null.""" self.dp.recording_duration = None - pred_df = pd.DataFrame( - {"Start Time": [10], "End Time": [20], "Duration": [None]} - ) - annot_df = pd.DataFrame( - {"Start Time": [30], "End Time": [40], "Duration": [None]} - ) + pred_df = pd.DataFrame({"Start Time": [10], "End Time": [20], "Duration": [None]}) + annot_df = pd.DataFrame({"Start Time": [30], "End Time": [40], "Duration": [None]}) duration = self.dp.determine_file_duration(pred_df, annot_df) assert duration == 40.0 @@ -1600,12 +1484,8 @@ def test_missing_end_time_column(self): def test_mixed_null_and_non_null_duration_columns(self): """Test with mixed null and non-null 'Duration' values.""" self.dp.recording_duration = None - pred_df = pd.DataFrame( - {"Start Time": [10], "End Time": [20], "Duration": [None]} - ) - annot_df = pd.DataFrame( - {"Start Time": [30], "End Time": [40], "Duration": [50.0]} - ) + pred_df = pd.DataFrame({"Start Time": [10], "End Time": [20], "Duration": [None]}) + annot_df = pd.DataFrame({"Start Time": [30], "End Time": [40], "Duration": [50.0]}) duration = self.dp.determine_file_duration(pred_df, annot_df) assert duration == 50.0 @@ -1615,7 +1495,7 @@ def setup_method(self): """Set up a DataProcessor instance for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock the function to return DataFrames with expected columns @@ -1686,8 +1566,7 @@ def test_multiple_classes(self): self.dp.classes = ("A", "B") samples_df = self.dp.initialize_samples("rec1", 5) assert all( - col in samples_df.columns - for col in ["A_confidence", "B_confidence", "A_annotation", "B_annotation"] + col in samples_df.columns for col in ["A_confidence", "B_confidence", "A_annotation", "B_annotation"] ) def test_empty_recording_filename(self): @@ -1719,7 +1598,7 @@ def setup_method(self): """Set up a DataProcessor instance and samples DataFrame for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock the function to return DataFrames with expected columns @@ -1763,45 +1642,38 @@ def teardown_method(self): def test_single_prediction_overlapping_one_sample(self): """Test single prediction overlapping one sample.""" - pred_df = pd.DataFrame( - {"Class": ["A"], "Start Time": [1], "End Time": [4], "Confidence": [0.8]} - ) + target_confidence = 0.6 + pred_df = pd.DataFrame({"Class": ["A"], "Start Time": [1], "End Time": [4], "Confidence": [target_confidence]}) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.8 + assert self.samples_df.loc[0, "A_confidence"] == target_confidence def test_single_prediction_overlapping_multiple_samples(self): """Test single prediction overlapping multiple samples.""" - pred_df = pd.DataFrame( - {"Class": ["A"], "Start Time": [3], "End Time": [8], "Confidence": [0.6]} - ) + target_confidence = 0.6 + pred_df = pd.DataFrame({"Class": ["A"], "Start Time": [3], "End Time": [8], "Confidence": [target_confidence]}) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.6 - assert self.samples_df.loc[1, "A_confidence"] == 0.6 + assert self.samples_df.loc[0, "A_confidence"] == target_confidence + assert self.samples_df.loc[1, "A_confidence"] == target_confidence def test_multiple_predictions_same_sample(self): """Test multiple predictions overlapping the same sample for the same class.""" + target_confidence = 0.7 pred_df = pd.DataFrame( { "Class": ["A", "A"], "Start Time": [1, 1], "End Time": [4, 4], - "Confidence": [0.5, 0.7], + "Confidence": [0.5, target_confidence], } ) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.7 # Max confidence + assert self.samples_df.loc[0, "A_confidence"] == target_confidence # Max confidence def test_predictions_classes_not_in_self_classes(self): """Test predictions with classes not in self.classes.""" - pred_df = pd.DataFrame( - {"Class": ["D"], "Start Time": [0], "End Time": [5], "Confidence": [0.9]} - ) + pred_df = pd.DataFrame({"Class": ["D"], "Start Time": [0], "End Time": [5], "Confidence": [0.9]}) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert ( - (self.samples_df[["A_confidence", "B_confidence", "C_confidence"]] == 0.0) - .all() - .all() - ) + assert (self.samples_df[["A_confidence", "B_confidence", "C_confidence"]] == 0.0).all().all() def test_predictions_missing_confidence(self): """Test predictions missing 'Confidence' column.""" @@ -1818,28 +1690,26 @@ def test_predictions_missing_confidence(self): def test_predictions_with_negative_times(self): """Test predictions with negative times.""" - pred_df = pd.DataFrame( - {"Class": ["A"], "Start Time": [-3], "End Time": [2], "Confidence": [0.5]} - ) + target_confidence = 0.5 + pred_df = pd.DataFrame({"Class": ["A"], "Start Time": [-3], "End Time": [2], "Confidence": [target_confidence]}) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.5 + assert self.samples_df.loc[0, "A_confidence"] == target_confidence def test_predictions_no_overlap(self): """Test predictions that do not overlap any samples.""" - pred_df = pd.DataFrame( - {"Class": ["A"], "Start Time": [15], "End Time": [20], "Confidence": [0.9]} - ) + pred_df = pd.DataFrame({"Class": ["A"], "Start Time": [15], "End Time": [20], "Confidence": [0.9]}) self.dp.update_samples_with_predictions(pred_df, self.samples_df) assert (self.samples_df["A_confidence"] == 0.0).all() def test_predictions_with_different_min_overlap(self): """Test predictions with different min_overlap values.""" + target_confidence = 0.8 pred_df = pd.DataFrame( { "Class": ["A"], "Start Time": [4.6], "End Time": [5.1], - "Confidence": [0.8], + "Confidence": [target_confidence], } ) # With min_overlap 0.5, should not overlap @@ -1850,48 +1720,46 @@ def test_predictions_with_different_min_overlap(self): # With min_overlap 0.0, should overlap self.dp.min_overlap = 0.0 self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.8 + assert self.samples_df.loc[0, "A_confidence"] == target_confidence def test_predictions_overlapping_different_classes(self): """Test predictions overlapping different classes.""" + target_confidence_1 = 0.7 + target_confidence_2 = 0.9 pred_df = pd.DataFrame( { "Class": ["A", "B"], "Start Time": [1, 6], "End Time": [4, 9], - "Confidence": [0.7, 0.9], + "Confidence": [target_confidence_1, target_confidence_2], } ) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.7 - assert self.samples_df.loc[1, "B_confidence"] == 0.9 + assert self.samples_df.loc[0, "A_confidence"] == target_confidence_1 + assert self.samples_df.loc[1, "B_confidence"] == target_confidence_2 def test_multiple_predictions_overlapping_multiple_samples(self): """Test multiple predictions overlapping multiple samples.""" + target_confidence_1 = 0.5 + target_confidence_2 = 0.6 pred_df = pd.DataFrame( { "Class": ["A", "A"], "Start Time": [2, 7], "End Time": [6, 12], - "Confidence": [0.5, 0.6], + "Confidence": [target_confidence_1, target_confidence_2], } ) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert self.samples_df.loc[0, "A_confidence"] == 0.5 - assert self.samples_df.loc[1, "A_confidence"] == 0.6 - assert self.samples_df.loc[2, "A_confidence"] == 0.6 + assert self.samples_df.loc[0, "A_confidence"] == target_confidence_1 + assert self.samples_df.loc[1, "A_confidence"] == target_confidence_2 + assert self.samples_df.loc[2, "A_confidence"] == target_confidence_2 def test_empty_predictions_dataframe(self): """Test when pred_df is empty.""" - pred_df = pd.DataFrame( - columns=["Class", "Start Time", "End Time", "Confidence"] - ) + pred_df = pd.DataFrame(columns=["Class", "Start Time", "End Time", "Confidence"]) self.dp.update_samples_with_predictions(pred_df, self.samples_df) - assert ( - (self.samples_df[["A_confidence", "B_confidence", "C_confidence"]] == 0.0) - .all() - .all() - ) + assert (self.samples_df[["A_confidence", "B_confidence", "C_confidence"]] == 0.0).all().all() class TestUpdateSamplesWithAnnotations: @@ -1899,7 +1767,7 @@ def setup_method(self): """Set up a DataProcessor instance and samples DataFrame for testing.""" # Start patching self.patcher = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory" + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory" ) self.mock_read_concat = self.patcher.start() # Mock the function to return DataFrames with expected columns @@ -1958,11 +1826,7 @@ def test_annotations_classes_not_in_self_classes(self): """Test annotations with classes not in self.classes.""" annot_df = pd.DataFrame({"Class": ["D"], "Start Time": [0], "End Time": [5]}) self.dp.update_samples_with_annotations(annot_df, self.samples_df) - assert ( - (self.samples_df[["A_annotation", "B_annotation", "C_annotation"]] == 0) - .all() - .all() - ) + assert (self.samples_df[["A_annotation", "B_annotation", "C_annotation"]] == 0).all().all() def test_annotations_with_negative_times(self): """Test annotations with negative times.""" @@ -1978,9 +1842,7 @@ def test_annotations_no_overlap(self): def test_annotations_with_different_min_overlap(self): """Test annotations with different min_overlap values.""" - annot_df = pd.DataFrame( - {"Class": ["A"], "Start Time": [4.6], "End Time": [5.1]} - ) + annot_df = pd.DataFrame({"Class": ["A"], "Start Time": [4.6], "End Time": [5.1]}) # With min_overlap 0.5, should not overlap self.dp.min_overlap = 0.5 self.dp.update_samples_with_annotations(annot_df, self.samples_df) @@ -1993,17 +1855,13 @@ def test_annotations_with_different_min_overlap(self): def test_multiple_annotations_same_sample(self): """Test multiple annotations overlapping the same sample for the same class.""" - annot_df = pd.DataFrame( - {"Class": ["A", "A"], "Start Time": [1, 2], "End Time": [4, 5]} - ) + annot_df = pd.DataFrame({"Class": ["A", "A"], "Start Time": [1, 2], "End Time": [4, 5]}) self.dp.update_samples_with_annotations(annot_df, self.samples_df) assert self.samples_df.loc[0, "A_annotation"] == 1 # Should be set to 1 def test_multiple_annotations_overlapping_multiple_samples(self): """Test multiple annotations overlapping multiple samples.""" - annot_df = pd.DataFrame( - {"Class": ["A", "A"], "Start Time": [2, 7], "End Time": [6, 12]} - ) + annot_df = pd.DataFrame({"Class": ["A", "A"], "Start Time": [2, 7], "End Time": [6, 12]}) self.dp.update_samples_with_annotations(annot_df, self.samples_df) assert self.samples_df.loc[0, "A_annotation"] == 1 assert self.samples_df.loc[1, "A_annotation"] == 1 @@ -2013,17 +1871,11 @@ def test_empty_annotations_dataframe(self): """Test when annot_df is empty.""" annot_df = pd.DataFrame(columns=["Class", "Start Time", "End Time"]) self.dp.update_samples_with_annotations(annot_df, self.samples_df) - assert ( - (self.samples_df[["A_annotation", "B_annotation", "C_annotation"]] == 0) - .all() - .all() - ) + assert (self.samples_df[["A_annotation", "B_annotation", "C_annotation"]] == 0).all().all() def test_annotations_overlapping_different_classes(self): """Test annotations overlapping different classes.""" - annot_df = pd.DataFrame( - {"Class": ["A", "B"], "Start Time": [1, 6], "End Time": [4, 9]} - ) + annot_df = pd.DataFrame({"Class": ["A", "B"], "Start Time": [1, 6], "End Time": [4, 9]}) self.dp.update_samples_with_annotations(annot_df, self.samples_df) assert self.samples_df.loc[0, "A_annotation"] == 1 assert self.samples_df.loc[1, "B_annotation"] == 1 @@ -2034,7 +1886,7 @@ def setup_method(self): """Set up a DataProcessor instance for testing.""" # Mock the file reading functions to prevent actual file I/O self.patcher_pred = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", return_value=pd.DataFrame( { "Class": [], # Required column @@ -2049,7 +1901,7 @@ def setup_method(self): self.mock_read_concat_pred = self.patcher_pred.start() self.patcher_annot = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", return_value=pd.DataFrame( { "Class": [], # Required column @@ -2067,6 +1919,8 @@ def setup_method(self): annotation_directory_path="dummy_path", ) + self.rng = np.random.default_rng(seed=42) # For reproducibility + def teardown_method(self): """Stop patching.""" self.patcher_pred.stop() @@ -2092,7 +1946,7 @@ def test_single_sample_single_class(self): self.dp.create_tensors() assert self.dp.prediction_tensors.shape == (1, 1) assert self.dp.label_tensors.shape == (1, 1) - assert self.dp.prediction_tensors[0, 0] == 0.8 + np.testing.assert_almost_equal(self.dp.prediction_tensors[0, 0], 0.8) assert self.dp.label_tensors[0, 0] == 1 def test_multiple_samples_multiple_classes(self): @@ -2177,7 +2031,7 @@ def test_non_numeric_confidence(self): "A_annotation": [1], } ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="could not convert string to float: 'high'"): self.dp.create_tensors() def test_large_number_of_samples_and_classes(self): @@ -2188,8 +2042,8 @@ def test_large_number_of_samples_and_classes(self): self.dp.classes = classes data = {} for cls in classes: - data[f"{cls}_confidence"] = np.random.rand(num_samples) - data[f"{cls}_annotation"] = np.random.randint(0, 2, size=num_samples) + data[f"{cls}_confidence"] = self.rng.random(num_samples) + data[f"{cls}_annotation"] = self.rng.integers(0, 2, size=num_samples) self.dp.samples_df = pd.DataFrame(data) self.dp.create_tensors() assert self.dp.prediction_tensors.shape == (num_samples, num_classes) @@ -2201,7 +2055,7 @@ def setup_method(self): """Set up a DataProcessor instance for testing.""" # Mock the file reading functions to prevent actual file I/O self.patcher_pred = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", return_value=pd.DataFrame( { "Class": [], # Required column @@ -2216,7 +2070,7 @@ def setup_method(self): self.mock_read_concat_pred = self.patcher_pred.start() self.patcher_annot = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", return_value=pd.DataFrame( { "Class": [], # Required column @@ -2297,7 +2151,7 @@ def setup_method(self): """Set up a DataProcessor instance for testing.""" # Mock the file reading functions to prevent actual file I/O self.patcher_pred = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", return_value=pd.DataFrame( { "Class": [], # Required column @@ -2312,7 +2166,7 @@ def setup_method(self): self.mock_read_concat_pred = self.patcher_pred.start() self.patcher_annot = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", return_value=pd.DataFrame( { "Class": [], # Required column @@ -2330,6 +2184,8 @@ def setup_method(self): annotation_directory_path="dummy_path", ) + self.rng = np.random.default_rng(seed=42) # For reproducibility + def teardown_method(self): """Stop patching.""" self.patcher_pred.stop() @@ -2357,10 +2213,11 @@ def test_samples_df_with_data(self): def test_modifying_returned_df_does_not_affect_samples_df(self): """Test that modifying returned DataFrame does not affect samples_df.""" - self.dp.samples_df = pd.DataFrame({"A_confidence": [0.8]}) + target_value = 0.8 + self.dp.samples_df = pd.DataFrame({"A_confidence": [target_value]}) sample_data = self.dp.get_sample_data() sample_data["A_confidence"] = [0.5] - assert self.dp.samples_df["A_confidence"][0] == 0.8 + assert self.dp.samples_df["A_confidence"][0] == target_value def test_samples_df_with_nan_values(self): """Test when samples_df contains NaN values.""" @@ -2378,15 +2235,16 @@ def test_samples_df_columns(self): def test_samples_df_large_data(self): """Test with a large samples_df.""" num_samples = 1000 - self.dp.samples_df = pd.DataFrame({"A_confidence": np.random.rand(num_samples)}) + self.dp.samples_df = pd.DataFrame({"A_confidence": self.rng.random(num_samples)}) sample_data = self.dp.get_sample_data() pd.testing.assert_frame_equal(sample_data, self.dp.samples_df) def test_samples_df_with_custom_index(self): """Test that index is preserved.""" - self.dp.samples_df = pd.DataFrame({"A_confidence": [0.8]}, index=[10]) + target_index = 10 + self.dp.samples_df = pd.DataFrame({"A_confidence": [0.8]}, index=[target_index]) sample_data = self.dp.get_sample_data() - assert sample_data.index[0] == 10 + assert sample_data.index[0] == target_index def test_samples_df_with_different_dtypes(self): """Test that data types are preserved.""" @@ -2403,10 +2261,11 @@ def test_samples_df_with_different_dtypes(self): def test_modifications_after_get_sample_data(self): """Test that modifications to samples_df after get_sample_data do not affect returned DataFrame.""" - self.dp.samples_df = pd.DataFrame({"A_confidence": [0.8]}) + target_value = 0.8 + self.dp.samples_df = pd.DataFrame({"A_confidence": [target_value]}) sample_data = self.dp.get_sample_data() self.dp.samples_df["A_confidence"] = [0.5] - assert sample_data["A_confidence"][0] == 0.8 + assert sample_data["A_confidence"][0] == target_value def test_samples_df_with_multiindex(self): """Test when samples_df has a MultiIndex.""" @@ -2421,7 +2280,7 @@ def setup_method(self): """Set up a DataProcessor instance for testing.""" # Mock the file reading functions to prevent actual file I/O self.patcher_pred = patch( - "bapat.preprocessing.data_processor.read_and_concatenate_files_in_directory", + "birdnet_analyzer.evaluation.preprocessing.data_processor.read_and_concatenate_files_in_directory", side_effect=[ pd.DataFrame( { @@ -2466,6 +2325,8 @@ def setup_method(self): # Create tensors for the DataProcessor self.dp.create_tensors() + self.rng = np.random.default_rng(123) + def teardown_method(self): """Stop patching.""" self.patcher_pred.stop() @@ -2481,10 +2342,8 @@ def test_valid_classes_and_recordings(self): def test_selected_classes_not_in_data(self): """Test when selected classes are not in data.""" - with pytest.raises(ValueError): - self.dp.get_filtered_tensors( - selected_classes=["C"], selected_recordings=["rec1"] - ) + with pytest.raises(ValueError, match="No valid classes selected."): + self.dp.get_filtered_tensors(selected_classes=["C"], selected_recordings=["rec1"]) def test_selected_recordings_not_in_data(self): """Test when selected recordings are not in data.""" @@ -2497,16 +2356,12 @@ def test_selected_recordings_not_in_data(self): def test_empty_selected_classes(self): """Test when selected_classes is empty.""" - with pytest.raises(ValueError): - self.dp.get_filtered_tensors( - selected_classes=[], selected_recordings=["rec1"] - ) + with pytest.raises(ValueError, match="No valid classes selected."): + self.dp.get_filtered_tensors(selected_classes=[], selected_recordings=["rec1"]) def test_empty_selected_recordings(self): """Test when selected_recordings is empty.""" - predictions, labels, classes = self.dp.get_filtered_tensors( - selected_classes=["A"], selected_recordings=[] - ) + predictions, labels, classes = self.dp.get_filtered_tensors(selected_classes=["A"], selected_recordings=[]) assert predictions.shape == (0, 1) assert labels.shape == (0, 1) assert classes == ("A",) @@ -2515,9 +2370,7 @@ def test_samples_df_is_empty(self): """Test when samples_df is empty.""" self.dp.samples_df = pd.DataFrame() with pytest.raises(ValueError, match="samples_df is empty."): - self.dp.get_filtered_tensors( - selected_classes=["A"], selected_recordings=["rec1"] - ) + self.dp.get_filtered_tensors(selected_classes=["A"], selected_recordings=["rec1"]) def test_missing_confidence_or_annotation_columns(self): """Test when required columns are missing.""" @@ -2529,9 +2382,7 @@ def test_missing_confidence_or_annotation_columns(self): } ) with pytest.raises(KeyError): - self.dp.get_filtered_tensors( - selected_classes=["A"], selected_recordings=["rec1"] - ) + self.dp.get_filtered_tensors(selected_classes=["A"], selected_recordings=["rec1"]) def test_nan_values_in_data(self): """Test when data contains NaN values.""" @@ -2564,8 +2415,8 @@ def test_large_data(self): self.dp.samples_df = pd.DataFrame( { "filename": ["rec1"] * num_samples, - "A_confidence": np.random.rand(num_samples), - "A_annotation": np.random.randint(0, 2, size=num_samples), + "A_confidence": self.rng.random(num_samples), + "A_annotation": self.rng.integers(0, 2, size=num_samples), } ) predictions, labels, classes = self.dp.get_filtered_tensors( diff --git a/tests/preprocessing/test_utils.py b/tests/evaluation/preprocessing/test_utils.py similarity index 96% rename from tests/preprocessing/test_utils.py rename to tests/evaluation/preprocessing/test_utils.py index 9e890b9f..075b4dff 100644 --- a/tests/preprocessing/test_utils.py +++ b/tests/evaluation/preprocessing/test_utils.py @@ -39,9 +39,7 @@ def test_extract_recording_filename_multiple_dots(): Ensures that the function correctly extracts the base filename when there are multiple dots in the filename. """ - input_series = pd.Series( - ["/path/to/file.name.ext", "/path/to/another.file.name.ext"] - ) + input_series = pd.Series(["/path/to/file.name.ext", "/path/to/another.file.name.ext"]) expected_output = pd.Series(["file.name", "another.file.name"]) output_series = extract_recording_filename(input_series) pd.testing.assert_series_equal(output_series, expected_output) @@ -298,7 +296,7 @@ def test_read_and_concatenate_files_different_structures(tmp_path): df1.to_csv(tmp_path / "file1.txt", sep="\t", index=False) df2.to_csv(tmp_path / "file2.txt", sep="\t", index=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="File file2.txt has different columns than the previous files."): read_and_concatenate_files_in_directory(str(tmp_path)) @@ -316,9 +314,7 @@ def test_read_and_concatenate_files_ignores_non_txt(tmp_path): result_df = read_and_concatenate_files_in_directory(str(tmp_path)) expected_df = df_txt.assign(source_file="file1.txt") - pd.testing.assert_frame_equal( - result_df.reset_index(drop=True), expected_df.reset_index(drop=True) - ) + pd.testing.assert_frame_equal(result_df.reset_index(drop=True), expected_df.reset_index(drop=True)) def test_read_and_concatenate_files_nonexistent_directory(): @@ -354,9 +350,7 @@ def test_read_and_concatenate_files_large_files(tmp_path): result_df = read_and_concatenate_files_in_directory(str(tmp_path)) expected_df = df_large.assign(source_file="large_file.txt") - pd.testing.assert_frame_equal( - result_df.reset_index(drop=True), expected_df.reset_index(drop=True) - ) + pd.testing.assert_frame_equal(result_df.reset_index(drop=True), expected_df.reset_index(drop=True)) def test_read_and_concatenate_files_invalid_path(): @@ -380,9 +374,7 @@ def test_read_and_concatenate_files_different_encodings(tmp_path): # Write files with utf-8 encoding df_utf8.to_csv(tmp_path / "utf8_file.txt", sep="\t", index=False, encoding="utf-8") - df_ascii.to_csv( - tmp_path / "ascii_file.txt", sep="\t", index=False, encoding="utf-8" - ) + df_ascii.to_csv(tmp_path / "ascii_file.txt", sep="\t", index=False, encoding="utf-8") # Call the function to read and concatenate result_df = read_and_concatenate_files_in_directory(str(tmp_path))