Skip to content
30 changes: 18 additions & 12 deletions cognite/client/_api/data_modeling/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import random
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from copy import copy
from datetime import datetime, timedelta, timezone
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -894,6 +895,7 @@ async def subscribe(
>>> subscription_context.cancel()

"""
query = query._get_query_with_defaults_applied()
subscription_context = SubscriptionContext()

async def _poll_loop() -> None:
Expand Down Expand Up @@ -968,20 +970,20 @@ def _create_other_params(
f"Received in `sources` argument for views: {with_properties}."
)
if sort:
if isinstance(sort, (InstanceSort, dict)):
other_params["sort"] = [cls._dump_instance_sort(sort)]
else:
other_params["sort"] = [cls._dump_instance_sort(s) for s in sort]
sorts_seq = [sort] if isinstance(sort, (InstanceSort, dict)) else list(sort)
result = []
for s in sorts_seq:
if isinstance(s, InstanceSort):
s = copy(s)._apply_defaults_or_maybe_warn().dump(camel_case=True)
result.append(s)
other_params["sort"] = result

if instance_type:
other_params["instanceType"] = instance_type
if debug:
other_params["debug"] = debug.dump()
return other_params

@staticmethod
def _dump_instance_sort(sort: InstanceSort | dict) -> dict:
return sort.dump(camel_case=True) if isinstance(sort, InstanceSort) else sort

async def apply(
self,
nodes: NodeApply | Sequence[NodeApply] | None = None,
Expand Down Expand Up @@ -1304,11 +1306,14 @@ async def search(
body["targetUnits"] = [unit.dump(camel_case=True) for unit in target_units]
if sort:
sorts = sort if isinstance(sort, Sequence) else [sort]
for sort_spec in sorts:
nulls_first = sort_spec.get("nullsFirst") if isinstance(sort_spec, dict) else sort_spec.nulls_first
if nulls_first is not None:
result = []
for s in sorts:
if isinstance(s, InstanceSort):
s = s.dump(camel_case=True)
if s.get("nullsFirst") is not None:
raise ValueError("nulls_first argument is not supported when sorting on instance search")
body["sort"] = [self._dump_instance_sort(s) for s in sorts]
result.append(s)
body["sort"] = result

semaphore = self._get_semaphore("search")
res = await self._post(url_path=self._RESOURCE_PATH + "/search", json=body, semaphore=semaphore)
Expand Down Expand Up @@ -1758,6 +1763,7 @@ async def _query_or_sync(
include_typing: bool,
debug: DebugParameters | None,
) -> QueryResult:
query = query._get_query_with_defaults_applied()
headers: None | dict[str, str] = None
body = query.dump(camel_case=True)
if include_typing:
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/_sync_api/data_modeling/instances.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
===============================================================================
c0af4f83cd8ffa0aaed6fa641bde9a98
15c69385f1ab10e31adbde3483ee7588
This file is auto-generated from the Async API modules, - do not edit manually!
===============================================================================
"""
Expand Down
105 changes: 102 additions & 3 deletions cognite/client/data_classes/data_modeling/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
overload,
)

from typing_extensions import Self
from typing_extensions import Self, override

