Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions datashare-python/datashare_python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
"join": {"type": "join", "relations": {"Document": "NamedEntity"}},
"contentType": {"type": "keyword"},
"content": {"type": "text"},
"contentTranslated": {"type": "text"},
}
}
}
Expand Down Expand Up @@ -197,10 +196,7 @@ def index_docs_ops(
docs: list[Document], index_name: str
) -> Generator[dict, None, None]:
for doc in docs:
op = {
"_op_type": "index",
"_index": index_name,
}
op = {"_op_type": "index", "_index": index_name}
doc = doc.model_dump(by_alias=True) # noqa: PLW2901
op.update(doc)
if "path" in op:
Expand Down
266 changes: 163 additions & 103 deletions datashare-python/datashare_python/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from enum import StrEnum, unique
from io import BytesIO
from pathlib import Path
from typing import Annotated, Any, Literal, Self, TypeVar, cast
from typing import Annotated, Any, ClassVar, Literal, Self, TypeVar, cast

from pydantic_core import PydanticCustomError, ValidationError, core_schema
from pydantic_extra_types.language_code import LanguageName
from temporalio import workflow

from .constants import TIKA_METADATA_RESOURCENAME
Expand All @@ -31,7 +33,13 @@
merge_configs,
no_enum_values_config,
)
from pydantic import AfterValidator, Field
from pydantic import (
AfterValidator,
Field,
GetCoreSchemaHandler,
TypeAdapter,
model_validator,
)
from pydantic import BaseModel as _BaseModel
from pydantic.main import IncEx

Expand All @@ -50,110 +58,45 @@ class DatashareModel(BaseModel):
model_config = merge_configs(BaseModel.model_config, lowercamel_case_config())


@unique
class TaskState(StrEnum):
CREATED = "CREATED"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
ERROR = "ERROR"
DONE = "DONE"
CANCELLED = "CANCELLED"


READY_STATES = frozenset({TaskState.DONE, TaskState.ERROR, TaskState.CANCELLED})

class DatashareLanguage(str):
_language_type_adapter: ClassVar[TypeAdapter] = TypeAdapter(LanguageName)

class StacktraceItem(DatashareModel):
name: str
file: str
lineno: int


class Message(DatashareModel):
type: str = Field(frozen=True, alias="@type")
@classmethod
def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> Self:
if __input_value != __input_value.upper():
raise PydanticCustomError(
"datashare_language", "Invalid Datashare language, expected uppercase"
)
try:
# Use pydantic provided validation
cls._language_type_adapter.validate_python(__input_value.title())
except ValidationError as e:
raise PydanticCustomError(
"datashare_language", "Unknown Datashare language"
) from e
return cls(__input_value)

def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Any | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
return super().model_dump(
by_alias=True,
mode=mode,
include=include,
exclude=exclude,
context=context,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
fallback=fallback,
serialize_as_any=serialize_as_any,
@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: GetCoreSchemaHandler
) -> core_schema.AfterValidatorFunctionSchema:
return core_schema.with_info_after_validator_function(
cls._validate,
core_schema.str_schema(),
serialization=core_schema.to_string_ser_schema(),
)

@property
def as_language_name(self) -> LanguageName:
return LanguageName(self.title())

class TaskResult(Message):
type: str = Field(frozen=True, alias="@type", default="TaskResult")
value: object


class TaskError(Message):
type: str = Field(frozen=True, alias="@type", default="TaskError")
name: str
message: str
cause: str | None = None
stacktrace: list[StacktraceItem] = Field(default_factory=list)


def _datetime_now() -> datetime:
return datetime.now(UTC)


class User(Message):
type: str = Field(
frozen=True, alias="@type", default="org.icij.datashare.user.User"
)
id: str
name: str | None = None
email: str | None = None
provider: str | None = None
details: dict = dict()


class Task(Message):
type: str = Field(frozen=True, alias="@type", default="Task")
id: str
name: str
args: dict[str, object] | None = None
state: TaskState = TaskState.CREATED
result: TaskResult | None = None
error: TaskError | None = None
progress: float | None = None
created_at: datetime = Field(default_factory=_datetime_now)
completed_at: datetime | None = None
retries_left: int | None = None
max_retries: int | None = None


@dataclass(frozen=True)
class TaskGroup:
name: str
@property
def alpha2(self) -> str | None:
return self.as_language_name.alpha2

