From 446893bee5116be0cf6efa2af33233e2687452d8 Mon Sep 17 00:00:00 2001 From: Avishai Weissberg Date: Fri, 21 Nov 2025 11:07:47 +0200 Subject: [PATCH] feat: add DisjointIntervalSequence API --- docs-src/api.rst | 11 + docs-src/diseq.rst | 400 +++++++++++++++++ docs-src/index.rst | 1 + genome_kit/__init__.py | 1 + genome_kit/diseq.py | 692 +++++++++++++++++++++++++++++ tests/test_diseq.py | 957 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 2062 insertions(+) create mode 100644 docs-src/diseq.rst create mode 100644 genome_kit/diseq.py create mode 100644 tests/test_diseq.py diff --git a/docs-src/api.rst b/docs-src/api.rst index 28ae6a50..fcacb5e2 100644 --- a/docs-src/api.rst +++ b/docs-src/api.rst @@ -11,6 +11,17 @@ Interval :special-members: :members: +DisjointIntervalSequence +======================== + +.. autoclass:: genome_kit.diseq.IndexDirection + :members: + +.. autoclass:: genome_kit.diseq.DisjointIntervalSequence + :special-members: + :members: + :exclude-members: shift, expand, upstream_of, dnstream_of, within + Variant ======= diff --git a/docs-src/diseq.rst b/docs-src/diseq.rst new file mode 100644 index 00000000..b9d79e0c --- /dev/null +++ b/docs-src/diseq.rst @@ -0,0 +1,400 @@ +.. _diseq: + +------------------------------- +Disjoint Interval Sequences +------------------------------- + +Motivation +========== + +When working with transcripts, the situation may arise where we want to ignore the +introns (or other features) of the transcript. If you were to represent the transcript +with those parts removed, you would be left with a disjoint series of :py:class:`~genome_kit.Interval` +objects that are difficult to work with directly (for example, creating an interval on +this disjoint space, or querying the position of a specific sequence within a +CDS, relative to the spliced RNA sequence). + +For this reason, the :py:class:`~genome_kit.diseq.DisjointIntervalSequence` class was +introduced. :py:class:`~genome_kit.diseq.DisjointIntervalSequence` (DIS) simplifies +working with intervals that exist on a disjoint coordinate space. + + +Overview +======== + +A :py:class:`~genome_kit.diseq.DisjointIntervalSequence` (DIS) represents +a flattened coordinate system over a sequence of disjoint genomic intervals. +For example, the exons of a transcript form a disjoint set of genomic intervals that, +when concatenated, represent the spliced RNA sequence. + +A DIS has two layers: + +- A **coordinate space**: the underlying genomic + :py:class:`~genome_kit.Interval` objects (e.g. exons) that define the + flattened index system. These intervals are sorted 5'→3' and must not overlap. +- An **interval**: a sub-range within that coordinate space, defined by + a start and end index, where start <= end. + +To explain how the coordinate space and interval layers interact, let's ignore code for +now, and just use some diagrams to illustrate the concepts. + +Say we have a transcript on the + strand represented by the diagram below: +:: + Genomic Coordinates: 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 + DNA Sequence: A T G C A T G C A T G C A T + |<------->| |<------->| |<--->| |<----------->| |<--->| + Exon1 Intron1 Exon2 Intron2 Exon3 + +If we were to take only the exons, we would have the following disjoint intervals: +:: + Genomic Coordinates: 153 154 155 159 160 165 166 167 + DNA Sequence: A T G G C A T + |<------->| |<--->| |<--->| + Exon1 Exon2 Exon3 + +Let's say we want to create an interval on this series of disjoint exon intervals, +spanning from the start of Exon1 to the end of Exon3. We can start by converting our +list of exons into a DIS coordinate space +:: + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: A T G G C A T + |<----->| |<->| |<->| + Exon1 Exon2 Exon3 + +Now let's place the interval on the DIS coordinate space +:: + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: A T G G C A T + |<--------------------->| + end5 Interval end3 + Start Index: 0 + End Index: 7 + +We see that the interval spans the full length of the coordinate space, and is defined by +a start index of 0 and an end index of 7. + +.. note:: + Why 7 and not 6? The disjoint interval follows the convention of + :py:class:`~genome_kit.Interval` where intervals are half-open + (the end index is exclusive). + +The above example illustrates the basics of how a DIS works. However, it is possible to +do more. We can instead define an interval within the DIS on the strand opposite that +of the coordinate space. Let's start with the DIS coordinate space from above +:: + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: A T G G C A T + |<----->| |<->| |<->| + Exon1 Exon2 Exon3 + +Now let's add the negative (opposite) strand to the diagram +:: + Plus Strand + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence (+): A T G G C A T + ----------------------------------------------------- + DNA Sequence (-): T A C C G T A + DIS Coordinates: 0 1 2 3 4 5 6 7 + Minus Strand + +Importantly, we see the DIS coordinates are the same on both strands. This simplifies +things when you want to get the complement of a given interval, as you can use the same +indices and just flip the strand. To illustrate this, let's now define the same interval +as before (spanning the entire coordinate space) but on the negative strand +:: + Plus Strand + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence (+): A T G G C A T + ----------------------------------------------------- + DNA Sequence (-): T A C C G T A + DIS Coordinates: 0 1 2 3 4 5 6 7 + Minus Strand + |<--------------------->| + end3 Interval end5 + Start Index: 0 + End Index: 7 + On Coordinate Strand: False + +You will notice "On Coordinate Strand: False" has been added to the diagram. Since we +aren't able to determine which strand the interval is on just from the indices, this +variable is used to let us know the strandedness of the interval. + +Thus far we have only defined a DIS from intervals on the + strand. When defining +a DIS from intervals on the negative strand, much remains the same. However, there is +one important difference from a regular Interval: On a DIS created from negative-strand +intervals, the indices still increase in the 5'→3' direction of the transcript. Let's +take a look at an example: + +Say we have the following transcript on the negative strand represented by the diagram +below: +:: + Reminder: On the - strand, the 3' end is on the left and the 5' end is on the right! + + Exon3 Intron2 Exon2 Intron1 Exon1 + |<------->| |<------->| |<--->| |<----------->| |<--->| + DNA Sequence (-): G T C A G T C A G T C A G T + Genomic Coordinates: 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 + Minus Strand +Taking just the exons: +:: + Exon3 Exon2 Exon1 + |<------->| |<--->| |<--->| + DNA Sequence (-): G T C C A G T + Genomic Coordinates: 153 154 155 159 160 165 166 167 + Minus Strand + +Now let's create a DIS from these exons: +:: + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: T G A C C T G + |<->| |<->| |<----->| + Exon1 Exon2 Exon3 + +Notice that the sequence flips relative to the direction of the indices. What has +happened is that the DIS coordinate space is defined in the 5'→3' direction of the +transcript, regardless of genomic strand. In a DIS, 0 always corresponds to the +DIS coordinate's 5' end, and the largest index corresponds to the DIS coordinate's 3' +end. + +Let's now define an interval on this DIS +:: + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: T G A C C T G + |<--------------------->| + end5 Interval end3 + Start Index: 0 + End Index: 7 + On Coordinate Strand: True + +We see that despite creating the DIS from the negative strand, the full-length interval +on the coordinate strand still looks the same as in the + strand example. When working +with DIS objects, you only need to think of things in terms of "same strand" or +"opposite strand". + +To complete the example, let's define an interval on this DIS that is on the opposite strand of the coordinate space +:: + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence (-): T G A C C T G + ----------------------------------------------------- + DNA Sequence (+): A C T G G A C + DIS Coordinates: 0 1 2 3 4 5 6 7 + |<--------------------->| + end3 Interval end5 + Start Index: 0 + End Index: 7 + On Coordinate Strand: False + +Now that you understand how a DIS works conceptually, you can read on to see how to +manipulate them in code. + +Construction +============ + +From a Transcript +~~~~~~~~~~~~~~~~~ + +The most common way to create a DIS is from a +:py:class:`~genome_kit.Transcript`:: + + >>> from genome_kit import Genome + >>> from genome_kit.diseq import DisjointIntervalSequence + >>> genome = Genome("gencode.v29") + >>> transcript = genome.transcripts[100] + >>> dis = DisjointIntervalSequence.from_transcript(transcript) + +By default, the coordinate space is built from the transcript's exons. +You can also specify a region to use CDS or UTR intervals:: + + >>> dis_cds = DisjointIntervalSequence.from_transcript(transcript, region="cds") + >>> dis_utr5 = DisjointIntervalSequence.from_transcript(transcript, region="utr5") + >>> dis_utr3 = DisjointIntervalSequence.from_transcript(transcript, region="utr3") + +The ``coord_id`` and ``interval_id`` default to ``transcript.id`` but can +be overridden:: + + >>> dis = DisjointIntervalSequence.from_transcript( + ... transcript, coord_id="my_coord", interval_id="my_interval") + +From Intervals +~~~~~~~~~~~~~~ + +You can also construct a DIS from any sequence of +:py:class:`~genome_kit.Interval` objects (or annotation objects like +:py:class:`~genome_kit.Exon` that have an ``.interval`` attribute):: + + >>> from genome_kit import Interval + >>> exon_intervals = [e.interval for e in transcript.exons] + >>> dis = DisjointIntervalSequence.from_intervals(exon_intervals, coord_id="my_coord") + +The intervals must all share the same chromosome, strand, and reference +genome. They are automatically sorted 5'→3' and checked for overlaps. + +Coordinate Space +================ + +The coordinate space is defined by the underlying genomic intervals, which +are accessible as a tuple:: + + >>> dis.coordinate_intervals + (Interval("chr1", "+", 100, 200, "hg38"), Interval("chr1", "+", 300, 450, "hg38")) + >>> dis.coordinate_length + 250 + +Metadata about the coordinate space is available through properties:: + + >>> dis.chromosome + 'chr1' + >>> dis.coord_transcript_strand + '+' + >>> dis.reference_genome + 'hg38' + >>> dis.coord_id + 'ENST00000...' + +Interval Start and End +====================== + +The interval within the coordinate space is defined by ``start`` and ``end`` +indices, following the same half-open convention as :py:class:`~genome_kit.Interval` +(``start <= end`` always):: + + >>> dis.start + 0 + >>> dis.end + 250 + >>> dis.length + 250 + >>> len(dis) + 250 + +By default, the interval spans the full coordinate space (``start=0``, +``end=coordinate_length``). Indices can extend beyond ``[0, coordinate_length]``, but +the DNA sequence returned by ``genome.dna()`` will be N-padded. + +End5 and End3 +~~~~~~~~~~~~~ + +The ``end5_index`` and ``end3_index`` properties give the 5' and 3' positions +of the interval. These are derived from ``start`` and ``end`` based on the +interval's strand:: + On coordinate strand (on_coordinate_strand=True): + Start Index: 1 + End Index: 6 + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: T G A C C T G + |<------------->| + ----------------------------------------------------- + DNA Sequence: A C T G G A C + DIS Coordinates: 0 1 2 3 4 5 6 7 + Opposite Strand + + >>> dis = DisjointIntervalSequence.from_transcript(transcript) + >>> dis.end5_index # same as start when on coordinate strand + 1 + >>> dis.end3_index # same as end when on coordinate strand + 6 + + + Off coordinate strand (on_coordinate_strand=False): + Start Index: 3 + End Index: 7 + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence: T G A C C T G + ----------------------------------------------------- + DNA Sequence: A C T G G A C + DIS Coordinates: 0 1 2 3 4 5 6 7 + |<--------->| + Opposite Strand + + >>> opp = dis.as_opposite_strand() + >>> opp.end5_index # same as end when off coordinate strand + 7 + >>> opp.end3_index # same as start when off coordinate strand + 3 + +Boundary Properties +~~~~~~~~~~~~~~~~~~~ + +Zero-length DIS objects at the interval and coordinate boundaries are +available as properties:: + + >>> dis.end5 # 0-length DIS at the interval's 5' boundary + >>> dis.end3 # 0-length DIS at the interval's 3' boundary + >>> dis.coord_end5 # 0-length DIS at the coordinate space's 5' boundary + >>> dis.coord_end3 # 0-length DIS at the coordinate space's 3' boundary + +Strand Methods +============== + +A DIS interval can sit on either strand independently of the coordinate +intervals. The ``on_coordinate_strand`` property indicates whether the +interval is on the same strand as the coordinate intervals:: + On Coordinate Strand: True + Start Index: 1 + End Index: 6 + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence (+): A T C C G A C + |<------------->| + ----------------------------------------------------- + DNA Sequence (-): T A G G C T G + DIS Coordinates: 0 1 2 3 4 5 6 7 + Opposite Strand + + >>> dis.on_coordinate_strand + True + >>> dis.is_positive_strand() + True + +``as_opposite_strand()`` creates a new DIS with the interval on the other +strand. The ``start`` and ``end`` indices are preserved — only +``on_coordinate_strand`` is flipped:: + + Before (on_coordinate_strand=False): + Start Index: 1 + End Index: 6 + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence (+): T A A C C C T + ----------------------------------------------------- + DNA Sequence (-): A T T G G G A + DIS Coordinates: 0 1 2 3 4 5 6 7 + |<------------->| + Opposite Strand + + After as_opposite_strand() (on_coordinate_strand=True): + Start Index: 1 + End Index: 6 + DIS Coordinates: 0 1 2 3 4 5 6 7 + DNA Sequence (+): T A A C C C T + |<------------->| + ----------------------------------------------------- + DNA Sequence (-): A T T G G G A + DIS Coordinates: 0 1 2 3 4 5 6 7 + Opposite Strand + +In code:: + >>> dis.on_coordinate_strand + False + >>> dis.is_positive_strand() + False + >>> opposite = dis.as_opposite_strand() + >>> opposite.on_coordinate_strand + True + >>> opposite.is_positive_strand() + True + >>> opposite.start == dis.start # start/end unchanged + True + >>> opposite.end == dis.end + True + >>> opposite.coordinate_intervals == dis.coordinate_intervals + True + +The ``as_positive_strand()`` and ``as_negative_strand()`` methods return +``self`` if the interval is already on the requested strand:: + + >>> dis.as_positive_strand() is dis + True + +.. note:: + + Strand methods only affect the interval layer. The coordinate + intervals always remain unchanged. diff --git a/docs-src/index.rst b/docs-src/index.rst index 46b7b4b6..c6feb3fc 100644 --- a/docs-src/index.rst +++ b/docs-src/index.rst @@ -70,6 +70,7 @@ Contents: :maxdepth: 2 quickstart + diseq anchors api genomes diff --git a/genome_kit/__init__.py b/genome_kit/__init__.py index dc1f1cb3..4d4d22e3 100644 --- a/genome_kit/__init__.py +++ b/genome_kit/__init__.py @@ -5,6 +5,7 @@ from . import gk_data from .data_manager import DataManager, DefaultDataManager, GCSDataManager +from .diseq import DisjointIntervalSequence from .genome import Genome, ApprisNotAvailableError, ManeNotAvailableError from .genome_annotation import ( Cds, diff --git a/genome_kit/diseq.py b/genome_kit/diseq.py new file mode 100644 index 00000000..430bbf19 --- /dev/null +++ b/genome_kit/diseq.py @@ -0,0 +1,692 @@ +import enum +from dataclasses import dataclass +from typing import Sequence, Literal + +from .interval import Interval +from .genome_annotation import Transcript + + +class IndexDirection(enum.Enum): + """Controls how indices are assigned to the 5' and 3' ends of a coordinate space. + Changing this value will result in undefined behaviour for + :py:class:`DisjointIntervalSequence` objects created under a different convention. + + ``TRANSCRIPT_FIVE_TO_THREE`` + Index 0 is always at the coordinate transcript's 5' end, regardless of genomic strand. + + ``POSITIVE_STRAND_LEFT_TO_RIGHT`` + Index 0 is always at the leftmost genomic position relative to the positive strand. + On the negative strand, this means the 3' end is at index 0. + """ + + TRANSCRIPT_FIVE_TO_THREE = "transcript_five_to_three" + POSITIVE_STRAND_LEFT_TO_RIGHT = "positive_strand_left_to_right" + + +@dataclass(frozen=True) +class _CoordinateMetadata: + id: str | None + reference_genome: str + chromosome: str + transcript_strand: Literal["+", "-"] + + +@dataclass(frozen=True) +class _IntervalMetadata: + id: str | None + on_coordinate_strand: bool + + +class DisjointIntervalSequence: + """A flattened coordinate system over a sequence of disjoint genomic Intervals. + + A DIS represents two layers: + + - A **coordinate space** defined by a sequence of non-overlapping genomic + :py:class:`~genome_kit.Interval` objects (e.g. the exons of a transcript), + which are flattened into a contiguous 0-based index space. Indices for the + coordinate space are assigned according to the current :py:class:`IndexDirection` + value. + - An **interval** within that coordinate space, defined by a 5' and 3' index. + The interval may lie on the same, or opposite, strand as the coordinate space. + + Use :py:meth:`from_transcript` or :py:meth:`from_intervals` to construct + instances rather than calling the constructor directly. + """ + + _index_direction: IndexDirection = IndexDirection.TRANSCRIPT_FIVE_TO_THREE + + @classmethod + def set_index_direction(cls, direction: IndexDirection) -> None: + """Set the index direction convention for all DIS instances. + + Parameters + ---------- + direction : :py:class:`IndexDirection` + The index direction convention to use. + """ + cls._index_direction = direction + + @classmethod + def get_index_direction(cls) -> IndexDirection: + """Get the current index direction convention. + + Returns + ------- + :py:class:`IndexDirection` + """ + return cls._index_direction + + def __init__( + self, + coordinate_intervals: Sequence[Interval], + *, + coord_id: str | None = None, + interval_id: str | None = None, + on_coordinate_strand: bool = True, + start: int | None = None, + end: int | None = None, + ): + """Low-level constructor. + + Prefer :py:meth:`from_transcript` or :py:meth:`from_intervals` for + public construction. + + Parameters + ---------- + coordinate_intervals : Sequence[:py:class:`~genome_kit.Interval`] + Non-empty sequence of non-overlapping Intervals on the same + chromosome, strand, and reference genome. + coord_id : :py:class:`str` or None + Optional identifier for the coordinate space. + interval_id : :py:class:`str` or None + Optional identifier for the interval. + on_coordinate_strand : :py:class:`bool` + Whether the interval is on the same strand as the coordinate + intervals. + start : :py:class:`int` or None + start index of the interval in the coordinate space. Defaults to 0 + end : :py:class:`int` or None + end index of the interval in the coordinate space. Defaults to the length + of the coordinate space. + + Raises + ------ + ValueError + If intervals are empty, inconsistent, overlapping, or if start + is greater than end. + TypeError + If any element is not an Interval. + """ + if len(coordinate_intervals) == 0: + raise ValueError("coordinate_intervals must be non-empty") + + for i, iv in enumerate(coordinate_intervals): + if not isinstance(iv, Interval): + raise TypeError( + f"coordinate_intervals[{i}] is {type(iv).__name__}, expected Interval" + ) + + # Consistent chromosome, strand, reference_genome + iv0 = coordinate_intervals[0] + for iv in coordinate_intervals[1:]: + if iv.chromosome != iv0.chromosome: + raise ValueError( + f"All intervals must share the same chromosome, " + f"got {iv0.chromosome!r} and {iv.chromosome!r}" + ) + if iv.strand != iv0.strand: + raise ValueError( + f"All intervals must share the same strand, " + f"got {iv0.strand!r} and {iv.strand!r}" + ) + if iv.reference_genome != iv0.reference_genome: + raise ValueError( + f"All intervals must share the same reference genome, " + f"got {iv0.reference_genome!r} and {iv.reference_genome!r}" + ) + + # Sort 5'->3' + if iv0.strand == "+": + sorted_intervals = sorted(coordinate_intervals, key=lambda iv: iv.start) + else: + # On negative strand end is the 5' end since start < end. + # Sort by -end to get 5'->3' order. + sorted_intervals = sorted(coordinate_intervals, key=lambda iv: -iv.end) + + # No overlaps (adjacent/touching OK) + for i in range(len(sorted_intervals) - 1): + cur_iv, next_iv = sorted_intervals[i], sorted_intervals[i + 1] + plus_strand_overlap = iv0.strand == "+" and cur_iv.end > next_iv.start + minus_strand_overlap = iv0.strand == "-" and cur_iv.start < next_iv.end + if plus_strand_overlap or minus_strand_overlap: + raise ValueError( + f"Intervals must not overlap: [{cur_iv.start}, {cur_iv.end}) and [{next_iv.start}, {next_iv.end})" + ) + + self._coordinate_intervals: tuple[Interval, ...] = tuple(sorted_intervals) + + self._coord_metadata = _CoordinateMetadata( + id=coord_id, + reference_genome=iv0.reference_genome, + chromosome=iv0.chromosome, + transcript_strand=iv0.strand, + ) + self._interval_metadata = _IntervalMetadata( + id=interval_id, + on_coordinate_strand=on_coordinate_strand, + ) + + # Default interval start/end to span the full coordinate + if start is None: + start = 0 + if end is None: + end = self.coordinate_length + + # Validate that start is less than or equal to end + if start > end: + raise ValueError( + f"start index {start} cannot be greater than end index {end}" + ) + + self._start: int = start + self._end: int = end + + @classmethod + def from_intervals( + cls, + intervals: Sequence[Interval], + *, + coord_id: str | None = None, + interval_id: str | None = None, + ) -> "DisjointIntervalSequence": + """Construct a DIS from a sequence of Intervals + (or :py:class:`~genome_kit.Exon`/:py:class:`~genome_kit.Cds`/:py:class:`~genome_kit.Utr` objects). + + If elements have an ``.interval`` attribute (e.g. Exon, Cds, Utr), + the plain Interval is extracted automatically. + + Parameters + ---------- + intervals : Sequence[:py:class:`~genome_kit.Interval`] + Sequence of Interval or annotation objects with ``.interval``. + coord_id : :py:class:`str` or None + Optional identifier for the coordinate space. + interval_id : :py:class:`str` or None + Optional identifier for the interval. + + Returns + ------- + :py:class:`DisjointIntervalSequence` + """ + # Extract .interval if items are Exon/Cds/Utr + coord_intervals = [] + for iv in intervals: + if type(iv) is not Interval and hasattr(iv, "interval"): + coord_intervals.append(iv.interval) + else: + coord_intervals.append(iv) + return cls(coord_intervals, coord_id=coord_id, interval_id=interval_id) + + @classmethod + def from_transcript( + cls, + transcript: Transcript, + *, + region: Literal["exons", "cds", "utr5", "utr3"] = "exons", + coord_id: str | None = None, + interval_id: str | None = None, + ) -> "DisjointIntervalSequence": + """Construct a DIS from a transcript's exons, CDS, or UTR regions. + + Parameters + ---------- + transcript : :py:class:`~genome_kit.Transcript` + The source Transcript object. + region : :py:class:`str` + Which region to extract — ``"exons"``, ``"cds"``, + ``"utr5"``, or ``"utr3"``. + coord_id : :py:class:`str` or None + Optional coordinate ID. Defaults to ``transcript.id``. + interval_id : :py:class:`str` or None + Optional interval ID. Defaults to ``transcript.id``. + + Returns + ------- + :py:class:`DisjointIntervalSequence` + + Raises + ------ + ValueError + If region is not one of the allowed values. + """ + match region: + case "exons": + region_elements = transcript.exons + case "cds": + region_elements = transcript.cdss + case "utr5": + region_elements = transcript.utr5s + case "utr3": + region_elements = transcript.utr3s + case _: + raise ValueError(f"Invalid region: {region!r}") + coord_intervals = [element.interval for element in region_elements] + if coord_id is None: + coord_id = transcript.id + if interval_id is None: + interval_id = transcript.id + + return cls(coord_intervals, coord_id=coord_id, interval_id=interval_id) + + @property + def coord_id(self) -> str | None: + """Identifier for the coordinate space, or None.""" + return self._coord_metadata.id + + @property + def reference_genome(self) -> str: + """Reference genome of the coordinate intervals.""" + return self._coord_metadata.reference_genome + + @property + def chromosome(self) -> str: + """Chromosome of the coordinate intervals.""" + return self._coord_metadata.chromosome + + @property + def coord_transcript_strand(self) -> Literal["+", "-"]: + """Strand of the coordinate intervals (the transcript strand).""" + return self._coord_metadata.transcript_strand + + @property + def id(self) -> str | None: + """Identifier for the interval, or None.""" + return self._interval_metadata.id + + @property + def on_coordinate_strand(self) -> bool: + """True if the interval is on the same strand as the coordinate intervals.""" + return self._interval_metadata.on_coordinate_strand + + @property + def transcript_strand(self) -> Literal["+", "-"]: + """Effective strand of the interval, accounting for on_coordinate_strand.""" + if self.on_coordinate_strand: + return self.coord_transcript_strand + # Interval is on opposite strand + if self.coord_transcript_strand == "+": + return "-" + return "+" + + @property + def coordinate_end5_index(self) -> int: + """5' index of the coordinate space.""" + if self._index_direction == IndexDirection.TRANSCRIPT_FIVE_TO_THREE: + return 0 + if self.coord_transcript_strand == "+": + return 0 + return self.coordinate_length + + @property + def coordinate_end3_index(self) -> int: + """3' index of the coordinate space.""" + if self._index_direction == IndexDirection.TRANSCRIPT_FIVE_TO_THREE: + return self.coordinate_length + if self.coord_transcript_strand == "+": + return self.coordinate_length + return 0 + + @property + def end5_index(self) -> int: + """5' index of the interval.""" + if self._upstream_index_step() == -1: + return self._start + return self._end + + @property + def end3_index(self) -> int: + """3' index of the interval.""" + if self._upstream_index_step() == -1: + return self._end + return self._start + + @property + def start(self) -> int: + """Start index of the interval in the coordinate space. Not necessarily the 5' end.""" + return self._start + + @property + def end(self) -> int: + """End index of the interval in the coordinate space. Not necessarily the 3' end.""" + return self._end + + def _at_index( + self, idx: int, on_coordinate_strand: bool + ) -> "DisjointIntervalSequence": + """Return a 0-length DIS at the given index position.""" + return DisjointIntervalSequence( + self._coordinate_intervals, + coord_id=self._coord_metadata.id, + on_coordinate_strand=on_coordinate_strand, + start=idx, + end=idx, + ) + + @property + def end5(self) -> "DisjointIntervalSequence": + """0-length DIS at the interval's 5' end.""" + return self._at_index( + self.end5_index, on_coordinate_strand=self.on_coordinate_strand + ) + + @property + def end3(self) -> "DisjointIntervalSequence": + """0-length DIS at the interval's 3' end.""" + return self._at_index( + self.end3_index, on_coordinate_strand=self.on_coordinate_strand + ) + + @property + def coord_end5(self) -> "DisjointIntervalSequence": + """0-length DIS at the coordinate space's 5' end.""" + return self._at_index(self.coordinate_end5_index, on_coordinate_strand=True) + + @property + def coord_end3(self) -> "DisjointIntervalSequence": + """0-length DIS at the coordinate space's 3' end.""" + return self._at_index(self.coordinate_end3_index, on_coordinate_strand=True) + + @property + def coordinate_intervals(self) -> tuple[Interval, ...]: + """The underlying genomic intervals of the coordinate-space, sorted 5'->3'.""" + return self._coordinate_intervals + + @property + def coordinate_length(self) -> int: + """Total length of the coordinate space in bases.""" + return sum(len(iv) for iv in self._coordinate_intervals) + + @property + def length(self) -> int: + """Length of the interval on the coordinate space.""" + return self.end - self.start + + def _set_end5(self, end5: int) -> "DisjointIntervalSequence": + """Convenience method to update start/end based on a new end5 index.""" + if end5 == self.end5_index: + return self # No change + new_start, new_end = self._start, self._end + end5_difference = end5 - self.end5_index + is_moved_upstream = end5_difference * self._upstream_index_step() > 0 + if is_moved_upstream and self._upstream_index_step() == -1: + new_start = new_start - abs(end5_difference) + elif is_moved_upstream and self._upstream_index_step() == 1: + new_end = new_end + abs(end5_difference) + elif not is_moved_upstream and self._upstream_index_step() == -1: + new_start = new_start + abs(end5_difference) + elif not is_moved_upstream and self._upstream_index_step() == 1: + new_end = new_end - abs(end5_difference) + if new_start > new_end: + raise ValueError( + f"Invalid end5 update: end5 index {end5} would be downstream of end3 index {self.end3_index}" + ) + return self._from_end_indices(new_start, new_end) + + def _set_end3(self, end3: int) -> "DisjointIntervalSequence": + """Convenience method to update start/end based on a new end3 index.""" + if end3 == self.end3_index: + return self # No change + new_start, new_end = self._start, self._end + end3_difference = end3 - self.end3_index + is_moved_upstream = end3_difference * self._upstream_index_step() > 0 + if is_moved_upstream and self._upstream_index_step() == -1: + new_end = new_end - abs(end3_difference) + elif is_moved_upstream and self._upstream_index_step() == 1: + new_start = new_start + abs(end3_difference) + elif not is_moved_upstream and self._upstream_index_step() == -1: + new_end = new_end + abs(end3_difference) + elif not is_moved_upstream and self._upstream_index_step() == 1: + new_start = new_start - abs(end3_difference) + if new_start > new_end: + raise ValueError( + f"Invalid end3 update: end3 index {end3} would be upstream of end5 index {self.end5_index}" + ) + return self._from_end_indices(new_start, new_end) + + def _upstream_index_step(self, on_coordinate_strand: bool | None = None) -> int: + """Return +1 or -1 indicating the upstream direction in index space. + + Args: + on_coordinate_strand: Override for which strand to compute the step for. + Defaults to this interval's on_coordinate_strand. + """ + if on_coordinate_strand is None: + on_coordinate_strand = self.on_coordinate_strand + if self._index_direction == IndexDirection.TRANSCRIPT_FIVE_TO_THREE: + return -1 if on_coordinate_strand else 1 + # POSITIVE_STRAND_LEFT_TO_RIGHT: effective strand determines direction + return -1 if self.transcript_strand == "+" else 1 + + def _validate_same_coordinate_space( + self, other: "DisjointIntervalSequence" + ) -> None: + """Raise if other does not share the same coordinate space.""" + if not isinstance(other, DisjointIntervalSequence): + raise TypeError( + f"Expected DisjointIntervalSequence, got {type(other).__name__}" + ) + if self._coordinate_intervals != other._coordinate_intervals: + raise ValueError("DIS objects must share the same coordinate intervals") + + def _from_end_indices(self, end5: int, end3: int) -> "DisjointIntervalSequence": + """Return a new DIS with the same coordinate space but different interval indices.""" + # Validate end5 is upstream of or equal to end3 + if self._upstream_index_step() == -1: + if end5 > end3: + raise ValueError( + f"Invalid indices: end5 index {end5} is downstream of end3 index {end3}" + ) + if self._upstream_index_step() == 1: + if end5 < end3: + raise ValueError( + f"Invalid indices: end5 index {end5} is downstream of end3 index {end3}" + ) + return DisjointIntervalSequence( + self._coordinate_intervals, + coord_id=self._coord_metadata.id, + interval_id=self._interval_metadata.id, + on_coordinate_strand=self.on_coordinate_strand, + start=min(end5, end3), + end=max(end5, end3), + ) + + def shift(self, amount: int) -> "DisjointIntervalSequence": + """Shift the interval downstream by amount (negative shifts upstream). + + The coordinate space is unchanged. Only the interval indices move. + """ + downstream_step = -self._upstream_index_step() + delta = amount * downstream_step + return self._from_end_indices( + self.end5_index + delta, + self.end3_index + delta, + ) + + def expand( + self, upstream: int, dnstream: int | None = None + ) -> "DisjointIntervalSequence": + """Expand the interval upstream and/or downstream. + + Negative values contract the interval. Raises ValueError if contraction + would result in end5 being downstream of end3. + + Args: + upstream: Bases to expand (or contract if negative) toward the 5' end. + dnstream: Bases to expand (or contract if negative) toward the 3' end. + Defaults to upstream (symmetric). + """ + if dnstream is None: + dnstream = upstream + up_step = self._upstream_index_step() + down_step = -up_step + new_end5 = self.end5_index + (upstream * up_step) + new_end3 = self.end3_index + (dnstream * down_step) + # Validate end5 is still upstream of or equal to end3 + if (new_end5 - new_end3) * up_step < 0: + raise ValueError( + "Invalid expansion: end5 would be downstream of end3 " + f"(end5={new_end5}, end3={new_end3})" + ) + return self._from_end_indices(new_end5, new_end3) + + def upstream_of(self, other: "DisjointIntervalSequence") -> bool: + """True if self is strictly upstream of other (no overlap). + + Requires the same coordinate space and same on_coordinate_strand. + """ + self._validate_same_coordinate_space(other) + if self.on_coordinate_strand != other.on_coordinate_strand: + raise ValueError("Cannot compare: intervals are on different strands") + if self.length == 0 and other.length == 0 and self.start == other.start: + return False + if self._upstream_index_step() == -1: + return self._end <= other.start + return self._start >= other.end + + def dnstream_of(self, other: "DisjointIntervalSequence") -> bool: + """True if self is strictly downstream of other (no overlap). + + Requires the same coordinate space and same on_coordinate_strand. + """ + self._validate_same_coordinate_space(other) + if self.on_coordinate_strand != other.on_coordinate_strand: + raise ValueError("Cannot compare: intervals are on different strands") + if self.length == 0 and other.length == 0 and self.start == other.start: + return False + if self._upstream_index_step() == -1: + return self._start >= other.end + return self._end <= other.start + + def within(self, other: "DisjointIntervalSequence") -> bool: + """True if self's interval is contained within other's interval. + + Requires the same coordinate space and same on_coordinate_strand. + """ + self._validate_same_coordinate_space(other) + if self.on_coordinate_strand != other.on_coordinate_strand: + raise ValueError("Cannot compare: intervals are on different strands") + return self._start >= other.start and self._end <= other.end + + def is_positive_strand(self) -> bool: + """If the interval is on the positive strand. + + Returns + ------- + :py:class:`bool` + """ + if self.transcript_strand == "+": + return True + return False + + def as_positive_strand(self) -> "DisjointIntervalSequence": + """Return a DIS with the interval on the positive strand. + + Returns ``self`` if already on the positive strand. The coordinate + intervals are unchanged; only the interval strand is affected. + + Returns + ------- + :py:class:`DisjointIntervalSequence` + """ + if self.is_positive_strand(): + return self + return self.as_opposite_strand() + + def as_negative_strand(self) -> "DisjointIntervalSequence": + """Return a DIS with the interval on the negative strand. + + Returns ``self`` if already on the negative strand. The coordinate + intervals are unchanged; only the interval strand is affected. + + Returns + ------- + :py:class:`DisjointIntervalSequence` + """ + if not self.is_positive_strand(): + return self + return self.as_opposite_strand() + + def as_opposite_strand(self) -> "DisjointIntervalSequence": + """Return a new DIS with the interval on the opposite strand. + + The coordinate intervals are unchanged. The interval's + ``on_coordinate_strand`` is flipped. + + Returns + ------- + :py:class:`DisjointIntervalSequence` + """ + return DisjointIntervalSequence( + self._coordinate_intervals, + coord_id=self._coord_metadata.id, + interval_id=self._interval_metadata.id, + on_coordinate_strand=not self.on_coordinate_strand, + start=self._start, + end=self._end, + ) + + def genomic_span(self) -> Interval: + """Smallest single Interval spanning all coordinate intervals. + + Returns + ------- + :py:class:`~genome_kit.Interval` + An interval from the minimum ``start`` to the maximum ``end`` + across all coordinate intervals. + """ + ivs = self._coordinate_intervals + return Interval( + ivs[0].chromosome, + ivs[0].strand, + min(iv.start for iv in ivs), + max(iv.end for iv in ivs), + ivs[0].reference_genome, + ) + + def __len__(self) -> int: + """Return the length of the interval.""" + return self.length + + def __repr__(self) -> str: + """Return a human-readable representation.""" + return ( + f"DisjointIntervalSequence(" + f"coord_id={self._coord_metadata.id!r}, " + f"id={self._interval_metadata.id!r}, " + f"{self.chromosome}:{self.coord_transcript_strand}, " + f"len={self.length}, " + f"coord_intervals={self._coordinate_intervals})" + f"start={self._start}, " + f"end={self._end}, " + f"end5={self.end5_index}, " + f"end3={self.end3_index})" + ) + + def __eq__(self, other: object) -> bool: + """Equality based on coordinate intervals, metadata, and index values.""" + if not isinstance(other, DisjointIntervalSequence): + return NotImplemented + try: + # If refg mismatch, ValueError is raised + self._coordinate_intervals == other._coordinate_intervals + except ValueError: + breakpoint() + return False + return ( + self._coord_metadata == other._coord_metadata + and self._interval_metadata == other._interval_metadata + and self._start == other._start + and self._end == other._end + and self._coordinate_intervals == other._coordinate_intervals + ) diff --git a/tests/test_diseq.py b/tests/test_diseq.py new file mode 100644 index 00000000..3d38fa3d --- /dev/null +++ b/tests/test_diseq.py @@ -0,0 +1,957 @@ +import unittest +from genome_kit import Interval, Genome +from genome_kit.diseq import DisjointIntervalSequence + +REFG = "hg19" + + +def _make_intervals(specs, refg=REFG): + """Helper: specs is list of (chrom, strand, start, end).""" + return [ + Interval(chrom, strand, start, end, refg) for chrom, strand, start, end in specs + ] + + +class TestInit(unittest.TestCase): + + def test_non_interval_raises(self): + with self.assertRaises(TypeError): + DisjointIntervalSequence(["not an interval"]) + with self.assertRaises(TypeError): + DisjointIntervalSequence([42]) + iv = Interval("chr1", "+", 100, 200, REFG) + with self.assertRaises(TypeError): + DisjointIntervalSequence([iv, "bad"]) + + def test_empty_list_raises(self): + with self.assertRaises(ValueError): + DisjointIntervalSequence([]) + + def test_mixed_chromosomes_raises(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr2", "+", 300, 400)]) + with self.assertRaises(ValueError): + DisjointIntervalSequence(ivs) + + def test_mixed_strands_raises(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "-", 300, 400)]) + with self.assertRaises(ValueError): + DisjointIntervalSequence(ivs) + + def test_mixed_reference_genomes_raises(self): + ivs = [ + Interval("chr1", "+", 100, 200, "hg19"), + Interval("chr1", "+", 300, 400, "hg38"), + ] + with self.assertRaises(ValueError): + DisjointIntervalSequence(ivs) + + def test_overlapping_intervals_raises(self): + ivs = _make_intervals([("chr1", "+", 100, 250), ("chr1", "+", 200, 400)]) + with self.assertRaises(ValueError): + DisjointIntervalSequence(ivs) + + def test_overlapping_intervals_negative_strand_raises(self): + ivs = _make_intervals([("chr1", "-", 100, 250), ("chr1", "-", 200, 400)]) + with self.assertRaises(ValueError): + DisjointIntervalSequence(ivs) + + def test_adjacent_intervals_ok(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 200, 300)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coordinate_length, 200) + self.assertEqual(dis.coordinate_intervals, tuple(ivs)) + + def test_sorts_out_of_order_positive(self): + ivs = _make_intervals([("chr1", "+", 300, 400), ("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coordinate_intervals[0].start, 100) + self.assertEqual(dis.coordinate_intervals[1].start, 300) + + def test_sorts_out_of_order_negative(self): + ivs = _make_intervals([("chr1", "-", 100, 200), ("chr1", "-", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coordinate_intervals[0].end, 400) + self.assertEqual(dis.coordinate_intervals[1].end, 200) + + def test_coordinate_length(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 450)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coordinate_length, 250) + + def test_start_end_default_to_full_interval(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.start, 0) + self.assertEqual(dis.end, 200) + + def test_out_of_bounds_indices_allowed(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) # coord_length=100 + dis = DisjointIntervalSequence(ivs, start=-10, end=50) + self.assertEqual(dis.start, -10) + dis2 = DisjointIntervalSequence(ivs, start=10, end=300) + self.assertEqual(dis2.end, 300) + dis3 = DisjointIntervalSequence(ivs, start=-100, end=300) + self.assertEqual(dis3.start, -100) + self.assertEqual(dis3.end, 300) + dis4 = DisjointIntervalSequence(ivs, start=300, end=300) + self.assertEqual(dis4.start, 300) + self.assertEqual(dis4.end, 300) + dis5 = DisjointIntervalSequence(ivs, start=-300, end=-300) + self.assertEqual(dis5.start, -300) + self.assertEqual(dis5.end, -300) + + def test_start_greater_than_end_raises(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + with self.assertRaises(ValueError): + DisjointIntervalSequence(ivs, start=80, end=10) + + def test_start_equals_end_allowed(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, start=50, end=50) + self.assertEqual(dis.length, 0) + + def test_custom_interval_indices(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, start=10, end=50) + self.assertEqual(dis.start, 10) + self.assertEqual(dis.end, 50) + self.assertEqual(dis.length, 40) + + +class TestFromIntervals(unittest.TestCase): + + def setUp(self): + self.genome = Genome("gencode.v41") + self.transcript = self.genome.transcripts[2002] + + def test_happy_path(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence.from_intervals(ivs, coord_id="mycoord") + self.assertEqual(dis.coord_id, "mycoord") + self.assertEqual(dis.coordinate_length, 200) + self.assertEqual(dis.coordinate_intervals, tuple(ivs)) + + def test_coord_and_interval_id_independent(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence.from_intervals( + ivs, coord_id="c1", interval_id="i1" + ) + self.assertEqual(dis.coord_id, "c1") + self.assertEqual(dis.id, "i1") + + def test_single_interval(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence.from_intervals(ivs) + self.assertEqual(len(dis.coordinate_intervals), 1) + self.assertEqual(dis.coordinate_length, 100) + self.assertEqual(dis.coordinate_intervals, tuple(ivs)) + + def test_adjacent_intervals(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 200, 300)]) + dis = DisjointIntervalSequence.from_intervals(ivs) + self.assertEqual(dis.coordinate_length, 200) + self.assertEqual(dis.coordinate_intervals, tuple(ivs)) + + def test_extracts_interval_from_exon(self): + exons = list(self.transcript.exons) + dis = DisjointIntervalSequence.from_intervals(exons) + expected = tuple(e.interval for e in self.transcript.exons) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_extracts_interval_from_cds(self): + cdss = list(self.transcript.cdss) + dis = DisjointIntervalSequence.from_intervals(cdss) + expected = tuple(c.interval for c in self.transcript.cdss) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_extracts_interval_from_utr5(self): + utr5s = list(self.transcript.utr5s) + dis = DisjointIntervalSequence.from_intervals(utr5s) + expected = tuple(u.interval for u in self.transcript.utr5s) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_extracts_interval_from_utr3(self): + utr3s = list(self.transcript.utr3s) + dis = DisjointIntervalSequence.from_intervals(utr3s) + expected = tuple(u.interval for u in self.transcript.utr3s) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_metadata_defaults_to_none(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence.from_intervals(ivs) + self.assertEqual(dis.coord_id, None) + self.assertEqual(dis.id, None) + self.assertEqual(dis.reference_genome, REFG) + self.assertEqual(dis.chromosome, "chr1") + self.assertEqual(dis.coord_transcript_strand, "+") + + +class TestFromTranscript(unittest.TestCase): + + def setUp(self): + self.genome = Genome("gencode.v41") + self.transcript = self.genome.transcripts[2002] + + def test_exons_region(self): + dis = DisjointIntervalSequence.from_transcript(self.transcript, region="exons") + expected = tuple(e.interval for e in self.transcript.exons) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_cds_region(self): + dis = DisjointIntervalSequence.from_transcript(self.transcript, region="cds") + expected = tuple(c.interval for c in self.transcript.cdss) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_utr5_region(self): + dis = DisjointIntervalSequence.from_transcript(self.transcript, region="utr5") + expected = tuple(u.interval for u in self.transcript.utr5s) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_utr3_region(self): + dis = DisjointIntervalSequence.from_transcript(self.transcript, region="utr3") + expected = tuple(u.interval for u in self.transcript.utr3s) + self.assertEqual(dis.coordinate_intervals, expected) + + def test_metadata_defaults_to_transcript_id(self): + dis = DisjointIntervalSequence.from_transcript(self.transcript) + self.assertEqual(dis.coord_id, self.transcript.id) + self.assertEqual(dis.id, self.transcript.id) + self.assertEqual(dis.reference_genome, self.transcript.reference_genome) + self.assertEqual(dis.chromosome, self.transcript.chromosome) + self.assertEqual(dis.coord_transcript_strand, self.transcript.strand) + + def test_invalid_region_raises(self): + with self.assertRaises(ValueError): + DisjointIntervalSequence.from_transcript(self.transcript, region="invalid") + + def test_custom_id_overrides(self): + dis = DisjointIntervalSequence.from_transcript( + self.transcript, coord_id="custom_coord", interval_id="custom_iv" + ) + self.assertEqual(dis.coord_id, "custom_coord") + self.assertEqual(dis.id, "custom_iv") + + +class TestProperties(unittest.TestCase): + + def test_metadata_getters_positive(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, coord_id="c", interval_id="i") + self.assertEqual(dis.coord_id, "c") + self.assertEqual(dis.id, "i") + self.assertEqual(dis.reference_genome, REFG) + self.assertEqual(dis.chromosome, "chr1") + self.assertEqual(dis.coord_transcript_strand, "+") + self.assertTrue(dis.on_coordinate_strand) + + def test_on_coordinate_strand_false(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertFalse(dis.on_coordinate_strand) + + def test_coord_transcript_strand_positive(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coord_transcript_strand, "+") + + def test_coord_transcript_strand_negative(self): + ivs = _make_intervals([("chr1", "-", 100, 200)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coord_transcript_strand, "-") + + def test_coord_transcript_strand_unaffected_by_on_coordinate_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertEqual(dis.coord_transcript_strand, "+") + + def test_transcript_strand_on_coordinate_strand_positive(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + self.assertEqual(dis.transcript_strand, "+") + + def test_transcript_strand_on_coordinate_strand_negative(self): + ivs = _make_intervals([("chr1", "-", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + self.assertEqual(dis.transcript_strand, "-") + + def test_transcript_strand_off_coordinate_strand_positive(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertEqual(dis.transcript_strand, "-") + + def test_transcript_strand_off_coordinate_strand_negative(self): + ivs = _make_intervals([("chr1", "-", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertEqual(dis.transcript_strand, "+") + + def test_start_and_end(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, start=10, end=150) + self.assertEqual(dis.start, 10) + self.assertEqual(dis.end, 150) + + def test_coordinate_intervals(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, start=10, end=150) + self.assertEqual(dis.coordinate_intervals, tuple(ivs)) + + def test_coordinate_length(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.coordinate_length, 200) + + def test_length(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, start=10, end=150) + self.assertEqual(dis.length, 140) + + def test_length_zero(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, start=50, end=50) + self.assertEqual(dis.length, 0) + + +class TestStrandMethods(unittest.TestCase): + + def test_is_positive_strand_plus_on_coord(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + self.assertTrue(dis.is_positive_strand()) + + def test_is_positive_strand_minus_off_coord(self): + ivs = _make_intervals([("chr1", "-", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertTrue(dis.is_positive_strand()) + + def test_is_positive_strand_false_plus_off_coord(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertFalse(dis.is_positive_strand()) + + def test_is_positive_strand_false_minus_on_coord(self): + ivs = _make_intervals([("chr1", "-", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + self.assertFalse(dis.is_positive_strand()) + + def test_as_positive_strand_already_positive(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + result = dis.as_positive_strand() + self.assertIs(result, dis) + + def test_as_positive_strand_flips(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence( + ivs, on_coordinate_strand=False, start=10, end=80 + ) + expected = DisjointIntervalSequence( + ivs, on_coordinate_strand=True, start=10, end=80 + ) + result = dis.as_positive_strand() + self.assertTrue(result.is_positive_strand()) + self.assertTrue(result.on_coordinate_strand) + self.assertEqual(result.start, 10) + self.assertEqual(result.end, 80) + self.assertEqual(result, expected) + + def test_as_negative_strand_already_negative(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + result = dis.as_negative_strand() + self.assertIs(result, dis) + + def test_as_negative_strand_flips(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True, start=10, end=80) + expected = DisjointIntervalSequence( + ivs, on_coordinate_strand=False, start=10, end=80 + ) + result = dis.as_negative_strand() + self.assertFalse(result.is_positive_strand()) + self.assertFalse(result.on_coordinate_strand) + self.assertEqual(result.start, 10) + self.assertEqual(result.end, 80) + self.assertEqual(result, expected) + + def test_as_opposite_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + opp = dis.as_opposite_strand() + self.assertFalse(opp.is_positive_strand()) + opp2 = opp.as_opposite_strand() + self.assertTrue(opp2.is_positive_strand()) + + def test_strand_flip_preserves_coordinate_intervals(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + flipped = dis.as_opposite_strand() + self.assertEqual(flipped.coordinate_intervals, dis.coordinate_intervals) + + def test_idempotency(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + self.assertIs(dis.as_positive_strand().as_positive_strand(), dis) + + def test_as_opposite_strand_preserves_start_end(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, start=10, end=80) + opp = dis.as_opposite_strand() + self.assertEqual(opp.start, 10) + self.assertEqual(opp.end, 80) + + def test_end5_end3_swap_on_opposite_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, start=10, end=80) + # On coordinate strand: end5 at start, end3 at end + self.assertEqual(dis.end5_index, 10) + self.assertEqual(dis.end3_index, 80) + # Off coordinate strand: end5 at end, end3 at start + opp = dis.as_opposite_strand() + self.assertEqual(opp.end5_index, 80) + self.assertEqual(opp.end3_index, 10) + + +class TestEndProperties(unittest.TestCase): + + def test_end5_default(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, coord_id="c", interval_id="i") + # On coordinate strand: end5_index == start (0), end3_index == end (200) + self.assertEqual(dis.end5_index, 0) + self.assertEqual(dis.end3_index, 200) + e5 = dis.end5 + self.assertEqual(len(e5), 0) + self.assertEqual(e5.start, 0) + self.assertEqual(e5.end, 0) + self.assertEqual(e5.coord_id, "c") + self.assertEqual(e5.id, None) + expected = DisjointIntervalSequence( + ivs, coord_id="c", interval_id=None, on_coordinate_strand=True, start=0, end=0 + ) + self.assertEqual(e5, expected) + + def test_end3_default(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + e3 = dis.end3 + self.assertEqual(len(e3), 0) + self.assertEqual(e3.start, 200) + self.assertEqual(e3.end, 200) + self.assertEqual(dis.end3_index, 200) + expected = DisjointIntervalSequence( + ivs, on_coordinate_strand=True, start=200, end=200 + ) + self.assertEqual(e3, expected) + + def test_end5_end3_with_custom_indices(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, start=30, end=150) + # On coordinate strand: end5_index == start, end3_index == end + e5 = dis.end5 + e3 = dis.end3 + self.assertEqual(e5.start, 30) + self.assertEqual(e5.end, 30) + self.assertEqual(e3.start, 150) + self.assertEqual(e3.end, 150) + self.assertEqual(dis.end5_index, 30) + self.assertEqual(dis.end3_index, 150) + expected_e5 = DisjointIntervalSequence( + ivs, on_coordinate_strand=True, start=30, end=30 + ) + expected_e3 = DisjointIntervalSequence( + ivs, on_coordinate_strand=True, start=150, end=150 + ) + self.assertEqual(e5, expected_e5) + self.assertEqual(e3, expected_e3) + + def test_coord_end5(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs) + ce5 = dis.coord_end5 + self.assertEqual(len(ce5), 0) + self.assertEqual(ce5.start, 0) + self.assertEqual(ce5.end, 0) + self.assertEqual(ce5.end5_index, 0) + self.assertEqual(ce5.end3_index, 0) + expected = DisjointIntervalSequence( + ivs, on_coordinate_strand=True, start=0, end=0 + ) + self.assertEqual(ce5, expected) + + def test_coord_end3(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs) + ce3 = dis.coord_end3 + self.assertEqual(len(ce3), 0) + self.assertEqual(ce3.start, 100) + self.assertEqual(ce3.end, 100) + self.assertEqual(ce3.end5_index, 100) + self.assertEqual(ce3.end3_index, 100) + expected = DisjointIntervalSequence( + ivs, on_coordinate_strand=True, start=100, end=100 + ) + self.assertEqual(ce3, expected) + + def test_end_preserves_coordinate_intervals(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(dis.end5.coordinate_intervals, dis.coordinate_intervals) + self.assertEqual(dis.end3.coordinate_intervals, dis.coordinate_intervals) + self.assertEqual(dis.coord_end5.coordinate_intervals, dis.coordinate_intervals) + self.assertEqual(dis.coord_end3.coordinate_intervals, dis.coordinate_intervals) + + def test_end_preserves_on_coordinate_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertFalse(dis.end5.on_coordinate_strand) + self.assertFalse(dis.end3.on_coordinate_strand) + + def test_coord_end_independent_of_on_coordinate_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertTrue(dis.coord_end5.on_coordinate_strand) + self.assertTrue(dis.coord_end3.on_coordinate_strand) + + def test_end5_end3_opposite_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence( + ivs, start=30, end=150, on_coordinate_strand=False + ) + # Off coord strand: end5 at end (150), end3 at start (30) + self.assertEqual(dis.end5_index, 150) + self.assertEqual(dis.end3_index, 30) + e5 = dis.end5 + e3 = dis.end3 + self.assertEqual(e5.start, 150) + self.assertEqual(e5.end, 150) + self.assertEqual(e3.start, 30) + self.assertEqual(e3.end, 30) + + +class TestDunderMethods(unittest.TestCase): + + def test_len(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs) + self.assertEqual(len(dis), 200) + + def test_len_with_custom_indices(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs, start=10, end=80) + self.assertEqual(len(dis), 70) + + def test_repr(self): + ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + dis = DisjointIntervalSequence(ivs, coord_id="ENST0001", interval_id="IV1") + r = repr(dis) + self.assertIn("DisjointIntervalSequence(", r) + self.assertIn("coord_id='ENST0001'", r) + self.assertIn("id='IV1'", r) + self.assertIn("chr1:+", r) + self.assertIn("len=200", r) + self.assertIn('coord_intervals=(Interval("chr1", "+", 100, 200, "hg19"), Interval("chr1", "+", 300, 400, "hg19"))', r) + self.assertIn("start=0", r) + self.assertIn("end=200", r) + self.assertIn("end5=0", r) + self.assertIn("end3=200", r) + + def test_eq_same(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + a = DisjointIntervalSequence(ivs, coord_id="x", interval_id="i") + b = DisjointIntervalSequence(ivs, coord_id="x", interval_id="i") + self.assertEqual(a, b) + + def test_eq_different_coord_id(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + a = DisjointIntervalSequence(ivs, coord_id="x") + b = DisjointIntervalSequence(ivs, coord_id="y") + self.assertNotEqual(a, b) + + def test_eq_different_interval_id(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + a = DisjointIntervalSequence(ivs, interval_id="x") + b = DisjointIntervalSequence(ivs, interval_id="y") + self.assertNotEqual(a, b) + + def test_eq_different_on_coordinate_strand(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + a = DisjointIntervalSequence(ivs, on_coordinate_strand=True) + b = DisjointIntervalSequence(ivs, on_coordinate_strand=False) + self.assertNotEqual(a, b) + + def test_eq_different_chrom(self): + a = DisjointIntervalSequence(_make_intervals([("chr1", "+", 100, 200)])) + b = DisjointIntervalSequence(_make_intervals([("chr2", "+", 100, 200)])) + self.assertNotEqual(a, b) + + def test_eq_different_refg(self): + a = DisjointIntervalSequence([Interval("chr1", "+", 100, 200, "hg19")]) + b = DisjointIntervalSequence([Interval("chr1", "+", 100, 200, "hg38")]) + self.assertNotEqual(a, b) + + def test_eq_different_strand(self): + a = DisjointIntervalSequence(_make_intervals([("chr1", "+", 100, 200)])) + b = DisjointIntervalSequence(_make_intervals([("chr1", "-", 100, 200)])) + self.assertNotEqual(a, b) + + def test_eq_different_coordinate_intervals(self): + a = DisjointIntervalSequence(_make_intervals([("chr1", "+", 100, 200)])) + b = DisjointIntervalSequence(_make_intervals([("chr1", "+", 100, 300)])) + self.assertNotEqual(a, b) + + def test_eq_different_start(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + a = DisjointIntervalSequence(ivs, start=0, end=50) + b = DisjointIntervalSequence(ivs, start=10, end=50) + self.assertNotEqual(a, b) + + def test_eq_different_end(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + a = DisjointIntervalSequence(ivs, start=0, end=50) + b = DisjointIntervalSequence(ivs, start=0, end=60) + self.assertNotEqual(a, b) + + def test_eq_non_dis(self): + ivs = _make_intervals([("chr1", "+", 100, 200)]) + dis = DisjointIntervalSequence(ivs) + self.assertNotEqual(dis, "not a DIS") + self.assertNotEqual(dis, 42) + + +# Helper for shift/expand/relational tests +# 2 exons on chr1+: [100,200) and [300,400), coordinate_length=200 +_COORD_IVS = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 300, 400)]) + + +def _dis( + start=0, end=200, on_coordinate_strand=True, coord_id="c", interval_id="i", ivs=None +): + """Quick DIS factory for tests.""" + return DisjointIntervalSequence( + ivs or _COORD_IVS, + coord_id=coord_id, + interval_id=interval_id, + on_coordinate_strand=on_coordinate_strand, + start=start, + end=end, + ) + + +class TestShift(unittest.TestCase): + + def test_shift_positive(self): + dis = _dis(start=30, end=150) + shifted = dis.shift(10) + self.assertEqual(shifted.start, 40) + self.assertEqual(shifted.end, 160) + + def test_shift_negative(self): + dis = _dis(start=30, end=150) + shifted = dis.shift(-10) + self.assertEqual(shifted.start, 20) + self.assertEqual(shifted.end, 140) + + def test_shift_zero(self): + dis = _dis(start=30, end=150) + shifted = dis.shift(0) + self.assertEqual(shifted, dis) + + def test_shift_beyond_coordinate(self): + dis = _dis(start=30, end=150) + shifted = dis.shift(60) + self.assertEqual(shifted.start, 90) + self.assertEqual(shifted.end, 210) + + def test_shift_negative_beyond(self): + dis = _dis(start=30, end=150) + shifted = dis.shift(-40) + self.assertEqual(shifted.start, -10) + self.assertEqual(shifted.end, 110) + + def test_shift_zero_length(self): + dis = _dis(start=50, end=50) + shifted = dis.shift(5) + self.assertEqual(shifted.start, 55) + self.assertEqual(shifted.end, 55) + self.assertEqual(shifted.length, 0) + + def test_shift_opposite_strand(self): + # on_coordinate_strand=False: upstream_step=+1, downstream=-1 + # shift(10) downstream → subtract 10 from both + dis = _dis(start=30, end=150, on_coordinate_strand=False) + shifted = dis.shift(10) + self.assertEqual(shifted.start, 20) + self.assertEqual(shifted.end, 140) + + def test_shift_preserves_metadata(self): + dis = _dis(start=30, end=150, coord_id="mycoord", interval_id="myiv") + shifted = dis.shift(10) + self.assertEqual(shifted.coord_id, "mycoord") + self.assertEqual(shifted.id, "myiv") + self.assertTrue(shifted.on_coordinate_strand) + + def test_shift_preserves_coordinate_intervals(self): + dis = _dis(start=30, end=150) + shifted = dis.shift(10) + self.assertEqual(shifted.coordinate_intervals, dis.coordinate_intervals) + + +class TestExpand(unittest.TestCase): + + def test_expand_symmetric(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(5) + self.assertEqual(expanded.start, 25) + self.assertEqual(expanded.end, 155) + + def test_expand_asymmetric(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(5, 10) + self.assertEqual(expanded.start, 25) + self.assertEqual(expanded.end, 160) + + def test_expand_upstream_only(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(5, 0) + self.assertEqual(expanded.start, 25) + self.assertEqual(expanded.end, 150) + + def test_expand_downstream_only(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(0, 10) + self.assertEqual(expanded.start, 30) + self.assertEqual(expanded.end, 160) + + def test_expand_zero(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(0) + self.assertEqual(expanded, dis) + + def test_expand_negative_contracts(self): + dis = _dis(start=30, end=150) + contracted = dis.expand(-5, -10) + self.assertEqual(contracted.start, 35) + self.assertEqual(contracted.end, 140) + + def test_expand_contract_to_zero_length(self): + dis = _dis(start=30, end=150) # length=120 + contracted = dis.expand(-60, -60) + self.assertEqual(contracted.start, 90) + self.assertEqual(contracted.end, 90) + self.assertEqual(contracted.length, 0) + + def test_expand_over_contraction_raises(self): + dis = _dis(start=30, end=150) # length=120 + with self.assertRaises(ValueError): + dis.expand(-70, -70) + + def test_expand_opposite_strand(self): + # on_coordinate_strand=False: upstream_step=+1 + # end5=150, end3=30. expand(5): end5 moves to 155, end3 moves to 25 + # start=min(155,25)=25, end=max(155,25)=155 + dis = _dis(start=30, end=150, on_coordinate_strand=False) + expanded = dis.expand(5) + self.assertEqual(expanded.start, 25) + self.assertEqual(expanded.end, 155) + + def test_expand_zero_length_interval(self): + dis = _dis(start=50, end=50) + expanded = dis.expand(5) + self.assertEqual(expanded.start, 45) + self.assertEqual(expanded.end, 55) + self.assertEqual(expanded.length, 10) + + def test_expand_beyond_coordinate(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(50, 0) + self.assertEqual(expanded.start, -20) + + def test_expand_preserves_metadata(self): + dis = _dis(start=30, end=150, coord_id="c", interval_id="i") + expanded = dis.expand(5) + self.assertEqual(expanded.coord_id, "c") + self.assertEqual(expanded.id, "i") + self.assertTrue(expanded.on_coordinate_strand) + + def test_expand_preserves_coordinate_intervals(self): + dis = _dis(start=30, end=150) + expanded = dis.expand(5) + self.assertEqual(expanded.coordinate_intervals, dis.coordinate_intervals) + + +class TestUpstreamOf(unittest.TestCase): + + def test_upstream_of_true(self): + a = _dis(start=10, end=30) + b = _dis(start=50, end=80) + self.assertTrue(a.upstream_of(b)) + + def test_upstream_of_false_overlap(self): + a = _dis(start=10, end=60) + b = _dis(start=50, end=80) + self.assertFalse(a.upstream_of(b)) + + def test_upstream_of_adjacent(self): + a = _dis(start=10, end=50) + b = _dis(start=50, end=80) + self.assertTrue(a.upstream_of(b)) + + def test_upstream_of_same_false(self): + a = _dis(start=30, end=50) + self.assertFalse(a.upstream_of(a)) + + def test_upstream_of_zero_length(self): + a = _dis(start=30, end=30) + b = _dis(start=50, end=80) + self.assertTrue(a.upstream_of(b)) + + def test_upstream_of_zero_length_same_pos(self): + a = _dis(start=50, end=50) + b = _dis(start=50, end=80) + self.assertTrue(a.upstream_of(b)) + + def test_upstream_of_both_zero_length_same_pos(self): + a = _dis(start=50, end=50) + b = _dis(start=50, end=50) + self.assertFalse(a.upstream_of(b)) + + def test_upstream_of_opposite_strand(self): + # on_coordinate_strand=False: upstream_step=+1, upstream = higher indices + # a.start(100) >= b.end(80) → True + a = _dis(start=100, end=150, on_coordinate_strand=False) + b = _dis(start=50, end=80, on_coordinate_strand=False) + self.assertTrue(a.upstream_of(b)) + + def test_different_coord_space_raises(self): + a = _dis(start=10, end=30) + other_ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 500, 600)]) + b = _dis(start=10, end=30, ivs=other_ivs) + with self.assertRaises(ValueError): + a.upstream_of(b) + + def test_different_coord_id_allowed(self): + a = _dis(start=10, end=30, coord_id="a") + b = _dis(start=50, end=80, coord_id="b") + self.assertTrue(a.upstream_of(b)) + + def test_different_on_coord_strand_raises(self): + a = _dis(start=10, end=30, on_coordinate_strand=True) + b = _dis(start=50, end=80, on_coordinate_strand=False) + with self.assertRaises(ValueError): + a.upstream_of(b) + + def test_non_dis_raises(self): + a = _dis(start=10, end=30) + with self.assertRaises(TypeError): + a.upstream_of("not a DIS") + + +class TestDnstreamOf(unittest.TestCase): + + def test_dnstream_of_true(self): + a = _dis(start=50, end=80) + b = _dis(start=10, end=30) + self.assertTrue(a.dnstream_of(b)) + + def test_dnstream_of_false(self): + a = _dis(start=10, end=30) + b = _dis(start=50, end=80) + self.assertFalse(a.dnstream_of(b)) + + def test_dnstream_of_adjacent(self): + a = _dis(start=50, end=80) + b = _dis(start=10, end=50) + self.assertTrue(a.dnstream_of(b)) + + def test_dnstream_of_same_false(self): + a = _dis(start=30, end=50) + self.assertFalse(a.dnstream_of(a)) + + def test_dnstream_of_both_zero_length_same_pos(self): + a = _dis(start=50, end=50) + b = _dis(start=50, end=50) + self.assertFalse(a.dnstream_of(b)) + + def test_dnstream_of_opposite_strand(self): + # on_coordinate_strand=False: upstream_step=+1 + # downstream = lower indices. a.end(80) <= b.start(100) → True + a = _dis(start=50, end=80, on_coordinate_strand=False) + b = _dis(start=100, end=150, on_coordinate_strand=False) + self.assertTrue(a.dnstream_of(b)) + + def test_different_coord_space_raises(self): + a = _dis(start=50, end=80) + other_ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 500, 600)]) + b = _dis(start=10, end=30, ivs=other_ivs) + with self.assertRaises(ValueError): + a.dnstream_of(b) + + def test_different_on_coord_strand_raises(self): + a = _dis(start=50, end=80, on_coordinate_strand=True) + b = _dis(start=30, end=60, on_coordinate_strand=False) + with self.assertRaises(ValueError): + a.dnstream_of(b) + + +class TestWithin(unittest.TestCase): + + def test_within_true(self): + a = _dis(start=30, end=50) + b = _dis(start=10, end=80) + self.assertTrue(a.within(b)) + + def test_within_false(self): + a = _dis(start=10, end=80) + b = _dis(start=30, end=50) + self.assertFalse(a.within(b)) + + def test_within_self(self): + a = _dis(start=30, end=50) + self.assertTrue(a.within(a)) + + def test_within_zero_length(self): + a = _dis(start=50, end=50) + b = _dis(start=10, end=80) + self.assertTrue(a.within(b)) + + def test_within_at_boundary(self): + a = _dis(start=10, end=80) + b = _dis(start=10, end=80) + self.assertTrue(a.within(b)) + + def test_within_zero_length_at_boundary(self): + a = _dis(start=10, end=10) + b = _dis(start=10, end=80) + self.assertTrue(a.within(b)) + + def test_within_zero_length_outside(self): + a = _dis(start=5, end=5) + b = _dis(start=10, end=80) + self.assertFalse(a.within(b)) + + def test_within_opposite_strand(self): + a = _dis(start=80, end=120, on_coordinate_strand=False) + b = _dis(start=50, end=150, on_coordinate_strand=False) + self.assertTrue(a.within(b)) + + def test_different_coord_space_raises(self): + a = _dis(start=30, end=50) + other_ivs = _make_intervals([("chr1", "+", 100, 200), ("chr1", "+", 500, 600)]) + b = _dis(start=10, end=80, ivs=other_ivs) + with self.assertRaises(ValueError): + a.within(b) + + def test_different_on_coord_strand_raises(self): + a = _dis(start=30, end=50, on_coordinate_strand=True) + b = _dis(start=10, end=80, on_coordinate_strand=False) + with self.assertRaises(ValueError): + a.within(b) + + def test_non_dis_raises(self): + a = _dis(start=30, end=50) + with self.assertRaises(TypeError): + a.within("not a DIS") + + +if __name__ == "__main__": + unittest.main()