diff --git a/src/osekit/core/annotation.py b/src/osekit/core/annotation.py new file mode 100644 index 00000000..d96e26ea --- /dev/null +++ b/src/osekit/core/annotation.py @@ -0,0 +1,409 @@ +"""The Annotation class represents an annotation made on APLOSE.""" + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, Self + +import pandas as pd +from matplotlib.patches import Rectangle +from pandas import Timestamp + +from osekit.core.event import Event + +KNOWN_KEYS = { + "dataset", + "project", + "filename", + "annotation_id", + "is_update_of_id", + "start_time", + "end_time", + "start_frequency", + "end_frequency", + "min_frequency", + "max_frequency", + "annotation", + "annotator", + "annotator_expertise", + "start_datetime", + "end_datetime", + "is_box", + "type", + "confidence_indicator_label", + "confidence_indicator_level", + "comments", + "signal_quantity", + "signal_is_intensity_too_low", + "signal_does_overlap_other_signals", + "signal_start_frequency", + "signal_end_frequency", + "signal_relative_min_frequency_count", + "signal_relative_max_frequency_count", + "signal_steps_count", + "signal_has_harmonics", + "signal_trend", + "signal_sidebands", + "signal_subharmonics", + "signal_frequency_jumps", + "signal_deterministic_chaos", + "created_at_phase", +} + + +@dataclass +class FrequencyBounds: + """Class representing the frequency bounds of an annotation. + + Parameters + ---------- + min: int + Lower frequency bound. + max: int + Upper frequency bound. + + """ + + min: int + max: int + + def __post_init__(self) -> None: + """Check the validity of the frequency bounds.""" + error_msgs = [] + if self.min < 0: + error_msgs.append( + f"Min frequency must be greater than or equal to 0, got {self.min}.", + ) + if self.max < 0: + error_msgs.append( + f"Max frequency must be greater than or equal to 0, got {self.max}.", + ) + if self.min > self.max: + error_msgs.append( + f"Max frequency must be greater than min frequency, " + f"got ({self.min},{self.max}).", + ) + if error_msgs: + msg = "\n".join(error_msgs) + raise ValueError(msg) + + @property + def bandwidth(self) -> int: + """Bandwidth of the annotation.""" + return self.max - self.min + + +@dataclass +class AnnotatorInfo: + """Class representing an annotator info.""" + + name: str + expertise: Literal["NOVICE", "AVERAGE", "EXPERT"] | None = None + + def __hash__(self) -> int: + """Return a hash for the annotator.""" + return hash((self.name, self.expertise)) + + def __eq__(self, other: Self) -> bool: + """Return whether two annotators are equal.""" + return self.name == other.name and self.expertise == other.expertise + + +@dataclass +class SignalParameters: + """Class representing parameters of an annoted signal.""" + + is_itensity_too_low: bool | None = None + does_overlap_other_signals: bool | None = None + min_frequency: int | None = None + max_frequency: int | None = None + nb_relative_mins: int | None = None + nb_relative_maxes: int | None = None + nb_steps: int | None = None + trend: Literal["FLAT", "ASCENDING", "DESCENDING", "MODULATED"] | None = None + frequency_jumps: bool | int | None = None + has_harmonics: bool | None = None + has_sidebands: bool | None = None + has_subharmonics: bool | None = None + has_deterministic_chaos: bool | None = None + + +@dataclass +class ConfidenceIndicator: + """Class that represents an annotation confidence indicator. + + Parameters + ---------- + label: str + Name of the level of confidence. + level: int + Level of confidence of the annotation. + maximum_level: int + Maximum level of confidence authorized in the project. + + """ + + label: str + level: int + maximum_level: int + + def __post_init__(self) -> None: + """Check the validity of the level and maximum level values.""" + if self.level > self.maximum_level: + msg = ( + f"Confidence level {self.level} is higher than " + f"maximum level {self.maximum_level} authorized in the project." + ) + raise ValueError(msg) + + @classmethod + def from_relative_level_string(cls, label: str, relative_level_string: str) -> Self: + """Return a ``ConfidenceIndicator`` from a string representing its level. + + Parameters + ---------- + label: str + Name of the level of confidence. + relative_level_string: str + Level of confidence relative to the maximum level available. + Should be formatted as ``n/m``, where ``n`` is the level of confidence + of the annotation and ``m`` is the maximum level available in the project. + + Returns + ------- + ConfidenceIndicator + The confidence indicator parsed from the input string. + + """ + level, maximum_level = map(int, relative_level_string.split("/")) + + return cls(label=label, level=level, maximum_level=maximum_level) + + +@dataclass +class AnnotationMetaData: + """Class that represents the metadata of an annotation. + + Parameters + ---------- + project: str + Name of the project in which the annotation was made. + filename: str + Name of the file this annotation was made on. + annotation_id: int + ID of the annotation. + base_id: int + ID of the base annotation. + May differ from ``annotation_id`` if the annotation is an update/correction. + comments: str | None + Comments left by the annotator. + phase: Literal["ANNOTATION", "VERIFICATION"] + Phase during which the annotation was created. + + """ + + project: str + filename: str + annotation_id: int + base_id: int | None + comments: str | None + phase: Literal["ANNOTATION", "VERIFICATION"] + + +@dataclass +class Verification: + """Class that represents a verification of an annotation.""" + + verificator: str + is_validated: bool + + def __hash__(self) -> int: + """Return a hash of the verification.""" + return hash((self.verificator, self.is_validated)) + + def __eq__(self, other: Self) -> bool: + """Return whether the two verifications are equal.""" + return ( + self.verificator == other.verificator + and self.is_validated == other.is_validated + ) + + +class Annotation(Event): + """Class that represents an annotation made on APLOSE.""" + + def __init__( # noqa: PLR0913 + self, + metadata: AnnotationMetaData, + begin: Timestamp, + end: Timestamp, + frequency_bounds: FrequencyBounds, + label: str, + annotator_info: AnnotatorInfo, + annotation_type: Literal["WEAK", "POINT", "BOX"], + confidence_indicator: ConfidenceIndicator, + signal_quantity: Literal["SINGLE", "MULTIPLE"], + signal_parameters: SignalParameters | None, + verifications: set[Verification], + ) -> None: + """Initialize an Annotation object. + + Parameters + ---------- + metadata: AnnotationMetaData + Metadata on the annotation. + begin: Timestamp + Begin timestamp of the annotation. + end: Timestamp + End timestamp of the annotation. + frequency_bounds: FrequencyBounds + Frequency bounds of the annotation. + label: str + Label of the annotation. + annotator_info: AnnotatorInfo + Information on the annotator or detector. + annotation_type: Literal["WEAK", "POINT", "BOX"] + Type of the annotation. + ``WEAK``: Annotation made on the whole spectrogram. + ``POINT``: Annotation made on one pixel of the spectrogram. + ``BOX``: Annotation made on one box within the spectrogram. + confidence_indicator: ConfidenceIndicator + Indicator of the confidence of the annotator. + signal_quantity: Literal["SINGLE","MULTIPLE"] + Whether there is only one signal in the annotation or more. + signal_parameters: SignalParameters | None + Parameters of the annotated signal. + ```None`` if ``signal_quantity`` is ``MULTIPLE``. + verifications: set[Verification] + Verifications made on this annotation. + + """ + self.metadata = metadata + self.label = label + self.annotator_info = annotator_info + self.frequency_bounds = frequency_bounds + self.type = annotation_type + self.confidence_indicator = confidence_indicator + self.signal_quantity = signal_quantity + self.signal_parameters = signal_parameters + self.verifications = verifications + + super().__init__(begin=begin, end=end) + + def __repr__(self) -> str: + """Override the string representation of the annotation.""" + return str(self.metadata.annotation_id) + + @classmethod + def from_dict(cls, row: dict) -> Self: + """Deserialize an Annotation object.""" + metadata = AnnotationMetaData( + project=row["project"] if "project" in row else row["dataset"], + filename=str(row["filename"]), + annotation_id=row["annotation_id"], + base_id=row["is_update_of_id"], + comments=row["comments"], + phase=row["created_at_phase"], + ) + annotator_info = AnnotatorInfo( + name=row["annotator"], + expertise=row["annotator_expertise"], + ) + + min_frequency, max_frequency = row["min_frequency"], row["max_frequency"] + frequency_bounds = ( + FrequencyBounds(min=min_frequency, max=max_frequency) + if not any(m is None for m in (min_frequency, max_frequency)) + else None + ) + + confidence_indicator = ConfidenceIndicator.from_relative_level_string( + label=row["confidence_indicator_label"], + relative_level_string=row["confidence_indicator_level"], + ) + + signal_quantity = row["signal_quantity"] + signal_parameters = ( + SignalParameters( + does_overlap_other_signals=row["signal_is_intensity_too_low"], + frequency_jumps=row["signal_frequency_jumps"], + has_deterministic_chaos=row["signal_deterministic_chaos"], + has_harmonics=row["signal_has_harmonics"], + has_sidebands=row["signal_sidebands"], + has_subharmonics=row["signal_subharmonics"], + is_itensity_too_low=row["signal_is_intensity_too_low"], + max_frequency=row["signal_end_frequency"], + min_frequency=row["signal_start_frequency"], + nb_relative_maxes=row["signal_relative_max_frequency_count"], + nb_relative_mins=row["signal_relative_min_frequency_count"], + nb_steps=row["signal_steps_count"], + trend=row["signal_trend"], + ) + if signal_quantity == "SINGLE" + else None + ) + + verifications = { + Verification( + verificator=key, + is_validated=value, + ) + for key, value in row.items() + if key not in KNOWN_KEYS + } + + return cls( + metadata=metadata, + label=row["annotation"], + annotator_info=annotator_info, + begin=Timestamp(row["start_datetime"]), + end=Timestamp(row["end_datetime"]), + frequency_bounds=frequency_bounds, + annotation_type=row["type"], + confidence_indicator=confidence_indicator, + signal_quantity=row["signal_quantity"], + signal_parameters=signal_parameters, + verifications=verifications, + ) + + def to_rectangle(self, **kwargs: Any) -> Rectangle: + """Return a matplotlib Rectangle representing the annotation. + + Parameters + ---------- + kwargs: + Additional keyword arguments + + Returns + ------- + matplotlib.patches.Rectangle + Rectangle representing the annotation. + The coordinates of the rectangle are in time x frequency. + + + + """ + return Rectangle( + xy=( # type: ignore[arg-type] + self.begin, + self.frequency_bounds.min, + ), + width=self.duration, # type: ignore[arg-type] + height=self.frequency_bounds.bandwidth, + **kwargs, + ) + + @classmethod + def from_csv(cls, csv: Path) -> list[Self]: + """Deserialize a list of Annotation from an annotations csv file.""" + records = pd.read_csv(filepath_or_buffer=csv).to_dict( + orient="records", + ) + records = [ + { + key: None if type(value) is float and math.isnan(value) else value + for key, value in record.items() + } + for record in records + ] + return [cls.from_dict(record) for record in records] diff --git a/src/osekit/core/event.py b/src/osekit/core/event.py index 575599df..60e1d5e3 100644 --- a/src/osekit/core/event.py +++ b/src/osekit/core/event.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, TypeVar +from osekit.utils.timestamp import localize_timestamp + if TYPE_CHECKING: from pandas import Timedelta, Timestamp @@ -63,6 +65,23 @@ def __repr__(self) -> str: """Overwrite repr.""" return f"{self.begin} - {self.end}" + def localize(self, timezone: str | None) -> None: + """Localize the event begin and end in a timezone. + + If the event is already tz-aware, it will be converted + to the target timezone. + + Parameters + ---------- + timezone: str | None + Target timezone + + """ + # We use the private fields here because we can't compare + # naive and aware timestamps in the begin and end setters + self._begin = localize_timestamp(timestamp=self._begin, timezone=timezone) + self._end = localize_timestamp(timestamp=self._end, timezone=timezone) + def overlaps(self, other: type[Event] | Event) -> bool: """Return ``True`` if the other event shares time with the current event. diff --git a/src/osekit/core/spectro_data.py b/src/osekit/core/spectro_data.py index 127b122e..bf5d8c15 100644 --- a/src/osekit/core/spectro_data.py +++ b/src/osekit/core/spectro_data.py @@ -404,12 +404,12 @@ def plot( ax: plt.Axes | None = None, sx: np.ndarray | None = None, scale: Scale | None = None, - ) -> None: + ) -> plt.Axes: """Plot the spectrogram on a specific ``Axes``. Parameters ---------- - ax: plt.axes | None + ax: plt.Axes | None ``Axes`` on which the spectrogram should be plotted. Defaulted to ``osekit.utils.plot.get_default_axes()``. sx: np.ndarray | None @@ -417,6 +417,11 @@ def plot( scale: osekit.core.frequecy_scale.Scale Custom frequency scale to use for plotting the spectrogram. + Returns + ------- + plt.Axes + The ``Axes`` on which the spectrogram has been plotted. + """ ax = ax if ax is not None else get_default_axes() sx = self.get_value() if sx is None else sx @@ -439,6 +444,7 @@ def plot( interpolation="none", extent=(date2num(time[0]), date2num(time[-1]), freq[0], freq[-1]), ) + return ax def get_db_value(self, sx: np.ndarray | None = None) -> np.ndarray: """Return the ``Sx`` spectrum of the spectrogram expressed in ``dB``. diff --git a/src/osekit/utils/timestamp.py b/src/osekit/utils/timestamp.py index 9eb69ad4..560ee4fc 100644 --- a/src/osekit/utils/timestamp.py +++ b/src/osekit/utils/timestamp.py @@ -90,7 +90,7 @@ def normalize_datetime(datetime: tuple[str], template: str) -> tuple[str, str]: def localize_timestamp( timestamp: Timestamp, - timezone: str | pytz.timezone, + timezone: str | pytz.timezone | None, ) -> Timestamp: """Localize a timestamp in the given timezone. @@ -98,8 +98,9 @@ def localize_timestamp( ---------- timestamp: pandas.Timestamp The timestamp to localize. - timezone: str | pytz.timezone + timezone: str | pytz.timezone | None The timezone in which the timestamp is localized. + If None, the output timestamp is naive. Returns ------- @@ -109,7 +110,7 @@ def localize_timestamp( to the new timezone. """ - if not timestamp.tz: + if not timestamp.tz or timezone is None: return timestamp.tz_localize(timezone) if timestamp.utcoffset() != timestamp.tz_convert(timezone).utcoffset(): diff --git a/tests/_static/aplose_result.csv b/tests/_static/aplose_result.csv new file mode 100644 index 00000000..dd321bfe --- /dev/null +++ b/tests/_static/aplose_result.csv @@ -0,0 +1,9 @@ +dataset,filename,annotation_id,is_update_of_id,start_time,end_time,start_frequency,end_frequency,min_frequency,max_frequency,annotation,annotator,annotator_expertise,start_datetime,end_datetime,is_box,type,confidence_indicator_label,confidence_indicator_level,comments,signal_quantity,signal_is_intensity_too_low,signal_does_overlap_other_signals,signal_start_frequency,signal_end_frequency,signal_relative_min_frequency_count,signal_relative_max_frequency_count,signal_steps_count,signal_has_harmonics,signal_trend,signal_sidebands,signal_subharmonics,signal_frequency_jumps,signal_deterministic_chaos,created_at_phase,lookaftering,bunyan +great_tit,990694,586654,,0.0,20.0,0.0,24000.0,0.0,24000.0,bird,vashti,NOVICE,2021-01-01T00:00:00.000+00:00,2021-01-01T00:00:20.000+00:00,0,WEAK,Sure,1/1,great tits |- vashti,MULTIPLE,,,,,,,,,,,,,,ANNOTATION,True,False +great_tit,990694,586655,,1.412,3.651,2512.0,15661.0,2512.0,15661.0,bird,vashti,NOVICE,2021-01-01T00:00:01.412+00:00,2021-01-01T00:00:03.651+00:00,1,BOX,Not sure,0/1,fluffy-backed tit-babbler |- vashti,MULTIPLE,,,,,,,,,,,,,,ANNOTATION,False,False +great_tit,990694,586656,,0.0,20.0,0.0,24000.0,0.0,24000.0,rain,heartleap,,2021-01-01T00:00:00.000+00:00,2021-01-01T00:00:20.000+00:00,0,WEAK,Sure,1/1,,MULTIPLE,,,,,,,,,,,,,,ANNOTATION,True,True +great_tit,990694,586657,,3.53,4.71,11137.0,13997.0,11137.0,13997.0,rain,heartleap,,2021-01-01T00:00:03.530+00:00,2021-01-01T00:00:04.710+00:00,1,BOX,Sure,1/1,,SINGLE,,True,12000.0,13000.0,3,2,4,True,MOD,True,,True,True,ANNOTATION,True,True +great_tit,990694,586669,586655,1.412,3.651,2512.0,15660.0,2512.0,15660.0,bird,bunyan,EXPERT,2021-01-01T00:00:01.412+00:00,2021-01-01T00:00:03.651+00:00,1,BOX,Not sure,0/1,,MULTIPLE,,,,,,,,,,,,,,VERIFICATION,, +great_tit,990694,586710,586655,1.412,3.651,2512.0,15660.0,2512.0,15660.0,bird,lookaftering,EXPERT,2021-01-01T00:00:01.412+00:00,2021-01-01T00:00:03.651+00:00,1,BOX,Not sure,0/1,,MULTIPLE,,,,,,,,,,,,,,VERIFICATION,, +great_tit,994410,586671,,0.0,20.0,0.0,24000.0,0.0,24000.0,car,bunyan,EXPERT,2021-01-01T00:01:18.218+00:00,2021-01-01T00:01:38.218+00:00,0,WEAK,Sure,1/1,,MULTIPLE,,,,,,,,,,,,,,VERIFICATION,, +great_tit,994410,586672,,0.0,20.0,0.0,24000.0,0.0,24000.0,bird,bunyan,EXPERT,2021-01-01T00:01:18.218+00:00,2021-01-01T00:01:38.218+00:00,0,WEAK,Sure,1/1,,MULTIPLE,,,,,,,,,,,,,,VERIFICATION,, diff --git a/tests/test_annotation.py b/tests/test_annotation.py new file mode 100644 index 00000000..0460453a --- /dev/null +++ b/tests/test_annotation.py @@ -0,0 +1,331 @@ +from contextlib import AbstractContextManager, nullcontext +from pathlib import Path + +import numpy as np +import pytest +from pandas import Timestamp + +from osekit.core.annotation import ( + Annotation, + AnnotationMetaData, + AnnotatorInfo, + ConfidenceIndicator, + FrequencyBounds, + SignalParameters, + Verification, +) + + +@pytest.fixture +def sample_annotation() -> Annotation: + return Annotation( + metadata=AnnotationMetaData( + annotation_id=35173, + base_id=None, + comments="He's a sneaky, sneaky dog friend", + filename="its_teasy", + phase="ANNOTATION", + project="mockasin", + ), + begin=Timestamp("2013-11-05 00:00:00"), + end=Timestamp("2013-11-05 00:00:10"), + frequency_bounds=FrequencyBounds( + min=1_000, + max=3_000, + ), + label="Connan", + annotator_info=AnnotatorInfo( + name="Mockasin", + expertise="EXPERT", + ), + annotation_type="BOX", + confidence_indicator=ConfidenceIndicator( + label="Sure", + level=2, + maximum_level=2, + ), + signal_quantity="SINGLE", + signal_parameters=SignalParameters( + does_overlap_other_signals=False, + frequency_jumps=True, + has_deterministic_chaos=True, + has_harmonics=True, + has_sidebands=True, + has_subharmonics=False, + is_itensity_too_low=False, + max_frequency=2_800, + min_frequency=1_300, + nb_relative_maxes=2, + nb_relative_mins=3, + nb_steps=4, + trend="MODULATED", + ), + verifications={ + Verification( + verificator="soft_hair", + is_validated=True, + ), + }, + ) + + +@pytest.mark.parametrize( + ("min_frequency", "max_frequency", "expectation"), + [ + pytest.param( + 0, + 1000, + nullcontext(1000), + id="box_from_bottom", + ), + pytest.param( + 300, + 1000, + nullcontext(700), + id="box_bandwidth_from_higher_than_0", + ), + pytest.param( + -10, + 1000, + pytest.raises(ValueError, match=r"Min frequency.*-10"), + id="negative_min_frequency_raises", + ), + pytest.param( + 0, + -5, + pytest.raises(ValueError, match=r"Max frequency.*-5"), + id="negative_max_frequency_raises", + ), + pytest.param( + 80, + 50, + pytest.raises( + ValueError, + match=r"Max frequency.*greater.*min frequency.*\(80,50\)", + ), + id="min_greater_than_max_raises", + ), + pytest.param( + -20, + -30, + pytest.raises( + ValueError, + match=r"(?s)" # Activates the DOTALL mode: includes \n in regex .* + r"(?=.*Min frequency.*got -20)" + r"(?=.*Max frequency.*got -30)" + r"(?=.*Max frequency.*greater.*min frequency.*\(-20,-30\))", + ), + id="errors_concatenation", + ), + ], +) +def test_frequency_bounds( + min_frequency: int, + max_frequency: int, + expectation: AbstractContextManager, +) -> None: + with expectation as e: + frequency_bounds = FrequencyBounds(min=min_frequency, max=max_frequency) + assert frequency_bounds.bandwidth == e + + +def test_annotator_info() -> None: + annotators = [ + AnnotatorInfo(name="ruby", expertise="NOVICE"), + AnnotatorInfo(name="ruby", expertise="NOVICE"), + AnnotatorInfo(name="haunt", expertise="EXPERT"), + AnnotatorInfo(name="haunt", expertise="EXPERT"), + AnnotatorInfo(name="nevada", expertise="EXPERT"), + AnnotatorInfo(name="nevada", expertise="EXPERT"), + AnnotatorInfo(name="haunt", expertise=None), + ] + + nb_unique_annotators = 4 + + assert sum(1 for _ in set(annotators)) == nb_unique_annotators + + +@pytest.mark.parametrize( + ("label", "level", "max_level", "expectation"), + [ + pytest.param( + "Sure", + 1, + 1, + nullcontext(), + id="max_level_is_ok", + ), + pytest.param( + "Not sure", + 0, + 1, + nullcontext(), + id="level_0_is_ok", + ), + pytest.param( + "Moderate", + 1, + 2, + nullcontext(), + id="between_0_and_max_is_ok", + ), + pytest.param( + "Moderate", + 3, + 2, + pytest.raises(ValueError, match=r"level 3.*higher.*maximum level 2"), + id="higher_than_max_raises", + ), + ], +) +def test_confidence_indicator_value_check( + label: str, + level: int, + max_level: int, + expectation: AbstractContextManager, +) -> None: + with expectation: + ConfidenceIndicator( + label=label, + level=level, + maximum_level=max_level, + ) + + +@pytest.mark.parametrize( + ("label", "relative_level_string", "expectation"), + [ + pytest.param( + "cool", + "1/6", + nullcontext( + ConfidenceIndicator( + label="cool", + level=1, + maximum_level=6, + ), + ), + id="correct_levels", + ), + pytest.param( + "cool", + "4/2", + pytest.raises(ValueError, match=r"level 4.*higher.*maximum level 2"), + id="incorrect_levels_should_raise", + ), + ], +) +def test_confidence_indicator_from_relative_level_string( + label: str, + relative_level_string: str, + expectation: AbstractContextManager, +) -> None: + with expectation as e: + ci = ConfidenceIndicator.from_relative_level_string( + label=label, + relative_level_string=relative_level_string, + ) + + assert ci.label == e.label + assert ci.level == e.level + assert ci.maximum_level == e.maximum_level + + +def test_annotations_from_csv() -> None: + annotations = Annotation.from_csv( + csv=Path(__file__).parent / "_static" / "aplose_result.csv", + ) + + # All records should be loaded + assert len(annotations) == 8 + assert all(a.metadata.project == "great_tit" for a in annotations) + + # Two distinct annotated files + filenames = {a.metadata.filename for a in annotations} + assert filenames == {"990694", "994410"} + + # Types + types = {a.type for a in annotations} + assert types == {"WEAK", "BOX"} + + # Phases + phases = {a.metadata.phase for a in annotations} + assert phases == {"ANNOTATION", "VERIFICATION"} + + # Single signal parameters + single = next(a for a in annotations if a.metadata.annotation_id == 586657) + assert single.signal_quantity == "SINGLE" + assert single.signal_parameters is not None + assert not single.signal_parameters.is_itensity_too_low + assert not single.signal_parameters.does_overlap_other_signals + assert single.signal_parameters.min_frequency == 12000 + assert single.signal_parameters.max_frequency == 13000 + assert single.signal_parameters.nb_relative_mins == 3 + assert single.signal_parameters.nb_relative_maxes == 2 + assert single.signal_parameters.nb_steps == 4 + assert single.signal_parameters.trend == "MOD" + assert single.signal_parameters.frequency_jumps + assert single.signal_parameters.has_harmonics + assert single.signal_parameters.has_sidebands + assert not single.signal_parameters.has_subharmonics + assert single.signal_parameters.has_deterministic_chaos + + # Multiple signal quantity: parameters should be None + multiple = next(a for a in annotations if a.metadata.annotation_id == 586654) + assert multiple.signal_quantity == "MULTIPLE" + assert multiple.signal_parameters is None + + # Annotation update + update = next(a for a in annotations if a.metadata.annotation_id == 586669) + assert update.metadata.base_id == 586655 + + # Annotation without base + base = next(a for a in annotations if a.metadata.annotation_id == 586655) + assert base.metadata.base_id is None + + # Annotator parsing + annotators = { + AnnotatorInfo(name="vashti", expertise="NOVICE"), + AnnotatorInfo(name="heartleap", expertise=None), + AnnotatorInfo(name="bunyan", expertise="EXPERT"), + AnnotatorInfo(name="lookaftering", expertise="EXPERT"), + } + assert np.array_equal( + annotators, + {a.annotator_info for a in annotations}, + ) + + # Verification parsing + verificated = next(a for a in annotations if a.metadata.annotation_id == 586654) + verification = { + Verification( + verificator="lookaftering", + is_validated=True, + ), + Verification( + verificator="bunyan", + is_validated=False, + ), + } + assert np.array_equal(verification, verificated.verifications) + + # Repr should be the annotation ID + annotation = annotations[0] + assert str(annotation) == str(annotation.metadata.annotation_id) + + +def test_annotation_to_rectangle(sample_annotation: Annotation) -> None: + rectangle = sample_annotation.to_rectangle() + + t1, t2 = sample_annotation.begin, sample_annotation.end + + f_box = sample_annotation.frequency_bounds + f1, f2 = f_box.min, f_box.max + + x, y = rectangle.xy + + assert x == t1 + assert y == f1 + + assert x + rectangle.get_width() == t2 + assert y + rectangle.get_height() == f2 diff --git a/tests/test_event.py b/tests/test_event.py index e9d21bdb..a5c02632 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -446,3 +446,49 @@ def test_repr() -> None: ) == "1990-09-12 12:00:00 - 1990-09-12 12:00:10" ) + + +@pytest.mark.parametrize( + ("event", "timezone", "expected"), + [ + pytest.param( + Event( + begin=Timestamp("18-02-1954 00:00:00"), + end=Timestamp("26-05-2022 00:00:00"), + ), + "UTC+0100", + Event( + begin=Timestamp("18-02-1954 00:00:00+0100"), + end=Timestamp("26-05-2022 00:00:00+0100"), + ), + id="naive_to_aware", + ), + pytest.param( + Event( + begin=Timestamp("18-02-1954 00:00:00+0100"), + end=Timestamp("26-05-2022 00:00:00+0100"), + ), + None, + Event( + begin=Timestamp("18-02-1954 00:00:00"), + end=Timestamp("26-05-2022 00:00:00"), + ), + id="aware_to_naive", + ), + pytest.param( + Event( + begin=Timestamp("18-02-1954 00:00:00+0100"), + end=Timestamp("26-05-2022 00:00:00+0100"), + ), + "UTC+0300", + Event( + begin=Timestamp("18-02-1954 02:00:00+0300"), + end=Timestamp("26-05-2022 02:00:00+0300"), + ), + id="aware_to_aware_converts_timezones", + ), + ], +) +def test_localize(event: Event, timezone: str | None, expected: Event) -> None: + event.localize(timezone) + assert event == expected diff --git a/tests/test_spectro.py b/tests/test_spectro.py index 9fc38ae3..4e4e309f 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -1425,7 +1425,8 @@ def mock_imshow( monkeypatch.setattr(plt.Axes, "imshow", mock_imshow) - sd.plot() + _, ax = plt.subplots() + sd_ax = sd.plot(ax=ax) assert (plot_kwargs["vmin"], plot_kwargs["vmax"]) == sd.v_lim assert plot_kwargs["cmap"] == sd.colormap @@ -1441,6 +1442,8 @@ def mock_imshow( assert f1 == sd.fft.f[0] assert f2 == sd.fft.f[-1] + assert sd_ax == ax + def test_spectro_default_v_lim(audio_files: pytest.fixture) -> None: files, _ = audio_files diff --git a/tests/test_timestamp_utils.py b/tests/test_timestamp_utils.py index 674ba46f..1ee533b9 100644 --- a/tests/test_timestamp_utils.py +++ b/tests/test_timestamp_utils.py @@ -583,6 +583,12 @@ def test_reformat_timestamp( Timestamp("2024-10-17T10:14:11.000+0000", tz="UTC"), id="negative_zero_UTC_offset_timezone", ), + pytest.param( + Timestamp("2024-10-17 10:14:11+0200"), + None, + Timestamp("2024-10-17T10:14:11"), + id="aware_to_naive", + ), ], ) def test_localize_timestamp(