@property
@classmethod
def python(cls) -> Self:
return cls(name="PYTHON")
def alpha3(self) -> str:
return self.as_language_name.alpha3


@unique
Expand Down Expand Up @@ -197,27 +140,38 @@ def locate(

class Document(DatashareModel):
id: str
language: str
language: DatashareLanguage
index: str | None = None
root_document: str | None = None
content: str | None = None
content_text_length: int | None = None
content_type: str | None = None
path: Path | None = None
tags: list[str] = Field(default_factory=list)
content_translated: dict[str, str] = Field(
default_factory=dict, alias="content_translated"
content_translated: list[dict[str, str]] = Field(
default_factory=list, alias="content_translated"
)
metadata: dict[str, Any] | None = None
type: str = Field(default="Document", frozen=True)

@model_validator(mode="before")
@classmethod
def _initialize_content_length_from_content(cls, data: Any) -> Any:
if isinstance(data, dict):
content_length = data.get("content_text_length")
if content_length is None and (content := data.get(DOC_CONTENT)):
data["content_text_length"] = len(content)
return data

@classmethod
def from_es(cls, es_doc: dict) -> Self:
sources = es_doc[SOURCE]
return cls(
id=es_doc[ID_],
index=es_doc.get(INDEX_),
content=sources.get(DOC_CONTENT),
content_translated=sources.get(DOC_CONTENT_TRANSLATED, dict()),
content_translated=sources.get(DOC_CONTENT_TRANSLATED, []),
content_text_length=sources.get("content_text_length"),
language=sources[DOC_LANGUAGE],
root_document=sources.get(DOC_ROOT_ID),
tags=sources.get("tags", []),
Expand Down Expand Up @@ -265,3 +219,109 @@ class DocArtifact:
artifact: bytes | BytesIO
filename: str
metadata_key: str


@unique
class TaskState(StrEnum):
CREATED = "CREATED"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
ERROR = "ERROR"
DONE = "DONE"
CANCELLED = "CANCELLED"


READY_STATES = frozenset({TaskState.DONE, TaskState.ERROR, TaskState.CANCELLED})


class StacktraceItem(DatashareModel):
name: str
file: str
lineno: int


class Message(DatashareModel):
type: str = Field(frozen=True, alias="@type")

def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Any | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
return super().model_dump(
by_alias=True,
mode=mode,
include=include,
exclude=exclude,
context=context,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
fallback=fallback,
serialize_as_any=serialize_as_any,
)


class TaskResult(Message):
type: str = Field(frozen=True, alias="@type", default="TaskResult")
value: object


class TaskError(Message):
type: str = Field(frozen=True, alias="@type", default="TaskError")
name: str
message: str
cause: str | None = None
stacktrace: list[StacktraceItem] = Field(default_factory=list)


def _datetime_now() -> datetime:
return datetime.now(UTC)


class User(Message):
type: str = Field(
frozen=True, alias="@type", default="org.icij.datashare.user.User"
)
id: str
name: str | None = None
email: str | None = None
provider: str | None = None
details: dict = dict()


class Task(Message):
type: str = Field(frozen=True, alias="@type", default="Task")
id: str
name: str
args: dict[str, object] | None = None
state: TaskState = TaskState.CREATED
result: TaskResult | None = None
error: TaskError | None = None
progress: float | None = None
created_at: datetime = Field(default_factory=_datetime_now)
completed_at: datetime | None = None
retries_left: int | None = None
max_retries: int | None = None


@dataclass(frozen=True)
class TaskGroup:
name: str

@property
@classmethod
def python(cls) -> Self:
return cls(name="PYTHON")
1 change: 0 additions & 1 deletion datashare-python/datashare_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(self):
async def update_progress(self, signal: ProgressSignal) -> None:
async with self._update_lock:
# TODO: remove this log
workflow.logger.debug("recording progress signal %s", signal)
key = (signal.run_id, signal.activity_id)
self._progress[key] = signal.to_progress()
progress = sum(p.current for p in self._progress.values())
Expand Down
1 change: 1 addition & 0 deletions datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"hatchling~=1.27",
"pyyaml~=6.0",
"orjson~=3.11",
"pydantic-extra-types[pycountry]>=2.11.1",
]

[project.urls]
Expand Down
Loading
Loading