diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 08efdb08..588a1ae8 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -40,6 +40,8 @@ jobs: matrix: python-version: ['3.10', '3.11', '3.12'] platform: ["linux-64", "osx-arm64", "osx-64"] + extras: ["none", "df"] + rosetta: ["false"] include: # specify additional fields for all configs - python-version: "3.10" pyver-short: "310" @@ -53,6 +55,13 @@ jobs: runs-on: "macos-26-intel" - platform: "linux-64" runs-on: "ubuntu-latest" + # single test for running intel build on arm64 macos runner with rosetta + - python-version: "3.10" + pyver-short: "310" + platform: "osx-64" # intel build + runs-on: "macos-latest" # defaults to arm64 runner with M1 chip + extras: "df" + rosetta: "true" runs-on: ${{ matrix.runs-on }} steps: - uses: actions/checkout@v4 @@ -100,6 +109,9 @@ jobs: - name: run unittests id: run_unittests shell: bash -l -e {0} + env: + GK_EXTRAS: ${{ matrix.extras }} + ROSETTA: ${{ matrix.rosetta }} run: | set -x micromamba activate test @@ -108,7 +120,15 @@ jobs: if [ ! -e "${files[0]}" ]; then echo "No files matched for py${{ matrix.pyver-short }}" exit 1 - fi - conda mambabuild --croot /tmp/conda-bld -t $files --extra-deps python=${{ matrix.python-version }} + fi + extra_deps=(python=${{ matrix.python-version }}) + # run command with optional extra deps + if [ "$GK_EXTRAS" = "df" ]; then + extra_deps+=("polars=1.39.3") + fi + if [ "$ROSETTA" = "true" ]; then + extra_deps+=("polars-runtime-compat=1.39.3") + fi + conda mambabuild --croot /tmp/conda-bld -t "${files[@]}" --extra-deps "${extra_deps[@]}" conda clean -it - set +x + set +x \ No newline at end of file diff --git a/docs-src/df.rst b/docs-src/df.rst new file mode 100644 index 00000000..d7fb90dd --- /dev/null +++ b/docs-src/df.rst @@ -0,0 +1,67 @@ +.. _df: + +DataFrame Utilities +=================== + +The :py:mod:`genome_kit.df` subpackage contains utilities for working with Polars DataFrames that contain GenomeKit objects. This includes utilities for serializing DataFrames with GenomeKit objects to Parquet and deserializing them back to GenomeKit objects. This is useful when sharing tabular data sets, or when saving intermediate DataFrames to disk during data processing. + +.. important:: + + ``genome_kit.df`` depends on optional ``polars`` dependencies, which are not installed by default. These can be installed with the ``[df]`` extra: + + .. code-block:: bash + + pip install "genomekit[df]" + + The ``[df]`` extra is not included in the default ``genomekit`` installation. + + If you are running an x86 version of Python on an Apple Silicon Mac (e.g. M1 chip), this will also install the ``polars-runtime-compat`` package, which is required to run Polars on Apple Silicon due to AVX features compatibility issues. + + +Quickstart +----------- +The serialization and deserialization entry points are :py:func:`~genome_kit.df.read_parquet` and :py:func:`~genome_kit.df.write_parquet`: + +.. code-block:: python + + import polars as pl + import genome_kit as gk + + genome = gk.Genome("ncbi_refseq.v110") + df = pl.DataFrame( + { + "gene": [genome.genes[0], genome.genes[1]], + "score": [0.1, 0.8], + } + ) + + gk.write_parquet(df, "genes.parquet") + ... + ... + restored_df = gk.read_parquet("genes.parquet") + + +.. note:: + + The written parquet files can be read by any software that supports the parquet format, but the GenomeKit objects will only be restored when read with :py:func:`genome_kit.df.read_parquet`. + + +Supported GenomeKit Objects +--------------------------- +The currently supported GenomeKit objects for serialization are: + +- :py:class:`genome_kit.Genome` +- :py:class:`genome_kit.Interval` +- :py:class:`genome_kit.Transcript` +- :py:class:`genome_kit.Gene` +- :py:class:`genome_kit.Exon` +- :py:class:`genome_kit.Intron` +- :py:class:`genome_kit.CDS` +- :py:class:`genome_kit.UTR` + +Public API +---------------- +.. automodule:: genome_kit.df + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs-src/index.rst b/docs-src/index.rst index 46b7b4b6..ec365912 100644 --- a/docs-src/index.rst +++ b/docs-src/index.rst @@ -73,6 +73,7 @@ Contents: anchors api genomes + df develop data_org diff --git a/genome_kit/__init__.py b/genome_kit/__init__.py index dc1f1cb3..b888147a 100644 --- a/genome_kit/__init__.py +++ b/genome_kit/__init__.py @@ -49,6 +49,7 @@ from .variant_genome import VariantGenome from .vcf_table import VCFTable, VCFVariant from . import serialize +from .df import write_parquet, read_parquet ######################################################################### @@ -93,6 +94,7 @@ "JunctionTable", "ReadAlignments", "ReadDistributions", + "read_parquet", "Transcript", "TranscriptTable", "Utr", @@ -102,6 +104,7 @@ "VariantTable", "VCFTable", "VCFVariant", + "write_parquet", ] ######################################################################### diff --git a/genome_kit/_optional.py b/genome_kit/_optional.py new file mode 100644 index 00000000..f7eefab8 --- /dev/null +++ b/genome_kit/_optional.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from importlib.metadata import PackageNotFoundError, version + + +def require_polars(): + """Import Polars if available, otherwise provide helpful error messages. + + Also checks for compatibility on MacOS with Apple Silicon, which may require + an additional package if running Python under Rosetta translation. + """ + try: + import polars as pl + + if check_under_rosetta(): + if not check_rtcompat(): + raise ImportError( + "Polars is not compatible with Apple Silicon.\n" + "Please install with `pip install genomekit[df-mac]` to include " + "the polars-runtime-compat package required for Rosetta " + "translation." + ) + except ModuleNotFoundError as e: + raise ImportError( + "Optional dependency 'polars' is required for this functionality. Please " + "install with `pip install genomekit[df]`.\n" + "If you are running this on MacOS with Apple Silicon, please install with " + "`pip install genomekit[df-mac]` to include the polars-runtime-compat " + "package required for Rosetta translation." + ) from e + + return pl + + +def check_under_rosetta(): + """Check if program is running under Rosetta translation on Apple Silicon. + + The default version of Polars is incompatible with Rosetta, and requires + polars-runtime-compat to be installed. + + Can be checked with the sysctl.proc_translated flag in sysctl. + See https://developer.apple.com/documentation/apple-silicon/about-the-rosetta-translation-environment#Determine-Whether-Your-App-Is-Running-as-a-Translated-Binary + """ + import subprocess + + try: + result = subprocess.run( + ["sysctl", "-n", "sysctl.proc_translated"], + capture_output=True, + text=True, + check=True, + ) + # output will be 0 if running natively on Apple Silicon, and 1 if running under + # Rosetta translation + return result.stdout.strip() == "1" + except (subprocess.CalledProcessError, OSError): + # sysctl.proc_translated won't exist on non-Apple Silicon machines + return False + + +def check_rtcompat(): + """Check if polars-runtime-compat is installed. + + Required for Polars to run on MacOS machines under Rosetta translation. + """ + try: + version("polars-runtime-compat") + return True + except PackageNotFoundError: + return False diff --git a/genome_kit/df/__init__.py b/genome_kit/df/__init__.py new file mode 100644 index 00000000..317335f9 --- /dev/null +++ b/genome_kit/df/__init__.py @@ -0,0 +1,3 @@ +from .serialization import read_parquet, write_parquet + +__all__ = ["read_parquet", "write_parquet"] diff --git a/genome_kit/df/gk_structs.py b/genome_kit/df/gk_structs.py new file mode 100644 index 00000000..b2d0c8b7 --- /dev/null +++ b/genome_kit/df/gk_structs.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from genome_kit._optional import require_polars + +if TYPE_CHECKING: # import polars for type checking + import polars as pl + +# minimal shim for python <3.11 compatibility +try: + from enum import StrEnum, auto +except ImportError: + from enum import Enum, auto + + class StrEnum(str, Enum): + def __str__(self): + return str(self.value) + + @staticmethod + def _generate_next_value_(name, start, count, last_values): + return name.lower() + +# serializable representations of the supported GKDF types, with a one-to-one mapping +# between GkDfType and GenomeKit object types. Serves as the key for struct and function +# definitions in registry.py, keeping serialization and deserialization paths symmetric. +class GkDfType(StrEnum): + GENOME = auto() + INTERVAL = auto() + TRANSCRIPT = auto() + GENE = auto() + EXON = auto() + INTRON = auto() + CDS = auto() + UTR = auto() + + +class CellType(StrEnum): + SCALAR = auto() + LIST = auto() + + +@dataclass(frozen=True) +class ColumnInfo: + """Dataclass to store metadata about a single column in a dataframe. + + Assumes that all cells in a column have the same type. If the cell contains a list, + assumes all items in the list are of the same type. + """ + + cell_type: CellType + gkdf_type: GkDfType + + def to_dict(self) -> dict: + return { + "cell_type": self.cell_type.value, + "gkdf_type": self.gkdf_type.value, + } + + +class GkDfVersion(StrEnum): + V1 = "1.0" + + +CURRENT_VERSION = GkDfVersion.V1 + + +def get_structs() -> dict[GkDfType, pl.Struct]: + """Return a mapping of GkDfType to their corresponding Polars Struct definitions.""" + pl = require_polars() + + GenomeStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("genome_name", pl.Utf8), # reference or annotation genome + ] + ) + + IntervalStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("chromosome", pl.Utf8), + pl.Field("strand", pl.Utf8), + pl.Field("start", pl.Int32), + pl.Field("end", pl.Int32), + pl.Field("refg", pl.Utf8), # reference genome + ] + ) + + TranscriptStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + # index of transcript within annotation genome transcript table + # Int32 matches index type in C++ backend (see src/table.h:22) + pl.Field("transcript_table_index", pl.Int32), + pl.Field("anno", pl.Utf8), # annotation genome + ] + ) + + GeneStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("gene_table_index", pl.Int32), + pl.Field("anno", pl.Utf8), # annotation genome + ] + ) + + ExonStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("exon_table_index", pl.Int32), + pl.Field("anno", pl.Utf8), # annotation genome + ] + ) + + IntronStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("intron_table_index", pl.Int32), + pl.Field("anno", pl.Utf8), # annotation genome + ] + ) + + CdsStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("cds_table_index", pl.Int32), + pl.Field("anno", pl.Utf8), # annotation genome + ] + ) + + UtrType = pl.Enum(["5prime", "3prime"]) + + UtrStruct = pl.Struct( + [ + pl.Field("schema_version", pl.Utf8), + pl.Field("utr_type", UtrType), + pl.Field("utr_table_index", pl.Int64), + pl.Field("anno", pl.Utf8), # annotation genome + ] + ) + + return { + GkDfType.GENOME: GenomeStruct, + GkDfType.INTERVAL: IntervalStruct, + GkDfType.TRANSCRIPT: TranscriptStruct, + GkDfType.GENE: GeneStruct, + GkDfType.EXON: ExonStruct, + GkDfType.INTRON: IntronStruct, + GkDfType.CDS: CdsStruct, + GkDfType.UTR: UtrStruct, + } diff --git a/genome_kit/df/registry.py b/genome_kit/df/registry.py new file mode 100644 index 00000000..6ff7dcef --- /dev/null +++ b/genome_kit/df/registry.py @@ -0,0 +1,368 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + import polars as pl + +import genome_kit as gk +from genome_kit._optional import require_polars + +from .gk_structs import GkDfType, GkDfVersion, get_structs + +# mapping from GenomeKit object types to the gkdf type strings +GK_TO_GKDF_TYPE: dict[type[gk.GenomeAnnotation], GkDfType] = { + gk.Genome: GkDfType.GENOME, + gk.Interval: GkDfType.INTERVAL, + gk.Transcript: GkDfType.TRANSCRIPT, + gk.Gene: GkDfType.GENE, + gk.Exon: GkDfType.EXON, + gk.Intron: GkDfType.INTRON, + gk.Cds: GkDfType.CDS, + gk.Utr: GkDfType.UTR, +} + + +# entry for the gkdf registry +@dataclass +class GKTypeEntry: + struct: pl.Struct + serializer: Callable[[pl.Series], pl.Series] + deserializer: Callable[[pl.Series], pl.Series] + + +_SCHEMA_VERSION_FIELD = "schema_version" + +SUPPORTED_VERSIONS = {v for v in GkDfVersion.__members__.values()} + + +@lru_cache(maxsize=1) # cache to avoid recreating registry in same session +def get_registry() -> dict[GkDfVersion, dict[GkDfType, GKTypeEntry]]: + """Fetch the registry containing serialization and deserilization functions. + + Returns: + Dictionary mapping GkDfType to their corresponding serializer and deserializer + functions, for each supported GkDfVersion. + """ + pl = require_polars() + gkdf_structs = get_structs() + + def _serialize_genome(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Genome objects by genome name.""" + return pl.Series( + name=s.name, + values=[ + ( + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + # config gives annotation genome name if applicable + "genome_name": genome.config, + } + if genome is not None + else None + ) + for genome in s + ], + dtype=gkdf_structs[GkDfType.GENOME], + ) + + def _deserialize_genome(s: pl.Series) -> pl.Series: + """Deserialize a Series of GenomeStruct back into GenomeKit Genome objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Genome(struct["genome_name"]) if struct is not None else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_interval(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Interval objects.""" + return pl.Series( + name=s.name, + values=[ + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + "chromosome": interval.chromosome, + "strand": interval.strand, + "start": interval.start, + "end": interval.end, + # intervals related to reference genome only + "refg": interval.reference_genome, + } + if interval is not None + else None + for interval in s + ], + dtype=gkdf_structs[GkDfType.INTERVAL], + ) + + def _deserialize_interval(s: pl.Series) -> pl.Series: + """Deserialize a Series of IntervalStruct back into GenomeKit Interval objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Interval( + chromosome=struct["chromosome"], + strand=struct["strand"], + start=struct["start"], + end=struct["end"], + reference_genome=struct["refg"], + ) + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_transcript(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Transcript objects.""" + return pl.Series( + name=s.name, + values=[ + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + "transcript_table_index": transcript.annotation_genome.transcripts.index_of( + transcript + ), + "anno": transcript.annotation_genome.config, + } + if transcript is not None + else None + for transcript in s + ], + dtype=gkdf_structs[GkDfType.TRANSCRIPT], + ) + + def _deserialize_transcript(s: pl.Series) -> pl.Series: + """Deserialize a Series of TranscriptStruct back into GenomeKit Transcript objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Genome(struct["anno"]).transcripts[struct["transcript_table_index"]] + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_gene(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Gene objects.""" + return pl.Series( + name=s.name, + values=[ + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + "gene_table_index": gene.annotation_genome.genes.index_of(gene), + "anno": gene.annotation_genome.config, + } + if gene is not None + else None + for gene in s + ], + dtype=gkdf_structs[GkDfType.GENE], + ) + + def _deserialize_gene(s: pl.Series) -> pl.Series: + """Deserialize a Series of GeneStruct back into GenomeKit Gene objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Genome(struct["anno"]).genes[struct["gene_table_index"]] + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_exon(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Exon objects.""" + return pl.Series( + name=s.name, + values=[ + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + "exon_table_index": exon.annotation_genome.exons.index_of(exon), + "anno": exon.annotation_genome.config, + } + if exon is not None + else None + for exon in s + ], + dtype=gkdf_structs[GkDfType.EXON], + ) + + def _deserialize_exon(s: pl.Series) -> pl.Series: + """Deserialize a Series of ExonStruct back into GenomeKit Exon objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Genome(struct["anno"]).exons[struct["exon_table_index"]] + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_intron(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Intron objects.""" + return pl.Series( + name=s.name, + values=[ + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + "intron_table_index": intron.annotation_genome.introns.index_of( + intron + ), + "anno": intron.annotation_genome.config, + } + if intron is not None + else None + for intron in s + ], + dtype=gkdf_structs[GkDfType.INTRON], + ) + + def _deserialize_intron(s: pl.Series) -> pl.Series: + """Deserialize a Series of IntronStruct back into GenomeKit Intron objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Genome(struct["anno"]).introns[struct["intron_table_index"]] + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_cds(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Cds objects.""" + return pl.Series( + name=s.name, + values=[ + { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + "cds_table_index": cds.annotation_genome.cdss.index_of(cds), + "anno": cds.annotation_genome.config, + } + if cds is not None + else None + for cds in s + ], + dtype=gkdf_structs[GkDfType.CDS], + ) + + def _deserialize_cds(s: pl.Series) -> pl.Series: + """Deserialize a Series of CDSStruct back into GenomeKit Cds objects.""" + return pl.Series( + name=s.name, + values=[ + gk.Genome(struct["anno"]).cdss[struct["cds_table_index"]] + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + def _serialize_utr(s: pl.Series) -> pl.Series: + """Serialize a Series of GenomeKit Utr objects.""" + values = [] + for utr in s: + if utr is None: + values.append(None) + continue + ser_dict = { + _SCHEMA_VERSION_FIELD: GkDfVersion.V1.value, + } + genome = utr.annotation_genome + try: + ser_dict["utr_table_index"] = genome.utr5s.index_of(utr) + ser_dict["utr_type"] = "5prime" + except ValueError: + ser_dict["utr_table_index"] = genome.utr3s.index_of(utr) + ser_dict["utr_type"] = "3prime" + + ser_dict["anno"] = genome.config + values.append(ser_dict) + + # pl.Series constructor doesn't strictly enforce nested struct types. Errors + # with "utr_type" field may silently manifest as null values. + # TODO: check polars implementation, related issue #18841 + return pl.Series( + name=s.name, + values=values, + dtype=gkdf_structs[GkDfType.UTR], + ) + + def _deserialize_utr(s: pl.Series) -> pl.Series: + """Deserialize a Series of UtrStruct back into GenomeKit Utr objects.""" + return pl.Series( + name=s.name, + values=[ + ( + gk.Genome(struct["anno"]).utr5s[struct["utr_table_index"]] + if struct["utr_type"] == "5prime" + else gk.Genome(struct["anno"]).utr3s[struct["utr_table_index"]] + ) + if struct is not None + else None + for struct in s + ], + dtype=pl.Object, + ) + + REGISTRY: dict[GkDfVersion, dict[GkDfType, GKTypeEntry]] = { + GkDfVersion.V1: { + GkDfType.GENOME: GKTypeEntry( + struct=gkdf_structs[GkDfType.GENOME], + serializer=_serialize_genome, + deserializer=_deserialize_genome, + ), + GkDfType.INTERVAL: GKTypeEntry( + struct=gkdf_structs[GkDfType.INTERVAL], + serializer=_serialize_interval, + deserializer=_deserialize_interval, + ), + GkDfType.TRANSCRIPT: GKTypeEntry( + struct=gkdf_structs[GkDfType.TRANSCRIPT], + serializer=_serialize_transcript, + deserializer=_deserialize_transcript, + ), + GkDfType.GENE: GKTypeEntry( + struct=gkdf_structs[GkDfType.GENE], + serializer=_serialize_gene, + deserializer=_deserialize_gene, + ), + GkDfType.EXON: GKTypeEntry( + struct=gkdf_structs[GkDfType.EXON], + serializer=_serialize_exon, + deserializer=_deserialize_exon, + ), + GkDfType.INTRON: GKTypeEntry( + struct=gkdf_structs[GkDfType.INTRON], + serializer=_serialize_intron, + deserializer=_deserialize_intron, + ), + GkDfType.CDS: GKTypeEntry( + struct=gkdf_structs[GkDfType.CDS], + serializer=_serialize_cds, + deserializer=_deserialize_cds, + ), + GkDfType.UTR: GKTypeEntry( + struct=gkdf_structs[GkDfType.UTR], + serializer=_serialize_utr, + deserializer=_deserialize_utr, + ), + } + } + + return REGISTRY diff --git a/genome_kit/df/serialization.py b/genome_kit/df/serialization.py new file mode 100644 index 00000000..228bc0ef --- /dev/null +++ b/genome_kit/df/serialization.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import functools +import itertools +import json +import warnings +from collections.abc import Callable +from inspect import signature +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import polars as pl + +import genome_kit as gk +from genome_kit._optional import require_polars + +from .gk_structs import CURRENT_VERSION, CellType, ColumnInfo, GkDfType, GkDfVersion +from .registry import GK_TO_GKDF_TYPE, get_registry + + +def _map_batches_safe(fn: Callable) -> Callable: + """Helper function to wrap a UDF and run safely with polars map_batches. + + Polars has a bug in map_batches that incorrectly forwards the return_dtype argument + to the UDF. See https://github.com/pola-rs/polars/issues/24840. + + Args: + fn: The user defined function to wrap. + + Returns: + A wrapped version of the UDF that can be safely used with map_batches. + """ + sig = signature(fn) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + accepted = sig.parameters + filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted} + return fn(*args, **filtered_kwargs) + + return wrapper + + +def _detect_gk_cols( + lf: pl.LazyFrame, infer_schema_length: int = 100 +) -> dict[str, ColumnInfo]: + """Detect columns in the LazyFrame that contains GenomeKit objects. + + Args: + lf: The LazyFrame to inspect. + infer_schema_length: The number of rows to use for schema inference when + detecting GenomeKit columns. + + Returns: + A dictionary mapping column names to the ColumnInfo dataclass containing the + GkDfType and CellType for the column. + """ + pl = require_polars() + + lf_cols = lf.collect_schema().names() + + target_cols = {} + + # datatype inference done on first n=infer_schema_length rows. Follows inference + # logic from Polars DataFrames when rows are provided. + # see https://github.com/pola-rs/polars/blob/1cd236c60c01572c5ec6fdd252d8b20218d7b440/py-polars/src/polars/dataframe/frame.py#L248-L251 + head = lf.head(infer_schema_length).collect() + + for col in lf_cols: + # remove nulls for type inference, list/scalar cols depend on first non-null value + vals = head.get_column(col).drop_nulls() # removes scalar nulls + + # column only contains null values in the first infer_schema_length rows + if len(vals) == 0: + warnings.warn( + f"Column {col} contains only null values in the first {infer_schema_length} rows, " + "unable to infer type for serialization. Please ensure this column " + "contains non-null values for accurate serialization." + ) + continue + + first = vals[0] + head_types = {type(v) for v in vals} + + if isinstance(first, list): + if head_types != {list}: + raise ValueError( + f"Column {col} contains mixed data types: {list(itertools.islice(head_types, 3))}.\n" + "Please ensure all cells are the same type before serialization." + ) + cell_type = CellType.LIST + col_types = {type(item) for v in vals for item in v if item is not None} + else: + cell_type = CellType.SCALAR + col_types = set(vals.map_elements(type, return_dtype=pl.Object)) + + if len(col_types) != 1: + raise ValueError( + f"Column {col} contains mixed data types: {list(itertools.islice(col_types, 3))}.\n" + "Please ensure all cells are the same type before serialization." + ) + + col_type = GK_TO_GKDF_TYPE.get(col_types.pop(), None) + + if col_type is None: + # column is not a genomekit type, so no serialization needed + continue + + target_cols[col] = ColumnInfo(cell_type=cell_type, gkdf_type=col_type) + + return target_cols + + +def _list_serializer( + serializer: Callable[[pl.Series], pl.Series], return_dtype: Any +) -> Callable[[pl.Series], pl.Series]: + """Helper function to convert a serializer to accept lists of GenomeKit objects. + + Args: + serializer: A serializer function for a series of GenomeKit objects + return_dtype: The return data type for the serialized series + + Returns: + A serializer function for a series of lists of GenomeKit objects. + """ + pl = require_polars() + + def _serialize_list(s: pl.Series) -> pl.Series: + return pl.Series( + name=s.name, + values=[ + serializer(pl.Series(values=l)).to_list() if l is not None else None + for l in s + ], + dtype=return_dtype, + ) + + return _serialize_list + + +def _init_gk_annotations( + lf: pl.LazyFrame, target_cols: dict[str, dict] +) -> list[gk.Genome]: + """Initialize GenomeKit annotations for all unique genomes in the LazyFrame. + + Prevents race conditions when opening dganno files during polars operations. + Objects are returned in a list to keep weak references alive. + + Args: + lf: The LazyFrame containing the serialized GenomeKit objects. + target_cols: A dictionary mapping column names to their column information. + Each value is a dictionary representation of the ColumnInfo dataclass. + + Returns: + A list of initialized gene tables for the unique genomes in the LazyFrame. + """ + pl = require_polars() + + def genome_str_field(col_info: dict) -> str: + gkdf_type = col_info["gkdf_type"] + if gkdf_type == GkDfType.GENOME: + return "genome_name" + elif gkdf_type == GkDfType.INTERVAL: + return "refg" + else: + return "anno" + + anno_strong_refs = [] + + # extract genome_str field from every column + genomes_exprs = [] + genomes_list_exprs = [] + + for c in target_cols.keys(): + genome_field = genome_str_field(target_cols[c]) + if target_cols[c]["cell_type"] == CellType.SCALAR: + genomes_exprs.append(pl.col(c).struct.field(genome_field)) + else: + genomes_list_exprs.append(pl.col(c).explode().struct.field(genome_field)) + + # expressions to extract genome_str must be run separately since exploded lists + # may have more rows than the original dataframe + plans = [] + + if genomes_exprs: + plans.append( + lf.select( + pl.concat_list(genomes_exprs) + .explode() + .drop_nulls() + .unique() + .alias("genome_str") + ) + ) + + if genomes_list_exprs: + plans.append( + lf.select( + pl.concat(genomes_list_exprs) + .explode() + .drop_nulls() + .unique() + .alias("genome_str") + ) + ) + + genomes = pl.concat(plans).unique().collect()["genome_str"].to_list() + + # warms annotations for all unique annotation genomes in the file. + # all annotations available for serialization are contained in dganno file + for genome_str in genomes: + genome = gk.Genome(genome_str) + try: + anno_strong_refs.append(genome.genes) + except ValueError: + # reference genomes don't have annotations + continue + + return anno_strong_refs + + +def _validate_gkdf_metadata(metadata: dict[str, str]) -> None: + """Validate the parquet metadata for a gkdf parquet file. + + Args: + metadata: The parquet metadata to validate. + """ + # gkdf version + metadata_version = metadata.get("gkdf_version") + version = GkDfVersion(metadata_version) if metadata_version is not None else None + if version != CURRENT_VERSION: + raise ValueError( + f"Invalid or missing gkdf_version in Parquet metadata, unable to deserialize GenomeKit objects. " + f"Expected GkDfVersion {CURRENT_VERSION}, but found {version}." + ) + + # target cols + if metadata.get("target_cols") is None: + raise ValueError( + "Missing target_cols in Parquet metadata, unable to deserialize GenomeKit objects." + ) + + # gk version + gk_version = metadata.get("gk_version") + if gk_version is None: + raise ValueError("Missing gk_version in Parquet metadata.") + elif gk_version != gk.__version__: + warnings.warn( + f"Parquet file was written with GenomeKit version {gk_version}, but current version is {gk.__version__}. " + "Deserializing GenomeKit objects may not be consistent across versions." + ) + + +def _list_deserializer( + deserializer: Callable[[pl.Series], pl.Series], +) -> Callable[[pl.Series], pl.Series]: + """Helper function to convert a deserializer to accept lists of serialized GenomeKit objects. + + Args: + deserializer: A deserializer function for a series of serialized GenomeKit objects + + Returns: + A deserializer function for a series of lists of serialized GenomeKit objects. + """ + pl = require_polars() + + def _deserialize_list(s: pl.Series) -> pl.Series: + return pl.Series( + name=s.name, + values=[ + deserializer(pl.Series(values=l)).to_list() if l is not None else None + for l in s + ], + dtype=pl.Object, + ) + + return _deserialize_list + + +def _deserialize_gk_cols( + lf: pl.LazyFrame, target_cols: dict[str, dict] +) -> pl.LazyFrame: + """Deserialize columns containing GenomeKit objects. + + Args: + lf: The LazyFrame containing the serialized GenomeKit objects. + target_cols: A dictionary mapping column names to their column information. + Each value is a dictionary representation of the ColumnInfo dataclass. + + Returns: + A LazyFrame with deserialized GenomeKit objects in the target columns. + """ + pl = require_polars() + registry = get_registry() + + def _build_deserialization_expr(col: str) -> pl.Expr: + col_info = target_cols[col] # dict representation of ColumnInfo + gkdf_type = col_info["gkdf_type"] + if col_info["cell_type"] == CellType.LIST: + deserializer = _list_deserializer( + registry[CURRENT_VERSION][gkdf_type].deserializer + ) + else: + deserializer = registry[CURRENT_VERSION][gkdf_type].deserializer + + return ( + pl.col(col) + .map_batches( + _map_batches_safe(deserializer), + return_dtype=pl.Object, + ) + .alias(col) + ) + + # with_columns_seq provides a 2x speedup here over with_columns + return lf.with_columns_seq(_build_deserialization_expr(col) for col in target_cols) + + +# TODO: add union of pd.DataFrame +def write_parquet( + df: pl.DataFrame | pl.LazyFrame, path: str | Path, infer_schema_length: int = 100 +) -> None: + """Serialize a DataFrame with GenomeKit objects to a Parquet file. + + Args: + df: A Polars DataFrame or LazyFrame with columns containing GenomeKit objects. + path: The file path to write the Parquet file to. + infer_schema_length: The number of rows to use for schema inference when writing the Parquet file. + """ + pl = require_polars() + + path = Path(path) + if isinstance(df, pl.DataFrame): + df = df.lazy() + + # mapping from column name to ColumnInfo dataclass + target_cols = _detect_gk_cols(df, infer_schema_length=infer_schema_length) + + if not target_cols: + warnings.warn( + "No GenomeKit columns detected for serialization, writing DataFrame as is." + ) + df.sink_parquet(path) + return + + registry = get_registry() + + def _build_serialization_expr(col: str) -> pl.Expr: + col_info = target_cols[col] # ColumnInfo dataclass + gkdf_type = col_info.gkdf_type + if col_info.cell_type == CellType.LIST: + return_dtype = pl.List(inner=registry[CURRENT_VERSION][gkdf_type].struct) + serializer = _list_serializer( + registry[CURRENT_VERSION][gkdf_type].serializer, + return_dtype=return_dtype, + ) + else: + return_dtype = registry[CURRENT_VERSION][gkdf_type].struct + serializer = registry[CURRENT_VERSION][gkdf_type].serializer + + return ( + pl.col(col) + .map_batches( + _map_batches_safe(serializer), + return_dtype=return_dtype, + ) + .alias(col) + ) + + df = df.with_columns(_build_serialization_expr(col) for col in target_cols) + + # convert ColumnInfo dataclass to a serializable format + target_col_metadata = {col: target_cols[col].to_dict() for col in target_cols} + + metadata = { + "gkdf_version": CURRENT_VERSION.value, + "gk_version": gk.__version__, + "target_cols": json.dumps(target_col_metadata), + } + + df.sink_parquet(path, metadata=metadata) + + +def read_parquet(path: str | Path, lazy: bool = False) -> pl.DataFrame | pl.LazyFrame: + """Deserialize a Parquet file containing GenomeKit objects into a Polars DataFrame or LazyFrame. + + Args: + path: The file path to read the Parquet file from. + lazy: If True, return a LazyFrame. Otherwise, return a DataFrame. + + Returns: + A Polars DataFrame or LazyFrame with deserialized GenomeKit objects. + """ + pl = require_polars() + + path = Path(path) + metadata = pl.read_parquet_metadata(path) + _validate_gkdf_metadata(metadata) + target_cols = json.loads(metadata.get("target_cols")) + + lf = pl.scan_parquet(path) + + # collect unique genome strings in the file and initialize, prevents race conditions + # on opening dganno files. + # genomes returned in dummy variable to keep weak reference alive for deserialization + _ = _init_gk_annotations(lf, target_cols) + + lf = _deserialize_gk_cols(lf, target_cols) + + return lf if lazy else lf.collect() diff --git a/setup.py b/setup.py index 41f11a9a..d8a2019f 100644 --- a/setup.py +++ b/setup.py @@ -403,10 +403,18 @@ def _compile_obj(obj): "importlib-metadata", "typing-extensions", ], + extras_require={ + # install polars-runtime-compat if running on x86_64 Python on macOS + # required to run polars due to AVX features compatibility issues + "df": [ + "polars", + "polars-runtime-compat; sys_platform == 'darwin' and platform_machine == 'x86_64'", + ] + }, license="Apache License 2.0", license_files=(COPYRIGHT_FILE, LICENSE_FILE,), name="genomekit", - packages=find_packages(include=["genome_kit"]), + packages=find_packages(include=["genome_kit", "genome_kit.*"]), project_urls={ "Documentation": "https://deepgenomics.github.io/GenomeKit" }, diff --git a/tests/test_gkdf.py b/tests/test_gkdf.py new file mode 100644 index 00000000..be20904e --- /dev/null +++ b/tests/test_gkdf.py @@ -0,0 +1,294 @@ +import importlib.util +import json +import tempfile +import unittest +from pathlib import Path + +from genome_kit import Genome, Interval +from genome_kit.df import read_parquet, write_parquet +from genome_kit.df.gk_structs import CURRENT_VERSION + +from . import MiniGenome + +HAS_POLARS = importlib.util.find_spec("polars") is not None +if HAS_POLARS: + import polars as pl + + +class TestGkdfRoundTrip(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tmp_dir = tempfile.TemporaryDirectory() + cls.addClassCleanup(cls.tmp_dir.cleanup) + cls.tmp_dir_path = Path(cls.tmp_dir.name) + + @unittest.skip("MiniGenome and Genome type mismatch") + def test_genome(self): + # plain reference genome as well as gencode and refseq annotations + genomes = ["hg38.p12", "gencode.v41", "ucsc_refseq.2017-06-25"] + + for genome_str in genomes: + g = MiniGenome(genome_str) + df = pl.DataFrame({"genome": [g]}) + + path = self.tmp_dir_path / f"{genome_str}.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_interval(self): + interval = Interval("chr5", "+", 2000, 3000, "hg19") + df = pl.DataFrame({"interval": [interval]}) + + path = self.tmp_dir_path / "interval.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_transcript(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + transcript = g.genes[0].transcripts[0] + df = pl.DataFrame({"transcript": [transcript]}) + + path = self.tmp_dir_path / f"{genome_str}_transcript.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_gene(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + gene = g.genes[0] + df = pl.DataFrame({"gene": [gene]}) + + path = self.tmp_dir_path / f"{genome_str}_gene.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_exon(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + exon = g.exons[0] + df = pl.DataFrame({"exon": [exon]}) + + path = self.tmp_dir_path / f"{genome_str}_exon.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_intron(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + intron = g.introns[0] + df = pl.DataFrame({"intron": [intron]}) + + path = self.tmp_dir_path / f"{genome_str}_intron.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_cds(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + cds = g.cdss[0] + df = pl.DataFrame({"cds": [cds]}) + + path = self.tmp_dir_path / f"{genome_str}_cds.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_utr3(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + utr3 = g.utr3s[0] + df = pl.DataFrame({"utr3": [utr3]}) + + path = self.tmp_dir_path / f"{genome_str}_utr3.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_utr5(self): + genomes = ["gencode.v41", "ucsc_refseq.2017-06-25"] + for genome_str in genomes: + g = Genome(genome_str) + utr5 = g.utr5s[0] + df = pl.DataFrame({"utr5": [utr5]}) + + path = self.tmp_dir_path / f"{genome_str}_utr5.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_list_of_intervals(self): + intervals = [ + Interval("chr1", "+", 2000, 3000, "hg19"), + Interval("chr4", "-", 5000, 6000, "hg19"), + ] + df = pl.DataFrame({"intervals": [intervals]}, schema={"intervals": pl.Object}) + + path = self.tmp_dir_path / "list_of_intervals.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_list_of_genomes(self): + genomes = [Genome("hg38.p12"), Genome("gencode.v41")] + df = pl.DataFrame({"genomes": [genomes]}, schema={"genomes": pl.Object}) + + path = self.tmp_dir_path / "list_of_genomes.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_list_of_transcripts(self): + g = Genome("gencode.v41") + transcripts = list(g.transcripts)[:10] + df = pl.DataFrame( + {"transcripts": [transcripts]}, schema={"transcripts": pl.Object} + ) + + path = self.tmp_dir_path / "list_of_transcripts.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_list_of_gk_with_null(self): + g = Genome("gencode.v41") + transcripts = list(g.transcripts)[:10] + transcripts[:3] = [None] * 3 + df = pl.DataFrame( + {"transcripts": [transcripts]}, schema={"transcripts": pl.Object} + ) + + path = self.tmp_dir_path / "list_of_transcripts_with_null.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df.item(), df.item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_multiple_types(self): + g = Genome("gencode.v41") + + interval = Interval("chr5", "+", 2000, 3000, "hg19") + transcript = g.genes[0].transcripts[0] + gene = g.genes[0] + exon = g.exons[0] + + df = pl.DataFrame( + { + "interval": [interval], + "transcript": [transcript], + "gene": [gene], + "exon": [exon], + } + ) + + path = self.tmp_dir_path / "multiple_types.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df["interval"].item(), df["interval"].item()) + self.assertEqual(re_df["transcript"].item(), df["transcript"].item()) + self.assertEqual(re_df["gene"].item(), df["gene"].item()) + self.assertEqual(re_df["exon"].item(), df["exon"].item()) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_multiple_genomes(self): + # test dataframe with multiple reference genomes in a single column + g1 = Genome("gencode.v41") + g2 = Genome("ucsc_refseq.2017-06-25") + + genes = [g1.genes[0], g2.genes[0]] + df = pl.DataFrame({"genes": genes}, schema={"genes": pl.Object}) + + path = self.tmp_dir_path / "multiple_genomes.parquet" + write_parquet(df, path) + re_df = read_parquet(path, lazy=False) + self.assertEqual(re_df["genes"][0], df["genes"][0]) + self.assertEqual(re_df["genes"][1], df["genes"][1]) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_mismatch_types(self): + # test that error is raised when cols have different types + g = Genome("gencode.v41") + gene = g.genes[0] + interval = Interval("chr5", "+", 2000, 3000, "hg19") + + df = pl.DataFrame({"mixed": [gene, interval]}, schema={"mixed": pl.Object}) + path = self.tmp_dir_path / "mismatch_types.parquet" + with self.assertRaises(ValueError): + write_parquet(df, path) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_mismatch_list_types(self): + # test that error is raised when cols have different types + g = Genome("gencode.v41") + gene = g.genes[0] + + df = pl.DataFrame({"mixed": [gene, [gene]]}, schema={"mixed": pl.Object}) + path = self.tmp_dir_path / "mismatch_list_types.parquet" + with self.assertRaises(ValueError): + write_parquet(df, path) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_no_gkdf_version(self): + # test that error raised when no gkdf version is found in metadata + df = pl.DataFrame({"genome": ["hg38.p12"]}) + + path = self.tmp_dir_path / "no_gkdf_version.parquet" + df.write_parquet(path, metadata={"some_other_key": "value"}) + with self.assertRaises(ValueError): + read_parquet(path, lazy=False) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_no_target_cols(self): + # test that error raised when no target_cols is found in metadata + df = pl.DataFrame({"genome": ["hg38.p12"]}) + + path = self.tmp_dir_path / "no_target_cols.parquet" + df.write_parquet(path, metadata={"gkdf_version": CURRENT_VERSION}) + with self.assertRaises(ValueError): + read_parquet(path, lazy=False) + + @unittest.skipUnless(HAS_POLARS, "Polars is required for this genome_kit.df tests") + def test_no_gk_version(self): + # test that error raised when no gk version is found in metadata + df = pl.DataFrame({"genome": ["hg38.p12"]}) + + path = self.tmp_dir_path / "no_gk_version.parquet" + target_cols = {"genome": {"cell_type": "scalar", "gkdf_type": "genome"}} + df.write_parquet( + path, + metadata={ + "gkdf_version": CURRENT_VERSION, + "target_cols": json.dumps(target_cols), + }, + ) + with self.assertRaises(ValueError): + read_parquet(path, lazy=False) + + +if __name__ == "__main__": + unittest.main()