from cognite.client._constants import OMITTED, Omitted
from cognite.client.data_classes._base import (
Expand Down Expand Up @@ -1344,13 +1344,112 @@ class InstancesApply:


class InstanceSort(DataModelingSort):
"""Sort order for an instance query.

Args:
property (list[str] | tuple[str, str] | tuple[str, str, str]): The property to sort by, given as a path, e.g.
``("mySpace", "myView/v1", "myProperty")`` or ``["node", "externalId"]``.
direction (Literal['ascending', 'descending']): Sort direction. Case-insensitive. Defaults to ``"ascending"``.
nulls_first (bool | None): Where to place ``null`` values. Defaults to ``None`` (auto). See tip below.

Tip:
For the backend database to use an index when sorting nullable properties, the ``nulls_first`` setting
must match the sort direction:

- ``ascending`` → nulls last (``nulls_first=False``)
- ``descending`` → nulls first (``nulls_first=True``)

When ``nulls_first=None`` (the default), the correct value is chosen automatically. Passing the
opposite combination is still accepted and sent to the API as-is, but may trigger a warning
for API endpoints that support index utilization if an unsupported combination is used.

Examples:

Sort by a view property ascending (default):

>>> from cognite.client.data_classes.data_modeling import InstanceSort
>>> sort = InstanceSort(("mySpace", "myView/v1", "myProperty"))

Can also use a ViewId to simplify the property path:

>>> from cognite.client.data_classes.data_modeling import ViewId
>>> view_id = ViewId("mySpace", "myView", "v1")
>>> sort = InstanceSort(view_id.as_property_ref("myProperty"))

Sort descending:

>>> sort = InstanceSort(
... view_id.as_property_ref("myProperty"),
... direction="descending",
... )

Sort by a base property:

>>> sort = InstanceSort(["node", "externalId"], direction="ascending")

Force a specific null placement (first/last). A UserWarning will fire at relevant API call
sites when this conflicts with index alignment:

>>> sort = InstanceSort(
... ("mySpace", "myView/v1", "myProperty"),
... direction="descending",
... nulls_first=True,
... )
"""

def __init__(
self,
property: list[str] | tuple[str, ...],
property: list[str] | tuple[str, str] | tuple[str, str, str],
direction: Literal["ascending", "descending"] = "ascending",
nulls_first: bool | None = None,
) -> None:
super().__init__(property, direction, nulls_first)
normalized = direction.casefold()
if normalized not in ("ascending", "descending"):
raise ValueError(f"direction must be 'ascending' or 'descending', got {direction!r}")

super().__init__(property, normalized, nulls_first) # type: ignore [arg-type]

# We override _load to get the more strict __init__ validation on 'direction' because we need it to
# be valid for the possible later automatic choice of nulls_first:
@override
@classmethod
def _load(cls, resource: dict[str, Any]) -> Self:
if not isinstance(resource, dict):
raise TypeError(f"Resource must be mapping, not {type(resource)}")

return cls(
property=resource["property"],
direction=resource.get("direction", "ascending"),
nulls_first=resource.get("nullsFirst"),
)

@property
def is_index_aligned(self) -> bool:
"""True when nulls_first matches the direction for PostgreSQL index utilization (None counts as aligned)."""
if self.nulls_first is None:
return True
return self.nulls_first is (self.direction == "descending")

def _apply_defaults_or_maybe_warn(self) -> Self:
"""Resolve nulls_first for database index alignment, warning if the explicit value is misaligned.

When nulls_first is None, sets it to the index-compatible value. When explicitly set but
misaligned, emits a UserWarning.
"""
if self.nulls_first is None:
self.nulls_first = self.direction == "descending"

elif not self.is_index_aligned:
import warnings

warnings.warn(
f"InstanceSort: direction={self.direction!r} with nulls_first={self.nulls_first} is not "
f"index-aligned and will likely prevent database index utilization. "
f"Use nulls_first={self.direction == 'descending'} (or omit it) for optimal performance.",
UserWarning,
stacklevel=3,
)
return self


@dataclass
Expand Down
36 changes: 35 additions & 1 deletion cognite/client/data_classes/data_modeling/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from abc import ABC
from collections import UserDict
from collections.abc import Mapping, MutableMapping, Sequence
from collections.abc import Iterator, Mapping, MutableMapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

Expand Down Expand Up @@ -85,6 +86,9 @@ def _load_list(
class SelectBase(CogniteResource, ABC):
sources: list[SourceSelector] = field(default_factory=list)

def _iter_sorts(self) -> Iterator[InstanceSort]:
yield from ()

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output: dict[str, Any] = {}
if self.sources:
Expand Down Expand Up @@ -133,6 +137,9 @@ class Select(SelectBase):
sort: list[InstanceSort] = field(default_factory=list)
limit: int | None = None

def _iter_sorts(self) -> Iterator[InstanceSort]:
yield from self.sort

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = super().dump(camel_case)
if self.sort:
Expand Down Expand Up @@ -179,6 +186,20 @@ def instance_type_by_result_expression(self) -> dict[str, type[NodeListWithCurso
for k, v in self.with_.items()
}

def _iter_sorts(self) -> Iterator[InstanceSort]:
for expr in self.with_.values():
yield from expr._iter_sorts()
for sel in self.select.values():
yield from sel._iter_sorts()

def _get_query_with_defaults_applied(self) -> Self:
"""TODO: We could verify (or just warn), when Query and Sync-versions are mixed or used in the wrong setting."""
# We don't want to mutate the user's original query, so we make a deepcopy and apply defaults to that:
query = deepcopy(self)
for sort in query._iter_sorts():
sort._apply_defaults_or_maybe_warn()
return query

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output: dict[str, Any] = {
"with": {k: v.dump(camel_case) for k, v in self.with_.items()},
Expand Down Expand Up @@ -281,6 +302,9 @@ class ResultSetExpressionBase(CogniteResource, ABC):
def _load_sort(resource: dict[str, Any], name: str) -> list[InstanceSort]:
return [InstanceSort.load(sort) for sort in resource.get(name, [])]

def _iter_sorts(self) -> Iterator[InstanceSort]:
yield from ()

@staticmethod
def _init_through(through: list[str] | tuple[str, str, str] | PropertyId | None) -> PropertyId | None:
def error() -> Never:
Expand Down Expand Up @@ -336,6 +360,9 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return type(self) is type(other) and self.dump() == other.dump()

def _iter_sorts(self) -> Iterator[InstanceSort]:
yield from self.sort


@dataclass(eq=False) # Prevents @dataclass from generating its own __eq__, so the parent's is used
class NodeResultSetExpression(NodeOrEdgeResultSetExpression):
Expand Down Expand Up @@ -429,6 +456,10 @@ class EdgeResultSetExpression(NodeOrEdgeResultSetExpression):
limit_each: int | None = None
post_sort: list[InstanceSort] = field(default_factory=list)

def _iter_sorts(self) -> Iterator[InstanceSort]:
yield from self.sort
yield from self.post_sort

@classmethod
def _load(cls, resource: dict[str, Any]) -> Self:
query_edge = resource["edges"]
Expand Down Expand Up @@ -494,6 +525,9 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return type(self) is type(other) and self.dump() == other.dump()

def _iter_sorts(self) -> Iterator[InstanceSort]:
yield from self.backfill_sort

@classmethod
def _load(cls, resource: dict[str, Any]) -> ResultSetExpressionSync:
if "nodes" in resource:
Expand Down
Loading
Loading