diff --git a/frontend/src/core/codemirror/cells/extensions.ts b/frontend/src/core/codemirror/cells/extensions.ts index 159a4505ee4..ddbeccc60e5 100644 --- a/frontend/src/core/codemirror/cells/extensions.ts +++ b/frontend/src/core/codemirror/cells/extensions.ts @@ -317,6 +317,16 @@ function cellKeymaps({ function cellCodeEditing(hotkeys: HotkeyProvider): Extension[] { const onChangePlugin = EditorView.updateListener.of((update) => { if (update.docChanged) { + // Skip changes that came from Loro sync (RTC) — these are + // already reflected in the shared LoroDoc and don't need to + // round-trip through the transaction middleware. + const isLoroSync = update.transactions.some( + (tr) => tr.annotation(loroSyncAnnotation) != null, + ); + if (isLoroSync) { + return; + } + // Check if the doc update was a formatting change // e.g. changing from python to markdown const isFormattingChange = update.transactions.some((tr) => diff --git a/marimo/_notebook/__init__.py b/marimo/_notebook/__init__.py new file mode 100644 index 00000000000..de4330833b5 --- /dev/null +++ b/marimo/_notebook/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Notebook document model — canonical representation of notebook structure.""" + +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) +from marimo._notebook.document import CellMeta, NotebookCell, NotebookDocument + +__all__ = [ + "CellMeta", + "CreateCell", + "DeleteCell", + "DocumentChange", + "MoveCell", + "NotebookCell", + "NotebookDocument", + "ReorderCells", + "SetCode", + "SetConfig", + "SetName", + "Transaction", +] diff --git a/marimo/_notebook/_loro.py b/marimo/_notebook/_loro.py new file mode 100644 index 00000000000..698b1dc79f3 --- /dev/null +++ b/marimo/_notebook/_loro.py @@ -0,0 +1,27 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Typed wrappers for ``loro`` APIs with incomplete stubs. + +The ``loro`` stubs omit return types on ``__new__`` and the +``ValueOrContainer`` union lacks a typed ``.container`` accessor. +These helpers isolate the ``type: ignore`` comments so the rest of +the codebase stays clean. +""" + +from __future__ import annotations + +from loro import LoroDoc, LoroText, ValueOrContainer + + +def create_doc() -> LoroDoc: + return LoroDoc() # type: ignore[no-untyped-call] + + +def create_text() -> LoroText: + return LoroText() # type: ignore[no-untyped-call] + + +def unwrap_text(val: ValueOrContainer) -> LoroText: + """Extract a ``LoroText`` from a ``ValueOrContainer``.""" + container = val.container # type: ignore[union-attr,attr-defined] + assert isinstance(container, LoroText) + return container diff --git a/marimo/_notebook/document.py b/marimo/_notebook/document.py new file mode 100644 index 00000000000..d5430c76c2b --- /dev/null +++ b/marimo/_notebook/document.py @@ -0,0 +1,420 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Canonical notebook document model. + +``NotebookDocument`` maintains an ordered list of ``CellMeta`` entries and +a ``LoroDoc`` that owns all cell source text as ``LoroText`` containers. +``NotebookCell`` is a read-only snapshot materialized on access. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING + +from marimo._utils.assert_never import assert_never + +if TYPE_CHECKING: + from collections.abc import Generator, Iterable, Iterator + +import msgspec +from msgspec.structs import replace as structs_replace + +from marimo._ast.cell import CellConfig +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) +from marimo._notebook._loro import create_doc, create_text, unwrap_text +from marimo._types.ids import CellId_t + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + from loro import LoroDoc, LoroText + + +class NotebookCell(msgspec.Struct, frozen=True): + """Read-only snapshot of a cell, materialized from CellMeta + LoroText. + + This is never stored by the document — ``get_cell()`` and ``.cells`` + construct fresh instances each time. + """ + + id: CellId_t + code: str + name: str + config: CellConfig + + +class CellMeta(msgspec.Struct): + """Mutable metadata for a cell. Owned by the document internally. + + Does *not* hold code — that lives in the ``LoroDoc``. + """ + + id: CellId_t + name: str + config: CellConfig + + +class NotebookDocument: + """Ordered collection of cells with transactional updates. + + Cell text is owned by a ``LoroDoc`` (one ``LoroText`` per cell under + ``LoroMap("codes")``). Structural metadata (name, config, ordering) + is stored in ``_cell_metas``. + + Usage:: + + from loro import LoroDoc + + doc = NotebookDocument(LoroDoc()) + doc.add_cell( + CellId_t("a"), code="x = 1", name="__", config=CellConfig() + ) + tx = Transaction( + changes=(SetCode(CellId_t("a"), "x = 2"),), source="kernel" + ) + applied = doc.apply(tx) + assert applied.version == 1 + assert doc.get_cell(CellId_t("a")).code == "x = 2" + """ + + def __init__(self, loro_doc: LoroDoc) -> None: + self._loro_doc = loro_doc + self._codes_map = loro_doc.get_map("codes") + self._cell_metas: list[CellMeta] = [] + self._version: int = 0 + + @classmethod + def from_cells(cls, cells: Iterable[NotebookCell]) -> NotebookDocument: + """Build a document from ``NotebookCell`` snapshots. + + Creates a fresh ``LoroDoc`` populated from the snapshot data. + Used at the kernel-process boundary where cells arrive as + serialized structs and need to be reconstructed into a live + document. + """ + doc = cls(create_doc()) + for c in cells: + doc.add_cell( + cell_id=c.id, code=c.code, name=c.name, config=c.config + ) + doc._loro_doc.commit() + return doc + + @property + def loro_doc(self) -> LoroDoc: + """The underlying Loro document owning cell text.""" + return self._loro_doc + + # ------------------------------------------------------------------ + # Bootstrap — populate the document from an external source + # ------------------------------------------------------------------ + + def add_cell( + self, + cell_id: CellId_t, + code: str, + name: str, + config: CellConfig, + ) -> None: + """Append a cell during initial document construction. + + This is *not* a transaction — it is used at session init to + populate the document from a ``CellManager``. + """ + text = create_text() + text.insert(0, code) + self._codes_map.insert_container(cell_id, text) + self._cell_metas.append(CellMeta(id=cell_id, name=name, config=config)) + + # ------------------------------------------------------------------ + # Read-only accessors + # ------------------------------------------------------------------ + + @property + def cells(self) -> list[NotebookCell]: + """Materialize and return a snapshot list of all cells.""" + return [self._snapshot(m) for m in self._cell_metas] + + @property + def cell_ids(self) -> list[CellId_t]: + """Cell IDs in document order.""" + return [m.id for m in self._cell_metas] + + @property + def version(self) -> int: + return self._version + + def get_cell(self, cell_id: CellId_t) -> NotebookCell: + """Lookup by ID. Raises ``KeyError`` if not found.""" + return self._snapshot(self._find_meta(cell_id)) + + def get(self, cell_id: CellId_t) -> NotebookCell | None: + """Lookup by ID, returning ``None`` if not found.""" + for m in self._cell_metas: + if m.id == cell_id: + return self._snapshot(m) + return None + + def __contains__(self, cell_id: object) -> bool: + return any(m.id == cell_id for m in self._cell_metas) + + def __len__(self) -> int: + return len(self._cell_metas) + + def __iter__(self) -> Iterator[CellId_t]: + return (m.id for m in self._cell_metas) + + # ------------------------------------------------------------------ + # Transaction application + # ------------------------------------------------------------------ + + def apply(self, tx: Transaction) -> Transaction: + """Validate and apply *tx*, return it with ``version`` assigned. + + Raises ``ValueError`` for validation failures and ``KeyError`` + when a change references a non-existent cell. + """ + if not tx.changes: + return structs_replace(tx, version=self._version) + + _validate(tx.changes, self._cell_metas) + + for change in tx.changes: + self._apply_change(change) + + # Commit all Loro mutations from this transaction as a single + # batch. This triggers one ``subscribe_local_update`` callback + # so RTC clients receive one update per transaction. + self._loro_doc.commit() + + self._version += 1 + return structs_replace(tx, version=self._version) + + def _apply_change(self, change: DocumentChange) -> None: + # TODO: refactor to use match/case (min Python is 3.10) once + # ruff target-version is bumped from py39. + if isinstance(change, CreateCell): + # Create LoroText in the shared doc + text = create_text() + text.insert(0, change.code) + self._codes_map.insert_container(change.cell_id, text) + + meta = CellMeta( + id=change.cell_id, name=change.name, config=change.config + ) + if change.after is not None: + idx = self._find_index(change.after) + self._cell_metas.insert(idx + 1, meta) + elif change.before is not None: + idx = self._find_index(change.before) + self._cell_metas.insert(idx, meta) + else: + self._cell_metas.append(meta) + + elif isinstance(change, DeleteCell): + idx = self._find_index(change.cell_id) + del self._cell_metas[idx] + self._codes_map.delete(change.cell_id) + + elif isinstance(change, MoveCell): + idx = self._find_index(change.cell_id) + meta = self._cell_metas.pop(idx) + if change.after is not None: + target = self._find_index(change.after) + self._cell_metas.insert(target + 1, meta) + elif change.before is not None: + target = self._find_index(change.before) + self._cell_metas.insert(target, meta) + else: + raise ValueError("MoveCell requires 'before' or 'after'") + + elif isinstance(change, ReorderCells): + by_id = {m.id: m for m in self._cell_metas} + seen: set[CellId_t] = set() + reordered: list[CellMeta] = [] + for cid in change.cell_ids: + if cid in by_id and cid not in seen: + reordered.append(by_id[cid]) + seen.add(cid) + for m in self._cell_metas: + if m.id not in seen: + reordered.append(m) + self._cell_metas = reordered + + elif isinstance(change, SetCode): + self._find_meta(change.cell_id) + text = self._get_loro_text(change.cell_id) + if text.len_unicode > 0: + text.delete(0, text.len_unicode) + if change.code: + text.insert(0, change.code) + + elif isinstance(change, SetName): + self._find_meta(change.cell_id).name = change.name + + elif isinstance(change, SetConfig): + meta = self._find_meta(change.cell_id) + meta.config = CellConfig( + column=change.column + if change.column is not None + else meta.config.column, + disabled=change.disabled + if change.disabled is not None + else meta.config.disabled, + hide_code=change.hide_code + if change.hide_code is not None + else meta.config.hide_code, + ) + else: + assert_never(change) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _snapshot(self, meta: CellMeta) -> NotebookCell: + """Build a read-only ``NotebookCell`` from metadata + Loro text.""" + code = self._get_loro_text(meta.id).to_string() + return NotebookCell( + id=meta.id, code=code, name=meta.name, config=meta.config + ) + + def _get_loro_text(self, cell_id: CellId_t) -> LoroText: + """Return the ``LoroText`` for *cell_id*.""" + val = self._codes_map.get(cell_id) + if val is None: + raise KeyError(f"No LoroText for cell {cell_id!r}") + return unwrap_text(val) + + def _find_index(self, cell_id: CellId_t) -> int: + for i, m in enumerate(self._cell_metas): + if m.id == cell_id: + return i + raise KeyError(f"Cell {cell_id!r} not found in document") + + def _find_meta(self, cell_id: CellId_t) -> CellMeta: + for m in self._cell_metas: + if m.id == cell_id: + return m + raise KeyError(f"Cell {cell_id!r} not found in document") + + def __repr__(self) -> str: + lines = [f"NotebookDocument({len(self._cell_metas)} cells):"] + for i, m in enumerate(self._cell_metas): + code_preview = ( + self._get_loro_text(m.id).to_string()[:40].replace("\n", "\\n") + ) + lines.append(f" {i}: {m.id} {code_preview!r}") + return "\n".join(lines) + + +# ------------------------------------------------------------------ +# Context variable +# ------------------------------------------------------------------ + +#: Document snapshot for the current scratchpad execution. Set by the +#: kernel before running code_mode so ``AsyncCodeModeContext`` can read +#: cell ordering, code, names, and configs without the kernel carrying +#: mutable document state. +_current_document: ContextVar[NotebookDocument | None] = ContextVar( + "_current_document", default=None +) + + +def get_current_document() -> NotebookDocument | None: + """Return the document for the current execution, if any.""" + return _current_document.get() + + +@contextmanager +def notebook_document_context( + doc: NotebookDocument | None, +) -> Generator[None, None, None]: + """Context manager for setting and resetting the current document.""" + token = _current_document.set(doc) + try: + yield + finally: + _current_document.reset(token) + + +# ------------------------------------------------------------------ +# Validation +# ------------------------------------------------------------------ + + +def _validate( + changes: tuple[DocumentChange, ...], metas: list[CellMeta] +) -> None: + """Check for conflicting changes. Raises ``ValueError``.""" + existing_ids = {m.id for m in metas} + created: set[CellId_t] = set() + deleted: set[CellId_t] = set() + updated: set[CellId_t] = set() + moved: set[CellId_t] = set() + + for change in changes: + if isinstance(change, CreateCell): + if change.cell_id in existing_ids or change.cell_id in created: + raise ValueError(f"Cell {change.cell_id!r} already exists") + if change.before is not None and change.after is not None: + raise ValueError( + "CreateCell cannot specify both 'before' and 'after'" + ) + created.add(change.cell_id) + + elif isinstance(change, DeleteCell): + if change.cell_id in deleted: + raise ValueError( + f"Cell {change.cell_id!r} is deleted more than once" + ) + if change.cell_id in updated: + raise ValueError( + f"Cannot delete cell {change.cell_id!r} that is also " + f"updated in the same transaction" + ) + if change.cell_id in moved: + raise ValueError( + f"Cannot delete cell {change.cell_id!r} that is also " + f"moved in the same transaction" + ) + deleted.add(change.cell_id) + + elif isinstance(change, MoveCell): + if change.cell_id in deleted: + raise ValueError( + f"Cannot move cell {change.cell_id!r} that is also " + f"deleted in the same transaction" + ) + if change.before is not None and change.after is not None: + raise ValueError( + "MoveCell cannot specify both 'before' and 'after'" + ) + if change.before is None and change.after is None: + raise ValueError("MoveCell requires 'before' or 'after'") + moved.add(change.cell_id) + + elif isinstance(change, ReorderCells): + pass # No conflicts — replaces full ordering + + elif isinstance(change, (SetCode, SetName, SetConfig)): + if change.cell_id in deleted: + raise ValueError( + f"Cannot update cell {change.cell_id!r} that is also " + f"deleted in the same transaction" + ) + updated.add(change.cell_id) + + else: + raise TypeError(f"Unknown change type: {type(change)!r}") diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 4ea3efce220..44a917b14c4 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -2296,7 +2296,7 @@ async def handle_execute_scratchpad( request: ExecuteScratchpadCommand, ) -> None: doc = ( - NotebookDocument(list(request.notebook_cells)) + NotebookDocument.from_cells(request.notebook_cells) if request.notebook_cells is not None else None ) diff --git a/marimo/_server/api/endpoints/ws/ws_kernel_ready.py b/marimo/_server/api/endpoints/ws/ws_kernel_ready.py index 409c1c6af14..595925fb1eb 100644 --- a/marimo/_server/api/endpoints/ws/ws_kernel_ready.py +++ b/marimo/_server/api/endpoints/ws/ws_kernel_ready.py @@ -64,9 +64,9 @@ def build_kernel_ready( """ codes, names, configs, cell_ids = _extract_cell_data(session, manager) - # Initialize RTC if needed + # Register session's LoroDoc for RTC if needed if _should_init_rtc(rtc_enabled, mode): - _try_init_rtc_doc(cell_ids, codes, file_key, doc_manager) + _try_init_rtc_doc(session, file_key, doc_manager) return KernelReadyNotification( codes=codes, @@ -158,18 +158,20 @@ def _should_init_rtc(rtc_enabled: bool, mode: SessionMode) -> bool: def _try_init_rtc_doc( - cell_ids: tuple[CellId_t, ...], - codes: tuple[str, ...], + session: Session, file_key: MarimoFileKey, doc_manager: LoroDocManager, ) -> None: - """Try to initialize RTC document with cell data. + """Register the session's LoroDoc with the RTC doc manager. + + The session's ``NotebookDocument`` already owns a ``LoroDoc`` that + holds all cell text. This makes that same doc available for RTC + client connections so there is a single source of truth. Logs a warning if Loro is not available but does not fail. Args: - cell_ids: Cell IDs to initialize - codes: Cell codes to initialize + session: Current session (owns the LoroDoc via its document) file_key: File key for the document doc_manager: LoroDoc manager """ @@ -180,4 +182,6 @@ def _try_init_rtc_doc( "RTC: Loro is not installed, disabling real-time collaboration" ) else: - asyncio.create_task(doc_manager.create_doc(file_key, cell_ids, codes)) + asyncio.create_task( + doc_manager.register_doc(file_key, session.document.loro_doc) + ) diff --git a/marimo/_server/api/endpoints/ws/ws_rtc_handler.py b/marimo/_server/api/endpoints/ws/ws_rtc_handler.py index 62e242b8465..7b6dfd36720 100644 --- a/marimo/_server/api/endpoints/ws/ws_rtc_handler.py +++ b/marimo/_server/api/endpoints/ws/ws_rtc_handler.py @@ -49,7 +49,7 @@ async def handle(self) -> None: # Get or create the LoroDoc and add the client to it LOGGER.debug("RTC: getting document") update_queue: asyncio.Queue[bytes] = asyncio.Queue() - doc = await self.doc_manager.get_or_create_doc(self.file_key) + doc = await self.doc_manager.get_doc(self.file_key) self.doc_manager.add_client_to_doc(self.file_key, update_queue) # Send initial sync to client diff --git a/marimo/_server/rtc/doc.py b/marimo/_server/rtc/doc.py index 7b42ba29335..d0ff65c4820 100644 --- a/marimo/_server/rtc/doc.py +++ b/marimo/_server/rtc/doc.py @@ -2,11 +2,10 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from marimo import _loggers from marimo._server.file_router import MarimoFileKey -from marimo._types.ids import CellId_t if TYPE_CHECKING: from loro import LoroDoc @@ -21,86 +20,62 @@ def __init__(self) -> None: self.loro_docs_clients: dict[ MarimoFileKey, set[asyncio.Queue[bytes]] ] = {} - self.loro_docs_cleaners: dict[ - MarimoFileKey, Optional[asyncio.Task[None]] - ] = {} + # Hold subscription references to prevent garbage collection + self._subscriptions: dict[MarimoFileKey, object] = {} - async def _clean_loro_doc( - self, file_key: MarimoFileKey, timeout: float = 60 - ) -> None: - """Clean up a loro doc if no clients are connected.""" - try: - await asyncio.sleep(timeout) - async with self.loro_docs_lock: - if ( - file_key in self.loro_docs_clients - and len(self.loro_docs_clients[file_key]) == 0 - ): - LOGGER.debug( - f"RTC: Removing loro doc for file {file_key} as it has no clients" - ) - # Clean up the document - await self._do_remove_doc(file_key) - except asyncio.CancelledError: - # Task was cancelled due to client reconnection - LOGGER.debug( - f"RTC: clean_loro_doc task cancelled for file {file_key} - likely due to reconnection" - ) - pass - - async def create_doc( + async def register_doc( self, file_key: MarimoFileKey, - cell_ids: tuple[CellId_t, ...], - codes: tuple[str, ...], - ) -> LoroDoc: - """Create a new loro doc.""" - from loro import LoroDoc, LoroText - - assert len(cell_ids) == len(codes), ( - "cell_ids and codes must be the same length" - ) - + doc: LoroDoc, + ) -> None: + """Register an existing LoroDoc (owned by a session's NotebookDocument). + + The session creates the LoroDoc at init time and the document + model owns it for the lifetime of the session. This method + makes the same doc available for RTC client connections and + subscribes to local updates so that server-originated changes + (e.g. SetCode from kernel or file-watch) are broadcast to all + connected RTC clients. + """ async with self.loro_docs_lock: if file_key in self.loro_docs: - return self.loro_docs[file_key] - - LOGGER.debug(f"RTC: Initializing LoroDoc for file {file_key}") - doc = LoroDoc() # type: ignore[no-untyped-call] - self.loro_docs[file_key] = doc - - # Add all cell code to the doc - doc_codes = doc.get_map("codes") + LOGGER.debug( + f"RTC: LoroDoc already registered for file {file_key}" + ) + return + # Ensure the languages map exists for the frontend doc.get_map("languages") - for cell_id, code in zip(cell_ids, codes): - cell_text = LoroText() # type: ignore[no-untyped-call] - cell_text.insert(0, code) - doc_codes.insert_container(cell_id, cell_text) + LOGGER.debug(f"RTC: Registered LoroDoc for file {file_key}") + self.loro_docs[file_key] = doc - # We don't set the language here because it will be set - # when the client connects for the first time. - return doc + # Broadcast server-side Loro mutations to RTC clients. + # The callback fires synchronously on doc.commit() — we + # enqueue directly into client queues (non-blocking). + def _on_local_update(update: bytes) -> bool: + clients = self.loro_docs_clients.get(file_key, set()) + for client in clients: + try: + client.put_nowait(update) + except asyncio.QueueFull: + LOGGER.warning( + "RTC: client queue full, dropping update" + ) + return True # keep subscription alive + + self._subscriptions[file_key] = doc.subscribe_local_update( + _on_local_update + ) - async def get_or_create_doc(self, file_key: MarimoFileKey) -> LoroDoc: - """Get or create a loro doc for a file key.""" - from loro import LoroDoc + async def get_doc(self, file_key: MarimoFileKey) -> LoroDoc: + """Get the LoroDoc registered for *file_key*. + Raises ``KeyError`` if no doc has been registered via + ``register_doc``. + """ async with self.loro_docs_lock: - if file_key in self.loro_docs: - doc = self.loro_docs[file_key] - # Cancel existing cleaner task if it exists - cleaner = self.loro_docs_cleaners.get(file_key, None) - if cleaner is not None: - LOGGER.debug( - f"RTC: Cancelling existing cleaner for file {file_key}" - ) - cleaner.cancel() - self.loro_docs_cleaners[file_key] = None - else: - LOGGER.warning(f"RTC: Expected loro doc for file {file_key}") - doc = LoroDoc() # type: ignore[no-untyped-call] - self.loro_docs[file_key] = doc - return doc + if file_key not in self.loro_docs: + raise KeyError(f"No LoroDoc registered for file {file_key!r}") + return self.loro_docs[file_key] def add_client_to_doc( self, file_key: MarimoFileKey, update_queue: asyncio.Queue[bytes] @@ -115,7 +90,7 @@ async def broadcast_update( self, file_key: MarimoFileKey, message: bytes, - exclude_queue: Optional[asyncio.Queue[bytes]] = None, + exclude_queue: asyncio.Queue[bytes] | None = None, ) -> None: """Broadcast an update to all clients except the excluded queue.""" clients = self.loro_docs_clients[file_key] @@ -129,39 +104,23 @@ async def remove_client( file_key: MarimoFileKey, update_queue: asyncio.Queue[bytes], ) -> None: - """Clean up a loro client and potentially the doc if no clients remain.""" - should_create_cleaner = False + """Remove an RTC client queue. + The LoroDoc itself is *not* cleaned up when clients disconnect — + it is owned by the session's ``NotebookDocument`` and lives for + the session's lifetime. Only the client tracking set is updated. + """ async with self.loro_docs_lock: if file_key not in self.loro_docs_clients: return - - self.loro_docs_clients[file_key].remove(update_queue) - # If no clients are connected, set up a cleaner task - if len(self.loro_docs_clients[file_key]) == 0: - # Remove any existing cleaner - cleaner = self.loro_docs_cleaners.get(file_key, None) - if cleaner is not None: - cleaner.cancel() - self.loro_docs_cleaners[file_key] = None - should_create_cleaner = True - - # Create the cleaner task outside the lock to avoid deadlocks - if should_create_cleaner: - self.loro_docs_cleaners[file_key] = asyncio.create_task( - self._clean_loro_doc(file_key, 60.0) - ) - - async def _do_remove_doc(self, file_key: MarimoFileKey) -> None: - """Actual implementation of removing a doc, separate from remove_doc to avoid deadlocks.""" - if file_key in self.loro_docs: - del self.loro_docs[file_key] - if file_key in self.loro_docs_clients: - del self.loro_docs_clients[file_key] - if file_key in self.loro_docs_cleaners: - del self.loro_docs_cleaners[file_key] + self.loro_docs_clients[file_key].discard(update_queue) async def remove_doc(self, file_key: MarimoFileKey) -> None: - """Remove a loro doc and all associated clients""" + """Unregister a LoroDoc and all associated client queues. + + Called when the session closes. + """ async with self.loro_docs_lock: - await self._do_remove_doc(file_key) + self.loro_docs.pop(file_key, None) + self.loro_docs_clients.pop(file_key, None) + self._subscriptions.pop(file_key, None) diff --git a/marimo/_session/file_change_handler.py b/marimo/_session/file_change_handler.py index 1e2310de425..472d2963d90 100644 --- a/marimo/_session/file_change_handler.py +++ b/marimo/_session/file_change_handler.py @@ -12,12 +12,21 @@ from marimo import _loggers from marimo._config.manager import MarimoConfigManager +from marimo._messaging.notebook import DocumentChange +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) from marimo._messaging.notification import ( + NotebookDocumentTransactionNotification, ReloadNotification, - UpdateCellCodesNotification, - UpdateCellIdsNotification, ) -from marimo._runtime.commands import DeleteCellCommand, SyncGraphCommand +from marimo._runtime.commands import SyncGraphCommand from marimo._session.model import SessionMode from marimo._types.ids import CellId_t from marimo._utils import async_path @@ -68,12 +77,9 @@ def handle_reload( self, session: Session, *, changed_cell_ids: set[CellId_t] ) -> None: """Handle reload in edit mode with optional auto-run.""" - # Get the latest codes, cell IDs, names, and configs cell_manager = session.app_file_manager.app.cell_manager - codes = list(cell_manager.codes()) cell_ids = list(cell_manager.cell_ids()) - names = list(cell_manager.names()) - configs = list(cell_manager.configs()) + codes = list(cell_manager.codes()) LOGGER.info( f"File changed: {session.app_file_manager.path}. " @@ -81,68 +87,74 @@ def handle_reload( f"changed_cell_ids: {changed_cell_ids}" ) - # Send the updated cell IDs to the frontend - session.notify( - UpdateCellIdsNotification(cell_ids=cell_ids), - from_consumer_id=None, - ) + # Build a transaction by diffing session.document vs new cell_manager. + doc = session.document + doc_ids = set(doc) + new_ids = set(cell_ids) + deleted = doc_ids - new_ids + + changes: list[DocumentChange] = [] + + # Deletes + for cid in deleted: + changes.append(DeleteCell(cell_id=cid)) + + # Creates and updates + for cd in cell_manager.cell_data(): + if cd.cell_id not in doc_ids: + changes.append( + CreateCell( + cell_id=cd.cell_id, + code=cd.code, + name=cd.name, + config=cd.config, + ) + ) + else: + doc_cell = doc.get_cell(cd.cell_id) + if cd.code != doc_cell.code: + changes.append(SetCode(cell_id=cd.cell_id, code=cd.code)) + if cd.name != doc_cell.name: + changes.append(SetName(cell_id=cd.cell_id, name=cd.name)) + if cd.config != doc_cell.config: + changes.append( + SetConfig( + cell_id=cd.cell_id, + column=cd.config.column, + disabled=cd.config.disabled, + hide_code=cd.config.hide_code, + ) + ) + + # Reorder + changes.append(ReorderCells(cell_ids=tuple(cell_ids))) + + # Broadcast transaction — session.notify() applies to + # session.document and stamps the version before forwarding. + if changes: + session.notify( + NotebookDocumentTransactionNotification( + transaction=Transaction( + changes=tuple(changes), source="file-watch" + ) + ), + from_consumer_id=None, + ) - # Check if we should auto-run cells based on config + # Auto-run changed cells if configured. watcher_on_save = self._config_manager.get_config()["runtime"][ "watcher_on_save" ] - should_autorun = watcher_on_save == "autorun" - - # Determine deleted cells - deleted = { - cell_id for cell_id in changed_cell_ids if cell_id not in cell_ids - } - - # Auto-run cells if configured - if should_autorun: - changed_cell_ids_list = list(changed_cell_ids - deleted) - cells = dict(zip(cell_ids, codes)) - - # Send names and configs to the frontend (SyncGraphCommand - # doesn't carry them) - if cell_ids: - session.notify( - UpdateCellCodesNotification( - cell_ids=cell_ids, - codes=codes, - code_is_stale=False, - names=names, - configs=configs, - ), - from_consumer_id=None, - ) - + if watcher_on_save == "autorun": + changed_not_deleted = list(changed_cell_ids - deleted) session.put_control_request( SyncGraphCommand( - cells=cells, - run_ids=changed_cell_ids_list, + cells=dict(zip(cell_ids, codes)), + run_ids=changed_not_deleted, delete_ids=list(deleted), ), from_consumer_id=None, ) - else: - # Just send deletes and code updates - for to_delete in deleted: - session.put_control_request( - DeleteCellCommand(cell_id=to_delete), - from_consumer_id=None, - ) - if cell_ids: - session.notify( - UpdateCellCodesNotification( - cell_ids=cell_ids, - codes=codes, - code_is_stale=True, - names=names, - configs=configs, - ), - from_consumer_id=None, - ) class RunModeReloadStrategy: diff --git a/marimo/_session/session.py b/marimo/_session/session.py index 432a5a047fa..88d1a701aa9 100644 --- a/marimo/_session/session.py +++ b/marimo/_session/session.py @@ -15,12 +15,12 @@ from marimo import _loggers from marimo._cli.sandbox import SandboxMode from marimo._config.manager import MarimoConfigManager, ScriptConfigManager -from marimo._messaging.notebook.document import NotebookCell, NotebookDocument from marimo._messaging.notification import ( NotificationMessage, ) from marimo._messaging.serde import serialize_kernel_message from marimo._messaging.types import KernelMessage +from marimo._notebook.document import NotebookDocument from marimo._runtime import commands from marimo._runtime.commands import ( AppMetadata, @@ -76,6 +76,10 @@ def _document_from_cell_manager(cell_manager: CellManager) -> NotebookDocument: """Build a NotebookDocument from a CellManager's current state. + Creates a ``LoroDoc`` that owns all cell text as ``LoroText`` + containers. Structural metadata (name, config, ordering) is stored + in the document's internal ``CellMeta`` list. + TODO: CellManager and NotebookDocument track overlapping state (cell ordering, code, names, configs). Once the document model is wired into all consumers, we should reconcile these — either CellManager @@ -83,17 +87,18 @@ def _document_from_cell_manager(cell_manager: CellManager) -> NotebookDocument: different composition. For now, the document is populated from the cell manager at session startup and the two coexist. """ - return NotebookDocument( - [ - NotebookCell( - id=cd.cell_id, - code=cd.code, - name=cd.name, - config=cd.config, - ) - for cd in cell_manager.cell_data() - ] - ) + from marimo._notebook._loro import create_doc + + doc = NotebookDocument(create_doc()) + for cd in cell_manager.cell_data(): + doc.add_cell( + cell_id=cd.cell_id, + code=cd.code, + name=cd.name, + config=cd.config, + ) + doc.loro_doc.commit() + return doc class SessionImpl(Session): diff --git a/tests/_code_mode/test_cells_view.py b/tests/_code_mode/test_cells_view.py index 51bee8bde5b..b0440cc5b66 100644 --- a/tests/_code_mode/test_cells_view.py +++ b/tests/_code_mode/test_cells_view.py @@ -27,11 +27,9 @@ def cmd(cell_id: str, code: str) -> ExecuteCellCommand: @contextmanager def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]: """Build an AsyncCodeModeContext with a document snapshot from the kernel.""" - doc = NotebookDocument( - [ - NotebookCell(id=cid, code=cell.code, name="", config=cell.config) - for cid, cell in k.graph.cells.items() - ] + doc = NotebookDocument.from_cells( + NotebookCell(id=cid, code=cell.code, name="", config=cell.config) + for cid, cell in k.graph.cells.items() ) with notebook_document_context(doc): yield AsyncCodeModeContext(k) diff --git a/tests/_code_mode/test_context.py b/tests/_code_mode/test_context.py index 881722de094..72d205b68bb 100644 --- a/tests/_code_mode/test_context.py +++ b/tests/_code_mode/test_context.py @@ -29,11 +29,9 @@ @contextmanager def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]: """Build an AsyncCodeModeContext with a document snapshot from the kernel.""" - doc = NotebookDocument( - [ - NotebookCell(id=cid, code=cell.code, name="", config=cell.config) - for cid, cell in k.graph.cells.items() - ] + doc = NotebookDocument.from_cells( + NotebookCell(id=cid, code=cell.code, name="", config=cell.config) + for cid, cell in k.graph.cells.items() ) with notebook_document_context(doc): yield AsyncCodeModeContext(k) diff --git a/tests/_notebook/test_document.py b/tests/_notebook/test_document.py new file mode 100644 index 00000000000..2a6e061f2dc --- /dev/null +++ b/tests/_notebook/test_document.py @@ -0,0 +1,576 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import pytest +from inline_snapshot import snapshot +from loro import LoroDoc + +from marimo._ast.cell import CellConfig +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) +from marimo._notebook.document import NotebookDocument +from marimo._types.ids import CellId_t + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _doc(*names: str) -> NotebookDocument: + doc = NotebookDocument(LoroDoc()) + for n in names: + doc.add_cell( + cell_id=CellId_t(n), code="", name="__", config=CellConfig() + ) + doc._loro_doc.commit() + return doc + + +def _tx(*changes: DocumentChange, source: str = "test") -> Transaction: + return Transaction(changes=changes, source=source) + + +def _ids(doc: NotebookDocument) -> list[str]: + return [str(cid) for cid in doc.cell_ids] + + +# ------------------------------------------------------------------ +# CreateCell +# ------------------------------------------------------------------ + + +class TestCreateCell: + def test_append_at_end(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + assert _ids(doc) == snapshot(["a", "b", "new"]) + + def test_insert_after(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + after=CellId_t("a"), + ) + ) + ) + assert _ids(doc) == snapshot(["a", "new", "b", "c"]) + + def test_insert_before(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + before=CellId_t("b"), + ) + ) + ) + assert _ids(doc) == snapshot(["a", "new", "b", "c"]) + + def test_insert_into_empty(self) -> None: + doc = _doc() + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + assert _ids(doc) == snapshot(["new"]) + + def test_multiple(self) -> None: + doc = _doc("a") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("x"), + code="1", + name="__", + config=CellConfig(), + ), + CreateCell( + cell_id=CellId_t("y"), + code="2", + name="__", + config=CellConfig(), + ), + ) + ) + assert _ids(doc) == snapshot(["a", "x", "y"]) + + def test_after_pending(self) -> None: + """A create can reference a cell added earlier in the same tx.""" + doc = _doc("a") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("x"), + code="1", + name="__", + config=CellConfig(), + ), + CreateCell( + cell_id=CellId_t("y"), + code="2", + name="__", + config=CellConfig(), + after=CellId_t("x"), + ), + ) + ) + assert _ids(doc) == snapshot(["a", "x", "y"]) + + def test_duplicate_id_raises(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="already exists"): + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("a"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + + def test_stores_code_and_config(self) -> None: + doc = _doc() + cfg = CellConfig(hide_code=True, disabled=True) + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("a"), + code="import os", + name="imports", + config=cfg, + ) + ) + ) + cell = doc.get_cell(CellId_t("a")) + assert cell.code == "import os" + assert cell.name == "imports" + assert cell.config.hide_code is True + assert cell.config.disabled is True + + +# ------------------------------------------------------------------ +# DeleteCell +# ------------------------------------------------------------------ + + +class TestDeleteCell: + def test_delete_single(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(DeleteCell(cell_id=CellId_t("b")))) + assert _ids(doc) == snapshot(["a", "c"]) + + def test_delete_multiple(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + DeleteCell(cell_id=CellId_t("c")), + ) + ) + assert _ids(doc) == snapshot(["b"]) + + def test_delete_all(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + DeleteCell(cell_id=CellId_t("b")), + ) + ) + assert _ids(doc) == snapshot([]) + + def test_delete_not_found(self) -> None: + doc = _doc("a") + with pytest.raises(KeyError): + doc.apply(_tx(DeleteCell(cell_id=CellId_t("missing")))) + + +# ------------------------------------------------------------------ +# MoveCell +# ------------------------------------------------------------------ + + +class TestMoveCell: + def test_move_after(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(MoveCell(cell_id=CellId_t("a"), after=CellId_t("c")))) + assert _ids(doc) == snapshot(["b", "c", "a"]) + + def test_move_before(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(MoveCell(cell_id=CellId_t("c"), before=CellId_t("a")))) + assert _ids(doc) == snapshot(["c", "a", "b"]) + + def test_no_anchor_raises(self) -> None: + doc = _doc("a", "b") + with pytest.raises(ValueError, match="before.*after"): + doc.apply(_tx(MoveCell(cell_id=CellId_t("a")))) + + +# ------------------------------------------------------------------ +# SetCode +# ------------------------------------------------------------------ + + +class TestSetCode: + def test_update_code(self) -> None: + doc = _doc("a", "b") + doc.apply(_tx(SetCode(cell_id=CellId_t("b"), code="new"))) + assert doc.get_cell(CellId_t("b")).code == "new" + assert doc.get_cell(CellId_t("a")).code == "" # unchanged + + def test_not_found(self) -> None: + doc = _doc("a") + with pytest.raises(KeyError): + doc.apply(_tx(SetCode(cell_id=CellId_t("missing"), code="x"))) + + +# ------------------------------------------------------------------ +# SetName +# ------------------------------------------------------------------ + + +class TestSetName: + def test_update_name(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetName(cell_id=CellId_t("a"), name="my_cell"))) + assert doc.get_cell(CellId_t("a")).name == "my_cell" + + +# ------------------------------------------------------------------ +# SetConfig +# ------------------------------------------------------------------ + + +class TestSetConfig: + def test_partial_hide_code(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetConfig(cell_id=CellId_t("a"), hide_code=True))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.hide_code is True + assert cfg.disabled is False # unchanged default + + def test_partial_disabled(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetConfig(cell_id=CellId_t("a"), disabled=True))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.disabled is True + assert cfg.hide_code is False # unchanged default + + def test_multiple_fields(self) -> None: + doc = _doc("a") + doc.apply( + _tx( + SetConfig(cell_id=CellId_t("a"), hide_code=True, disabled=True) + ) + ) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.hide_code is True + assert cfg.disabled is True + + def test_all_none_is_noop(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetConfig(cell_id=CellId_t("a")))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg == CellConfig() + + def test_preserves_existing(self) -> None: + """Setting one field preserves other non-default fields.""" + doc = NotebookDocument(LoroDoc()) + doc.add_cell( + cell_id=CellId_t("a"), + code="", + name="__", + config=CellConfig(hide_code=True, column=2), + ) + doc._loro_doc.commit() + doc.apply(_tx(SetConfig(cell_id=CellId_t("a"), disabled=True))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.disabled is True + assert cfg.hide_code is True # preserved + assert cfg.column == 2 # preserved + + +# ------------------------------------------------------------------ +# Validation +# ------------------------------------------------------------------ + + +class TestValidation: + def test_delete_and_set_code_same_cell(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="delete.*update"): + doc.apply( + _tx( + SetCode(cell_id=CellId_t("a"), code="x"), + DeleteCell(cell_id=CellId_t("a")), + ) + ) + + def test_set_code_and_delete_same_cell(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="update.*delete"): + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + SetCode(cell_id=CellId_t("a"), code="x"), + ) + ) + + def test_delete_and_move_same_cell(self) -> None: + doc = _doc("a", "b") + with pytest.raises(ValueError, match="delete.*move"): + doc.apply( + _tx( + MoveCell(cell_id=CellId_t("a"), after=CellId_t("b")), + DeleteCell(cell_id=CellId_t("a")), + ) + ) + + def test_double_delete(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="deleted more than once"): + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + DeleteCell(cell_id=CellId_t("a")), + ) + ) + + def test_valid_mixed_changes(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ), + SetCode(cell_id=CellId_t("a"), code="y"), + DeleteCell(cell_id=CellId_t("b")), + ) + ) + assert _ids(doc) == snapshot(["a", "new"]) + + +# ------------------------------------------------------------------ +# Versioning +# ------------------------------------------------------------------ + + +class TestVersion: + def test_increments_on_apply(self) -> None: + doc = _doc("a") + assert doc.version == 0 + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.version == 1 + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="y"))) + assert doc.version == 2 + + def test_stamped_on_returned_tx(self) -> None: + doc = _doc("a") + tx = _tx(SetCode(cell_id=CellId_t("a"), code="x")) + assert tx.version is None + applied = doc.apply(tx) + assert applied.version == 1 + + def test_empty_tx_no_increment(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.version == 1 + applied = doc.apply(_tx()) + assert doc.version == 1 + assert applied.version == 1 + + +# ------------------------------------------------------------------ +# Combined changes +# ------------------------------------------------------------------ + + +class TestCombined: + def test_delete_then_create(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("b")), + CreateCell( + cell_id=CellId_t("new"), + code="d", + name="__", + config=CellConfig(), + after=CellId_t("a"), + ), + ) + ) + assert _ids(doc) == snapshot(["a", "new", "c"]) + + def test_create_then_set_code(self) -> None: + doc = _doc("a") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="tmp", + name="__", + config=CellConfig(), + ), + SetCode(cell_id=CellId_t("new"), code="final"), + ) + ) + assert _ids(doc) == snapshot(["a", "new"]) + assert doc.get_cell(CellId_t("new")).code == "final" + + def test_create_then_move(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ), + MoveCell(cell_id=CellId_t("new"), before=CellId_t("a")), + ) + ) + assert _ids(doc) == snapshot(["new", "a", "b"]) + + def test_source_preserved(self) -> None: + doc = _doc("a") + applied = doc.apply( + _tx( + SetCode(cell_id=CellId_t("a"), code="x"), + source="kernel", + ) + ) + assert applied.source == "kernel" + + +# ------------------------------------------------------------------ +# Initialization +# ------------------------------------------------------------------ + + +class TestInit: + def test_from_cells(self) -> None: + doc = _doc("a", "b", "c") + assert _ids(doc) == ["a", "b", "c"] + + def test_empty(self) -> None: + doc = NotebookDocument(LoroDoc()) + assert _ids(doc) == [] + assert doc.version == 0 + + def test_get_cell(self) -> None: + doc = _doc("a", "b") + cell = doc.get_cell(CellId_t("b")) + assert cell.id == CellId_t("b") + + def test_get_cell_not_found(self) -> None: + doc = _doc("a") + with pytest.raises(KeyError): + doc.get_cell(CellId_t("missing")) + + def test_get_returns_none(self) -> None: + doc = _doc("a") + assert doc.get(CellId_t("missing")) is None + assert doc.get(CellId_t("a")) is not None + + def test_contains(self) -> None: + doc = _doc("a", "b") + assert CellId_t("a") in doc + assert CellId_t("missing") not in doc + + def test_len(self) -> None: + assert len(_doc()) == 0 + assert len(_doc("a", "b")) == 2 + + def test_iter(self) -> None: + doc = _doc("a", "b", "c") + assert list(doc) == [CellId_t("a"), CellId_t("b"), CellId_t("c")] + + def test_repr(self) -> None: + doc = _doc("a") + assert "NotebookDocument(1 cells)" in repr(doc) + + +# ------------------------------------------------------------------ +# ReorderCells +# ------------------------------------------------------------------ + + +class TestReorderCells: + def test_reorder(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + ReorderCells( + cell_ids=(CellId_t("c"), CellId_t("a"), CellId_t("b")) + ) + ) + ) + assert _ids(doc) == snapshot(["c", "a", "b"]) + + def test_missing_ids_appended(self) -> None: + """Cells not in the reorder list are appended at the end.""" + doc = _doc("a", "b", "c") + doc.apply(_tx(ReorderCells(cell_ids=(CellId_t("c"), CellId_t("a"))))) + assert _ids(doc) == snapshot(["c", "a", "b"]) + + def test_unknown_ids_ignored(self) -> None: + """IDs not in the document are silently skipped.""" + doc = _doc("a", "b") + doc.apply( + _tx( + ReorderCells( + cell_ids=( + CellId_t("b"), + CellId_t("unknown"), + CellId_t("a"), + ) + ) + ) + ) + assert _ids(doc) == snapshot(["b", "a"]) + + def test_reorder_single(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(ReorderCells(cell_ids=(CellId_t("b"),)))) + assert _ids(doc) == snapshot(["b", "a", "c"]) diff --git a/tests/_server/api/endpoints/test_resume_session.py b/tests/_server/api/endpoints/test_resume_session.py index 27d4584b493..8c2886e7eb1 100644 --- a/tests/_server/api/endpoints/test_resume_session.py +++ b/tests/_server/api/endpoints/test_resume_session.py @@ -286,12 +286,13 @@ def test_resume_session_after_file_change(client: TestClient) -> None: assert result.handled data = websocket.receive_json() - assert data == { - "op": "update-cell-ids", - "data": {"cell_ids": ["MJUe", "Hbol"], "op": "update-cell-ids"}, - } - data = websocket.receive_json() - assert data["op"] == "update-cell-codes" + assert data["op"] == "notebook-document-transaction" + tx = data["data"]["transaction"] + # Transaction should contain the new cell and reorder. + op_types = [op["type"] for op in tx["ops"]] + assert "create-cell" in op_types + assert "reorder-cells" in op_types + assert tx["source"] == "file-watch" # Resume session with new ID (simulates refresh) with client.websocket_connect(_create_ws_url("456")) as websocket: @@ -304,7 +305,7 @@ def test_resume_session_after_file_change(client: TestClient) -> None: assert parse_raw(data["data"], KernelReadyNotification) messages: list[dict[str, Any]] = [] - # Wait for update-cell-ids message + # Wait for update-cell-ids message (session replay) while True: data = websocket.receive_json() messages.append(data) @@ -313,7 +314,7 @@ def test_resume_session_after_file_change(client: TestClient) -> None: # 2 messages: # 1. banner - # 2. update-cell-ids + # 2. update-cell-ids (from session view replay) assert len(messages) == 2 assert messages[0]["op"] == "banner" assert messages[1] == { diff --git a/tests/_server/api/endpoints/test_ws_rtc.py b/tests/_server/api/endpoints/test_ws_rtc.py index 6059a364153..b14ee8c7725 100644 --- a/tests/_server/api/endpoints/test_ws_rtc.py +++ b/tests/_server/api/endpoints/test_ws_rtc.py @@ -29,12 +29,12 @@ async def setup_loro_docs() -> AsyncGenerator[None, None]: # Clear any existing loro docs DOC_MANAGER.loro_docs.clear() DOC_MANAGER.loro_docs_clients.clear() - DOC_MANAGER.loro_docs_cleaners.clear() + DOC_MANAGER._subscriptions.clear() yield # Cleanup after test DOC_MANAGER.loro_docs.clear() DOC_MANAGER.loro_docs_clients.clear() - DOC_MANAGER.loro_docs_cleaners.clear() + DOC_MANAGER._subscriptions.clear() @contextmanager diff --git a/tests/_server/rtc/test_rtc_doc.py b/tests/_server/rtc/test_rtc_doc.py index 07be31cdf6d..afc35df2c0f 100644 --- a/tests/_server/rtc/test_rtc_doc.py +++ b/tests/_server/rtc/test_rtc_doc.py @@ -26,50 +26,68 @@ async def setup_doc_manager() -> AsyncGenerator[None, None]: # Clear any existing loro docs doc_manager.loro_docs.clear() doc_manager.loro_docs_clients.clear() - doc_manager.loro_docs_cleaners.clear() + doc_manager._subscriptions.clear() yield # Cleanup after test doc_manager.loro_docs.clear() doc_manager.loro_docs_clients.clear() - doc_manager.loro_docs_cleaners.clear() + doc_manager._subscriptions.clear() @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) -async def test_quick_reconnection(setup_doc_manager: None) -> None: - """Test that quick reconnection properly handles cleanup task cancellation""" +async def test_register_doc(setup_doc_manager: None) -> None: + """Test registering a LoroDoc with the manager.""" del setup_doc_manager - # Setup file_key = MarimoFileKey("test_file") - - # Create initial loro_doc doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc - # Setup client queue - update_queue = asyncio.Queue[bytes]() - doc_manager.loro_docs_clients[file_key] = {update_queue} + await doc_manager.register_doc(file_key, doc) - # Start cleanup task - cleanup_task = asyncio.create_task(doc_manager._clean_loro_doc(file_key)) + assert file_key in doc_manager.loro_docs + assert doc_manager.loro_docs[file_key] is doc + assert file_key in doc_manager._subscriptions - # Simulate quick reconnection by creating a new client before cleanup finishes - new_queue = asyncio.Queue[bytes]() - doc_manager.loro_docs_clients[file_key].add(new_queue) - # Cancel cleanup task - cleanup_task.cancel() - try: - await cleanup_task - except asyncio.CancelledError: - pass +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_register_doc_idempotent(setup_doc_manager: None) -> None: + """Registering the same file_key twice keeps the first doc.""" + del setup_doc_manager + file_key = MarimoFileKey("test_file") + doc1 = LoroDoc() + doc2 = LoroDoc() - # Verify state - assert len(doc_manager.loro_docs) == 1 - assert ( - len(doc_manager.loro_docs_clients[file_key]) == 2 - ) # Original client + reconnected client + await doc_manager.register_doc(file_key, doc1) + await doc_manager.register_doc(file_key, doc2) + + assert doc_manager.loro_docs[file_key] is doc1 + + +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_get_doc(setup_doc_manager: None) -> None: + """Test retrieving a registered doc.""" + del setup_doc_manager + file_key = MarimoFileKey("test_file") + doc = LoroDoc() + await doc_manager.register_doc(file_key, doc) + + result = await doc_manager.get_doc(file_key) + assert result is doc + + +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_get_doc_missing(setup_doc_manager: None) -> None: + """Getting an unregistered doc raises KeyError.""" + del setup_doc_manager + with pytest.raises(KeyError): + await doc_manager.get_doc(MarimoFileKey("missing")) @pytest.mark.skipif( @@ -81,14 +99,15 @@ async def test_two_users_sync(setup_doc_manager: None) -> None: file_key = MarimoFileKey("test_file") cell_id = str(CellId_t("test_cell")) # Convert CellId to string for loro - # First user connects + # Register the doc doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Setup client queues for both users queue1 = asyncio.Queue[bytes]() queue2 = asyncio.Queue[bytes]() - doc_manager.loro_docs_clients[file_key] = {queue1, queue2} + doc_manager.add_client_to_doc(file_key, queue1) + doc_manager.add_client_to_doc(file_key, queue2) # Get maps from doc doc_codes = doc.get_map("codes") @@ -123,21 +142,17 @@ async def test_two_users_sync(setup_doc_manager: None) -> None: @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) -async def test_concurrent_doc_creation(setup_doc_manager: None) -> None: - """Test concurrent doc creation doesn't cause issues""" +async def test_concurrent_registration(setup_doc_manager: None) -> None: + """Test concurrent doc registration doesn't cause issues""" del setup_doc_manager file_key = MarimoFileKey("test_file") - cell_ids = (CellId_t("cell1"), CellId_t("cell2")) - codes = ("print('hello')", "print('world')") + doc = LoroDoc() - # Create multiple tasks that try to create the same doc - tasks = [ - doc_manager.create_doc(file_key, cell_ids, codes) for _ in range(5) - ] - docs = await asyncio.gather(*tasks) + # Create multiple tasks that try to register the same doc + tasks = [doc_manager.register_doc(file_key, doc) for _ in range(5)] + await asyncio.gather(*tasks) - # All tasks should return the same doc instance - assert all(doc is docs[0] for doc in docs) + # Only one doc should be registered assert len(doc_manager.loro_docs) == 1 @@ -151,7 +166,7 @@ async def test_concurrent_client_operations( del setup_doc_manager file_key = MarimoFileKey("test_file") doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Create multiple client queues queues = [asyncio.Queue[bytes]() for _ in range(5)] @@ -170,40 +185,6 @@ async def client_operation(queue: asyncio.Queue[bytes]) -> None: assert len(doc_manager.loro_docs_clients[file_key]) == 0 -@pytest.mark.skipif( - "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" -) -async def test_cleanup_task_management(setup_doc_manager: None) -> None: - """Test cleanup task management and cancellation""" - del setup_doc_manager - file_key = MarimoFileKey("test_file") - doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc - - # Add and remove a client to trigger cleanup - queue = asyncio.Queue[bytes]() - doc_manager.add_client_to_doc(file_key, queue) - await doc_manager.remove_client(file_key, queue) - - # Verify cleanup task was created - assert file_key in doc_manager.loro_docs_cleaners - assert doc_manager.loro_docs_cleaners[file_key] is not None - - # Add a new client before cleanup finishes - new_queue = asyncio.Queue[bytes]() - doc_manager.add_client_to_doc(file_key, new_queue) - - # Wait for the task to be cancelled - await asyncio.sleep(0.1) - - # Verify cleanup task was cancelled and removed - # TODO: not sure why this is still here. - # assert doc_manager.loro_docs_cleaners[file_key] is None - - # Clean up - await doc_manager.remove_client(file_key, new_queue) - - @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) @@ -212,7 +193,7 @@ async def test_broadcast_update(setup_doc_manager: None) -> None: del setup_doc_manager file_key = MarimoFileKey("test_file") doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Create multiple client queues queues = [asyncio.Queue[bytes]() for _ in range(3)] @@ -232,6 +213,29 @@ async def test_broadcast_update(setup_doc_manager: None) -> None: assert await queue.get() == message +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_local_update_broadcast(setup_doc_manager: None) -> None: + """Server-side Loro mutations are broadcast to RTC clients.""" + del setup_doc_manager + file_key = MarimoFileKey("test_file") + doc = LoroDoc() + await doc_manager.register_doc(file_key, doc) + + queue: asyncio.Queue[bytes] = asyncio.Queue() + doc_manager.add_client_to_doc(file_key, queue) + + # Mutate the doc server-side (simulates SetCode via NotebookDocument) + codes = doc.get_map("codes") + text = codes.get_or_create_container("cell1", LoroText()) + text.insert(0, "x = 1") + doc.commit() + + # The subscription should have enqueued the update + assert not queue.empty() + + @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) @@ -242,7 +246,7 @@ async def test_remove_nonexistent_doc(setup_doc_manager: None) -> None: await doc_manager.remove_doc(file_key) assert file_key not in doc_manager.loro_docs assert file_key not in doc_manager.loro_docs_clients - assert file_key not in doc_manager.loro_docs_cleaners + assert file_key not in doc_manager._subscriptions @pytest.mark.skipif( @@ -265,7 +269,7 @@ async def test_concurrent_doc_removal(setup_doc_manager: None) -> None: del setup_doc_manager file_key = MarimoFileKey("test_file") doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Create multiple tasks that try to remove the same doc tasks = [doc_manager.remove_doc(file_key) for _ in range(5)] @@ -274,85 +278,4 @@ async def test_concurrent_doc_removal(setup_doc_manager: None) -> None: # Verify doc was removed assert file_key not in doc_manager.loro_docs assert file_key not in doc_manager.loro_docs_clients - assert file_key not in doc_manager.loro_docs_cleaners - - -@pytest.mark.skipif( - "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" -) -async def test_prevent_lock_deadlock(setup_doc_manager: None) -> None: - """Test that our deadlock prevention measures work correctly. - - This test simulates the scenario that could cause a deadlock: - 1. A client disconnects, starting the cleanup process - 2. Another operation acquires the lock before cleanup timer finishes - 3. Cleanup timer expires and tries to acquire the lock - - The fixed implementation should handle this without deadlocking. - """ - del setup_doc_manager - file_key = MarimoFileKey("test_file") - - # Create a doc and add a client - doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc - queue = asyncio.Queue[bytes]() - doc_manager.add_client_to_doc(file_key, queue) - - # Set a very short cleanup timeout for testing - original_timeout = 60.0 - cleanup_timeout = 0.1 # 100ms - - # Create a barrier to coordinate tasks - barrier = asyncio.Barrier(2) - long_operation_done = asyncio.Event() - - # Task 1: Remove client, which will schedule cleanup with short timeout - async def remove_client_task() -> None: - await doc_manager.remove_client(file_key, queue) - # Wait at barrier to synchronize with the long operation - await barrier.wait() - # Wait for long operation to complete - await long_operation_done.wait() - - # Task 2: Simulate a long operation that holds the lock - async def long_lock_operation() -> None: - # Wait for remove_client to schedule the cleanup - await barrier.wait() - - # Acquire the lock and hold it for longer than the cleanup timeout - async with doc_manager.loro_docs_lock: - # Sleep while holding the lock (longer than cleanup timeout) - await asyncio.sleep(cleanup_timeout * 2) - - # Signal that we're done holding the lock - long_operation_done.set() - - # Modified test version of _clean_loro_doc with shorter timeout - original_clean_loro_doc = doc_manager._clean_loro_doc - - async def test_clean_loro_doc( - file_key: MarimoFileKey, timeout: float = original_timeout - ) -> None: - del timeout - # Override timeout with our test value - await original_clean_loro_doc(file_key, cleanup_timeout) - - # Override the method for this test - doc_manager._clean_loro_doc = test_clean_loro_doc - - try: - # Run both tasks simultaneously - task1 = asyncio.create_task(remove_client_task()) - task2 = asyncio.create_task(long_lock_operation()) - - # This should complete without deadlocking - await asyncio.gather(task1, task2) - - # Verify the doc was properly cleaned up - assert file_key not in doc_manager.loro_docs - assert file_key not in doc_manager.loro_docs_clients - assert file_key not in doc_manager.loro_docs_cleaners - finally: - # Restore the original method - doc_manager._clean_loro_doc = original_clean_loro_doc + assert file_key not in doc_manager._subscriptions