From c82654c2e040209b8c58f2ff0be3fef6809fd5b0 Mon Sep 17 00:00:00 2001 From: Trevor Manz Date: Mon, 23 Mar 2026 21:20:20 -0400 Subject: [PATCH 1/6] Emit document transactions from --watch file reload MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The file watcher previously sent separate UpdateCellIdsNotification and UpdateCellCodesNotification when the notebook file changed on disk. Now it diffs `session.document` against the reloaded cell_manager and emits a single `NotebookDocumentTransactionNotification` with typed ops. The session intercepts it, applies to the canonical document, stamps the version, and forwards to the frontend. The same path used by code_mode and the frontend transaction endpoint. SyncGraphCommand for autorun stays separate since execution is not a structural concern. The non-autorun path no longer needs to send DeleteCellCommand individually — deletes are part of the transaction. --- marimo/_session/file_change_handler.py | 130 ++++++++++-------- .../api/endpoints/test_resume_session.py | 17 +-- 2 files changed, 80 insertions(+), 67 deletions(-) diff --git a/marimo/_session/file_change_handler.py b/marimo/_session/file_change_handler.py index 1e2310de425..472d2963d90 100644 --- a/marimo/_session/file_change_handler.py +++ b/marimo/_session/file_change_handler.py @@ -12,12 +12,21 @@ from marimo import _loggers from marimo._config.manager import MarimoConfigManager +from marimo._messaging.notebook import DocumentChange +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) from marimo._messaging.notification import ( + NotebookDocumentTransactionNotification, ReloadNotification, - UpdateCellCodesNotification, - UpdateCellIdsNotification, ) -from marimo._runtime.commands import DeleteCellCommand, SyncGraphCommand +from marimo._runtime.commands import SyncGraphCommand from marimo._session.model import SessionMode from marimo._types.ids import CellId_t from marimo._utils import async_path @@ -68,12 +77,9 @@ def handle_reload( self, session: Session, *, changed_cell_ids: set[CellId_t] ) -> None: """Handle reload in edit mode with optional auto-run.""" - # Get the latest codes, cell IDs, names, and configs cell_manager = session.app_file_manager.app.cell_manager - codes = list(cell_manager.codes()) cell_ids = list(cell_manager.cell_ids()) - names = list(cell_manager.names()) - configs = list(cell_manager.configs()) + codes = list(cell_manager.codes()) LOGGER.info( f"File changed: {session.app_file_manager.path}. " @@ -81,68 +87,74 @@ def handle_reload( f"changed_cell_ids: {changed_cell_ids}" ) - # Send the updated cell IDs to the frontend - session.notify( - UpdateCellIdsNotification(cell_ids=cell_ids), - from_consumer_id=None, - ) + # Build a transaction by diffing session.document vs new cell_manager. + doc = session.document + doc_ids = set(doc) + new_ids = set(cell_ids) + deleted = doc_ids - new_ids + + changes: list[DocumentChange] = [] + + # Deletes + for cid in deleted: + changes.append(DeleteCell(cell_id=cid)) + + # Creates and updates + for cd in cell_manager.cell_data(): + if cd.cell_id not in doc_ids: + changes.append( + CreateCell( + cell_id=cd.cell_id, + code=cd.code, + name=cd.name, + config=cd.config, + ) + ) + else: + doc_cell = doc.get_cell(cd.cell_id) + if cd.code != doc_cell.code: + changes.append(SetCode(cell_id=cd.cell_id, code=cd.code)) + if cd.name != doc_cell.name: + changes.append(SetName(cell_id=cd.cell_id, name=cd.name)) + if cd.config != doc_cell.config: + changes.append( + SetConfig( + cell_id=cd.cell_id, + column=cd.config.column, + disabled=cd.config.disabled, + hide_code=cd.config.hide_code, + ) + ) + + # Reorder + changes.append(ReorderCells(cell_ids=tuple(cell_ids))) + + # Broadcast transaction — session.notify() applies to + # session.document and stamps the version before forwarding. + if changes: + session.notify( + NotebookDocumentTransactionNotification( + transaction=Transaction( + changes=tuple(changes), source="file-watch" + ) + ), + from_consumer_id=None, + ) - # Check if we should auto-run cells based on config + # Auto-run changed cells if configured. watcher_on_save = self._config_manager.get_config()["runtime"][ "watcher_on_save" ] - should_autorun = watcher_on_save == "autorun" - - # Determine deleted cells - deleted = { - cell_id for cell_id in changed_cell_ids if cell_id not in cell_ids - } - - # Auto-run cells if configured - if should_autorun: - changed_cell_ids_list = list(changed_cell_ids - deleted) - cells = dict(zip(cell_ids, codes)) - - # Send names and configs to the frontend (SyncGraphCommand - # doesn't carry them) - if cell_ids: - session.notify( - UpdateCellCodesNotification( - cell_ids=cell_ids, - codes=codes, - code_is_stale=False, - names=names, - configs=configs, - ), - from_consumer_id=None, - ) - + if watcher_on_save == "autorun": + changed_not_deleted = list(changed_cell_ids - deleted) session.put_control_request( SyncGraphCommand( - cells=cells, - run_ids=changed_cell_ids_list, + cells=dict(zip(cell_ids, codes)), + run_ids=changed_not_deleted, delete_ids=list(deleted), ), from_consumer_id=None, ) - else: - # Just send deletes and code updates - for to_delete in deleted: - session.put_control_request( - DeleteCellCommand(cell_id=to_delete), - from_consumer_id=None, - ) - if cell_ids: - session.notify( - UpdateCellCodesNotification( - cell_ids=cell_ids, - codes=codes, - code_is_stale=True, - names=names, - configs=configs, - ), - from_consumer_id=None, - ) class RunModeReloadStrategy: diff --git a/tests/_server/api/endpoints/test_resume_session.py b/tests/_server/api/endpoints/test_resume_session.py index 27d4584b493..8c2886e7eb1 100644 --- a/tests/_server/api/endpoints/test_resume_session.py +++ b/tests/_server/api/endpoints/test_resume_session.py @@ -286,12 +286,13 @@ def test_resume_session_after_file_change(client: TestClient) -> None: assert result.handled data = websocket.receive_json() - assert data == { - "op": "update-cell-ids", - "data": {"cell_ids": ["MJUe", "Hbol"], "op": "update-cell-ids"}, - } - data = websocket.receive_json() - assert data["op"] == "update-cell-codes" + assert data["op"] == "notebook-document-transaction" + tx = data["data"]["transaction"] + # Transaction should contain the new cell and reorder. + op_types = [op["type"] for op in tx["ops"]] + assert "create-cell" in op_types + assert "reorder-cells" in op_types + assert tx["source"] == "file-watch" # Resume session with new ID (simulates refresh) with client.websocket_connect(_create_ws_url("456")) as websocket: @@ -304,7 +305,7 @@ def test_resume_session_after_file_change(client: TestClient) -> None: assert parse_raw(data["data"], KernelReadyNotification) messages: list[dict[str, Any]] = [] - # Wait for update-cell-ids message + # Wait for update-cell-ids message (session replay) while True: data = websocket.receive_json() messages.append(data) @@ -313,7 +314,7 @@ def test_resume_session_after_file_change(client: TestClient) -> None: # 2 messages: # 1. banner - # 2. update-cell-ids + # 2. update-cell-ids (from session view replay) assert len(messages) == 2 assert messages[0]["op"] == "banner" assert messages[1] == { From 69d18cc53da768d21dba300fc7f71d96e81e6bef Mon Sep 17 00:00:00 2001 From: Trevor Manz Date: Mon, 23 Mar 2026 22:58:33 -0400 Subject: [PATCH 2/6] Back NotebookDocument cell text with Loro CRDT NotebookDocument previously stored cell code as plain strings on mutable NotebookCell structs. This moves cell text ownership to a `LoroDoc`, making the CRDT the single source of truth for code content while the document continues to own structural metadata (cell ordering, names, configs) in a new lightweight CellMeta class. `NotebookCell` becomes a frozen, read-only snapshot materialized on access from CellMeta + LoroText.to_string(). It is never stored internally by the document. `SetCode` ops now perform a full delete-then-insert on the `LoroText` container, which is the correct semantic for non-interactive writes (kernel, file-watch, code_mode). Character-level edits from the frontend continue to flow through the Loro RTC WebSocket unchanged. The new internal layout: ```py doc = NotebookDocument(create_doc()) doc.add_cell(cell_id, code="x = 1", name="__", config=CellConfig()) doc.loro_doc.commit() doc.get_cell(cell_id).code # materializes snapshot from LoroText ``` --- marimo/_notebook/__init__.py | 30 ++ marimo/_notebook/_loro.py | 19 + marimo/_notebook/document.py | 417 ++++++++++++++++++++ marimo/_runtime/runtime.py | 2 +- marimo/_session/session.py | 29 +- tests/_code_mode/test_cells_view.py | 8 +- tests/_code_mode/test_context.py | 8 +- tests/_notebook/test_document.py | 576 ++++++++++++++++++++++++++++ 8 files changed, 1066 insertions(+), 23 deletions(-) create mode 100644 marimo/_notebook/__init__.py create mode 100644 marimo/_notebook/_loro.py create mode 100644 marimo/_notebook/document.py create mode 100644 tests/_notebook/test_document.py diff --git a/marimo/_notebook/__init__.py b/marimo/_notebook/__init__.py new file mode 100644 index 00000000000..153a7bd09da --- /dev/null +++ b/marimo/_notebook/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Notebook document model — canonical representation of notebook structure.""" + +from marimo._notebook.document import CellMeta, NotebookCell, NotebookDocument +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) + +__all__ = [ + "CellMeta", + "CreateCell", + "DeleteCell", + "DocumentChange", + "MoveCell", + "NotebookCell", + "NotebookDocument", + "ReorderCells", + "SetCode", + "SetConfig", + "SetName", + "Transaction", +] diff --git a/marimo/_notebook/_loro.py b/marimo/_notebook/_loro.py new file mode 100644 index 00000000000..646cb37e82d --- /dev/null +++ b/marimo/_notebook/_loro.py @@ -0,0 +1,19 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Typed wrappers for ``loro`` constructors. + +The ``loro`` stubs omit return types on ``__new__``, which triggers +mypy ``no-untyped-call``. These helpers provide correctly-typed +construction so the rest of the codebase stays clean. +""" + +from __future__ import annotations + +from loro import LoroDoc, LoroText + + +def create_doc() -> LoroDoc: + return LoroDoc() # type: ignore[no-untyped-call] + + +def create_text() -> LoroText: + return LoroText() # type: ignore[no-untyped-call] diff --git a/marimo/_notebook/document.py b/marimo/_notebook/document.py new file mode 100644 index 00000000000..07b959638fe --- /dev/null +++ b/marimo/_notebook/document.py @@ -0,0 +1,417 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Canonical notebook document model. + +``NotebookDocument`` maintains an ordered list of ``CellMeta`` entries and +a ``LoroDoc`` that owns all cell source text as ``LoroText`` containers. +``NotebookCell`` is a read-only snapshot materialized on access. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING, Optional + +from marimo._utils.assert_never import assert_never + +if TYPE_CHECKING: + from collections.abc import Generator, Iterable, Iterator + +from loro import LoroDoc, LoroText +from msgspec.structs import replace as structs_replace + +from marimo._notebook._loro import create_doc, create_text + +import msgspec + +from marimo._ast.cell import CellConfig +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) +from marimo._types.ids import CellId_t + + +class NotebookCell(msgspec.Struct, frozen=True): + """Read-only snapshot of a cell, materialized from CellMeta + LoroText. + + This is never stored by the document — ``get_cell()`` and ``.cells`` + construct fresh instances each time. + """ + + id: CellId_t + code: str + name: str + config: CellConfig + + +class CellMeta: + """Mutable metadata for a cell. Owned by the document internally. + + Does *not* hold code — that lives in the ``LoroDoc``. + """ + + __slots__ = ("id", "name", "config") + + def __init__( + self, id: CellId_t, name: str, config: CellConfig + ) -> None: + self.id = id + self.name = name + self.config = config + + +class NotebookDocument: + """Ordered collection of cells with transactional updates. + + Cell text is owned by a ``LoroDoc`` (one ``LoroText`` per cell under + ``LoroMap("codes")``). Structural metadata (name, config, ordering) + is stored in ``_cell_metas``. + + Usage:: + + from loro import LoroDoc + doc = NotebookDocument(LoroDoc()) + doc.add_cell(CellId_t("a"), code="x = 1", name="__", config=CellConfig()) + tx = Transaction( + changes=(SetCode(CellId_t("a"), "x = 2"),), source="kernel" + ) + applied = doc.apply(tx) + assert applied.version == 1 + assert doc.get_cell(CellId_t("a")).code == "x = 2" + """ + + def __init__(self, loro_doc: LoroDoc) -> None: + self._loro_doc = loro_doc + self._codes_map = loro_doc.get_map("codes") + self._cell_metas: list[CellMeta] = [] + self._version: int = 0 + + @classmethod + def from_cells(cls, cells: Iterable[NotebookCell]) -> NotebookDocument: + """Build a document from ``NotebookCell`` snapshots. + + Creates a fresh ``LoroDoc`` populated from the snapshot data. + Used at the kernel-process boundary where cells arrive as + serialized structs and need to be reconstructed into a live + document. + """ + doc = cls(create_doc()) + for c in cells: + doc.add_cell( + cell_id=c.id, code=c.code, name=c.name, config=c.config + ) + doc._loro_doc.commit() + return doc + + @property + def loro_doc(self) -> LoroDoc: + """The underlying Loro document owning cell text.""" + return self._loro_doc + + # ------------------------------------------------------------------ + # Bootstrap — populate the document from an external source + # ------------------------------------------------------------------ + + def add_cell( + self, + cell_id: CellId_t, + code: str, + name: str, + config: CellConfig, + ) -> None: + """Append a cell during initial document construction. + + This is *not* a transaction — it is used at session init to + populate the document from a ``CellManager``. + """ + text = create_text() + text.insert(0, code) + self._codes_map.insert_container(cell_id, text) + self._cell_metas.append(CellMeta(id=cell_id, name=name, config=config)) + + # ------------------------------------------------------------------ + # Read-only accessors + # ------------------------------------------------------------------ + + @property + def cells(self) -> list[NotebookCell]: + """Materialize and return a snapshot list of all cells.""" + return [self._snapshot(m) for m in self._cell_metas] + + @property + def cell_ids(self) -> list[CellId_t]: + """Cell IDs in document order.""" + return [m.id for m in self._cell_metas] + + @property + def version(self) -> int: + return self._version + + def get_cell(self, cell_id: CellId_t) -> NotebookCell: + """Lookup by ID. Raises ``KeyError`` if not found.""" + return self._snapshot(self._find_meta(cell_id)) + + def get(self, cell_id: CellId_t) -> NotebookCell | None: + """Lookup by ID, returning ``None`` if not found.""" + for m in self._cell_metas: + if m.id == cell_id: + return self._snapshot(m) + return None + + def __contains__(self, cell_id: object) -> bool: + return any(m.id == cell_id for m in self._cell_metas) + + def __len__(self) -> int: + return len(self._cell_metas) + + def __iter__(self) -> Iterator[CellId_t]: + return (m.id for m in self._cell_metas) + + # ------------------------------------------------------------------ + # Transaction application + # ------------------------------------------------------------------ + + def apply(self, tx: Transaction) -> Transaction: + """Validate and apply *tx*, return it with ``version`` assigned. + + Raises ``ValueError`` for validation failures and ``KeyError`` + when a change references a non-existent cell. + """ + if not tx.changes: + return structs_replace(tx, version=self._version) + + _validate(tx.changes, self._cell_metas) + + for change in tx.changes: + self._apply_change(change) + + self._version += 1 + return structs_replace(tx, version=self._version) + + def _apply_change(self, change: DocumentChange) -> None: + # TODO: refactor to use match/case (min Python is 3.10) once + # ruff target-version is bumped from py39. + if isinstance(change, CreateCell): + # Create LoroText in the shared doc + text = create_text() + text.insert(0, change.code) + self._codes_map.insert_container(change.cell_id, text) + + meta = CellMeta(id=change.cell_id, name=change.name, config=change.config) + if change.after is not None: + idx = self._find_index(change.after) + self._cell_metas.insert(idx + 1, meta) + elif change.before is not None: + idx = self._find_index(change.before) + self._cell_metas.insert(idx, meta) + else: + self._cell_metas.append(meta) + + elif isinstance(change, DeleteCell): + idx = self._find_index(change.cell_id) + del self._cell_metas[idx] + self._codes_map.delete(change.cell_id) + + elif isinstance(change, MoveCell): + idx = self._find_index(change.cell_id) + meta = self._cell_metas.pop(idx) + if change.after is not None: + target = self._find_index(change.after) + self._cell_metas.insert(target + 1, meta) + elif change.before is not None: + target = self._find_index(change.before) + self._cell_metas.insert(target, meta) + else: + raise ValueError("MoveCell requires 'before' or 'after'") + + elif isinstance(change, ReorderCells): + by_id = {m.id: m for m in self._cell_metas} + seen: set[CellId_t] = set() + reordered: list[CellMeta] = [] + for cid in change.cell_ids: + if cid in by_id and cid not in seen: + reordered.append(by_id[cid]) + seen.add(cid) + for m in self._cell_metas: + if m.id not in seen: + reordered.append(m) + self._cell_metas = reordered + + elif isinstance(change, SetCode): + # Verify cell exists + self._find_meta(change.cell_id) + # Full replace in Loro + text = self._get_loro_text(change.cell_id) + if text.len_unicode > 0: + text.delete(0, text.len_unicode) + if change.code: + text.insert(0, change.code) + + elif isinstance(change, SetName): + self._find_meta(change.cell_id).name = change.name + + elif isinstance(change, SetConfig): + meta = self._find_meta(change.cell_id) + meta.config = CellConfig( + column=change.column + if change.column is not None + else meta.config.column, + disabled=change.disabled + if change.disabled is not None + else meta.config.disabled, + hide_code=change.hide_code + if change.hide_code is not None + else meta.config.hide_code, + ) + else: + assert_never(change) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _snapshot(self, meta: CellMeta) -> NotebookCell: + """Build a read-only ``NotebookCell`` from metadata + Loro text.""" + code = self._get_loro_text(meta.id).to_string() + return NotebookCell( + id=meta.id, code=code, name=meta.name, config=meta.config + ) + + def _get_loro_text(self, cell_id: CellId_t) -> LoroText: + """Return the ``LoroText`` for *cell_id*.""" + val = self._codes_map.get(cell_id) + if val is None: + raise KeyError(f"No LoroText for cell {cell_id!r}") + container = val.container # type: ignore[union-attr,attr-defined] + assert isinstance(container, LoroText) + return container + + def _find_index(self, cell_id: CellId_t) -> int: + for i, m in enumerate(self._cell_metas): + if m.id == cell_id: + return i + raise KeyError(f"Cell {cell_id!r} not found in document") + + def _find_meta(self, cell_id: CellId_t) -> CellMeta: + for m in self._cell_metas: + if m.id == cell_id: + return m + raise KeyError(f"Cell {cell_id!r} not found in document") + + def __repr__(self) -> str: + lines = [f"NotebookDocument({len(self._cell_metas)} cells):"] + for i, m in enumerate(self._cell_metas): + code_preview = self._get_loro_text(m.id).to_string()[:40].replace( + "\n", "\\n" + ) + lines.append(f" {i}: {m.id} {code_preview!r}") + return "\n".join(lines) + + +# ------------------------------------------------------------------ +# Context variable +# ------------------------------------------------------------------ + +#: Document snapshot for the current scratchpad execution. Set by the +#: kernel before running code_mode so ``AsyncCodeModeContext`` can read +#: cell ordering, code, names, and configs without the kernel carrying +#: mutable document state. +_current_document: ContextVar[NotebookDocument | None] = ContextVar( + "_current_document", default=None +) + + +def get_current_document() -> NotebookDocument | None: + """Return the document for the current execution, if any.""" + return _current_document.get() + + +@contextmanager +def notebook_document_context( + doc: NotebookDocument | None, +) -> Generator[None, None, None]: + """Context manager for setting and resetting the current document.""" + token = _current_document.set(doc) + try: + yield + finally: + _current_document.reset(token) + + +# ------------------------------------------------------------------ +# Validation +# ------------------------------------------------------------------ + + +def _validate( + changes: tuple[DocumentChange, ...], metas: list[CellMeta] +) -> None: + """Check for conflicting changes. Raises ``ValueError``.""" + existing_ids = {m.id for m in metas} + created: set[CellId_t] = set() + deleted: set[CellId_t] = set() + updated: set[CellId_t] = set() + moved: set[CellId_t] = set() + + for change in changes: + if isinstance(change, CreateCell): + if change.cell_id in existing_ids or change.cell_id in created: + raise ValueError(f"Cell {change.cell_id!r} already exists") + if change.before is not None and change.after is not None: + raise ValueError( + "CreateCell cannot specify both 'before' and 'after'" + ) + created.add(change.cell_id) + + elif isinstance(change, DeleteCell): + if change.cell_id in deleted: + raise ValueError( + f"Cell {change.cell_id!r} is deleted more than once" + ) + if change.cell_id in updated: + raise ValueError( + f"Cannot delete cell {change.cell_id!r} that is also " + f"updated in the same transaction" + ) + if change.cell_id in moved: + raise ValueError( + f"Cannot delete cell {change.cell_id!r} that is also " + f"moved in the same transaction" + ) + deleted.add(change.cell_id) + + elif isinstance(change, MoveCell): + if change.cell_id in deleted: + raise ValueError( + f"Cannot move cell {change.cell_id!r} that is also " + f"deleted in the same transaction" + ) + if change.before is not None and change.after is not None: + raise ValueError( + "MoveCell cannot specify both 'before' and 'after'" + ) + if change.before is None and change.after is None: + raise ValueError("MoveCell requires 'before' or 'after'") + moved.add(change.cell_id) + + elif isinstance(change, ReorderCells): + pass # No conflicts — replaces full ordering + + elif isinstance(change, (SetCode, SetName, SetConfig)): + if change.cell_id in deleted: + raise ValueError( + f"Cannot update cell {change.cell_id!r} that is also " + f"deleted in the same transaction" + ) + updated.add(change.cell_id) + + else: + raise TypeError(f"Unknown change type: {type(change)!r}") diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 4ea3efce220..44a917b14c4 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -2296,7 +2296,7 @@ async def handle_execute_scratchpad( request: ExecuteScratchpadCommand, ) -> None: doc = ( - NotebookDocument(list(request.notebook_cells)) + NotebookDocument.from_cells(request.notebook_cells) if request.notebook_cells is not None else None ) diff --git a/marimo/_session/session.py b/marimo/_session/session.py index 432a5a047fa..88d1a701aa9 100644 --- a/marimo/_session/session.py +++ b/marimo/_session/session.py @@ -15,12 +15,12 @@ from marimo import _loggers from marimo._cli.sandbox import SandboxMode from marimo._config.manager import MarimoConfigManager, ScriptConfigManager -from marimo._messaging.notebook.document import NotebookCell, NotebookDocument from marimo._messaging.notification import ( NotificationMessage, ) from marimo._messaging.serde import serialize_kernel_message from marimo._messaging.types import KernelMessage +from marimo._notebook.document import NotebookDocument from marimo._runtime import commands from marimo._runtime.commands import ( AppMetadata, @@ -76,6 +76,10 @@ def _document_from_cell_manager(cell_manager: CellManager) -> NotebookDocument: """Build a NotebookDocument from a CellManager's current state. + Creates a ``LoroDoc`` that owns all cell text as ``LoroText`` + containers. Structural metadata (name, config, ordering) is stored + in the document's internal ``CellMeta`` list. + TODO: CellManager and NotebookDocument track overlapping state (cell ordering, code, names, configs). Once the document model is wired into all consumers, we should reconcile these — either CellManager @@ -83,17 +87,18 @@ def _document_from_cell_manager(cell_manager: CellManager) -> NotebookDocument: different composition. For now, the document is populated from the cell manager at session startup and the two coexist. """ - return NotebookDocument( - [ - NotebookCell( - id=cd.cell_id, - code=cd.code, - name=cd.name, - config=cd.config, - ) - for cd in cell_manager.cell_data() - ] - ) + from marimo._notebook._loro import create_doc + + doc = NotebookDocument(create_doc()) + for cd in cell_manager.cell_data(): + doc.add_cell( + cell_id=cd.cell_id, + code=cd.code, + name=cd.name, + config=cd.config, + ) + doc.loro_doc.commit() + return doc class SessionImpl(Session): diff --git a/tests/_code_mode/test_cells_view.py b/tests/_code_mode/test_cells_view.py index 51bee8bde5b..b0440cc5b66 100644 --- a/tests/_code_mode/test_cells_view.py +++ b/tests/_code_mode/test_cells_view.py @@ -27,11 +27,9 @@ def cmd(cell_id: str, code: str) -> ExecuteCellCommand: @contextmanager def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]: """Build an AsyncCodeModeContext with a document snapshot from the kernel.""" - doc = NotebookDocument( - [ - NotebookCell(id=cid, code=cell.code, name="", config=cell.config) - for cid, cell in k.graph.cells.items() - ] + doc = NotebookDocument.from_cells( + NotebookCell(id=cid, code=cell.code, name="", config=cell.config) + for cid, cell in k.graph.cells.items() ) with notebook_document_context(doc): yield AsyncCodeModeContext(k) diff --git a/tests/_code_mode/test_context.py b/tests/_code_mode/test_context.py index 881722de094..72d205b68bb 100644 --- a/tests/_code_mode/test_context.py +++ b/tests/_code_mode/test_context.py @@ -29,11 +29,9 @@ @contextmanager def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]: """Build an AsyncCodeModeContext with a document snapshot from the kernel.""" - doc = NotebookDocument( - [ - NotebookCell(id=cid, code=cell.code, name="", config=cell.config) - for cid, cell in k.graph.cells.items() - ] + doc = NotebookDocument.from_cells( + NotebookCell(id=cid, code=cell.code, name="", config=cell.config) + for cid, cell in k.graph.cells.items() ) with notebook_document_context(doc): yield AsyncCodeModeContext(k) diff --git a/tests/_notebook/test_document.py b/tests/_notebook/test_document.py new file mode 100644 index 00000000000..a956df8f156 --- /dev/null +++ b/tests/_notebook/test_document.py @@ -0,0 +1,576 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import pytest +from inline_snapshot import snapshot +from loro import LoroDoc + +from marimo._ast.cell import CellConfig +from marimo._notebook.document import NotebookDocument +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + ReorderCells, + SetCode, + SetConfig, + SetName, + Transaction, +) +from marimo._types.ids import CellId_t + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _doc(*names: str) -> NotebookDocument: + doc = NotebookDocument(LoroDoc()) + for n in names: + doc.add_cell( + cell_id=CellId_t(n), code="", name="__", config=CellConfig() + ) + doc._loro_doc.commit() + return doc + + +def _tx(*changes: DocumentChange, source: str = "test") -> Transaction: + return Transaction(changes=changes, source=source) + + +def _ids(doc: NotebookDocument) -> list[str]: + return [str(cid) for cid in doc.cell_ids] + + +# ------------------------------------------------------------------ +# CreateCell +# ------------------------------------------------------------------ + + +class TestCreateCell: + def test_append_at_end(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + assert _ids(doc) == snapshot(["a", "b", "new"]) + + def test_insert_after(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + after=CellId_t("a"), + ) + ) + ) + assert _ids(doc) == snapshot(["a", "new", "b", "c"]) + + def test_insert_before(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + before=CellId_t("b"), + ) + ) + ) + assert _ids(doc) == snapshot(["a", "new", "b", "c"]) + + def test_insert_into_empty(self) -> None: + doc = _doc() + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + assert _ids(doc) == snapshot(["new"]) + + def test_multiple(self) -> None: + doc = _doc("a") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("x"), + code="1", + name="__", + config=CellConfig(), + ), + CreateCell( + cell_id=CellId_t("y"), + code="2", + name="__", + config=CellConfig(), + ), + ) + ) + assert _ids(doc) == snapshot(["a", "x", "y"]) + + def test_after_pending(self) -> None: + """A create can reference a cell added earlier in the same tx.""" + doc = _doc("a") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("x"), + code="1", + name="__", + config=CellConfig(), + ), + CreateCell( + cell_id=CellId_t("y"), + code="2", + name="__", + config=CellConfig(), + after=CellId_t("x"), + ), + ) + ) + assert _ids(doc) == snapshot(["a", "x", "y"]) + + def test_duplicate_id_raises(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="already exists"): + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("a"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + + def test_stores_code_and_config(self) -> None: + doc = _doc() + cfg = CellConfig(hide_code=True, disabled=True) + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("a"), + code="import os", + name="imports", + config=cfg, + ) + ) + ) + cell = doc.get_cell(CellId_t("a")) + assert cell.code == "import os" + assert cell.name == "imports" + assert cell.config.hide_code is True + assert cell.config.disabled is True + + +# ------------------------------------------------------------------ +# DeleteCell +# ------------------------------------------------------------------ + + +class TestDeleteCell: + def test_delete_single(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(DeleteCell(cell_id=CellId_t("b")))) + assert _ids(doc) == snapshot(["a", "c"]) + + def test_delete_multiple(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + DeleteCell(cell_id=CellId_t("c")), + ) + ) + assert _ids(doc) == snapshot(["b"]) + + def test_delete_all(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + DeleteCell(cell_id=CellId_t("b")), + ) + ) + assert _ids(doc) == snapshot([]) + + def test_delete_not_found(self) -> None: + doc = _doc("a") + with pytest.raises(KeyError): + doc.apply(_tx(DeleteCell(cell_id=CellId_t("missing")))) + + +# ------------------------------------------------------------------ +# MoveCell +# ------------------------------------------------------------------ + + +class TestMoveCell: + def test_move_after(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(MoveCell(cell_id=CellId_t("a"), after=CellId_t("c")))) + assert _ids(doc) == snapshot(["b", "c", "a"]) + + def test_move_before(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(MoveCell(cell_id=CellId_t("c"), before=CellId_t("a")))) + assert _ids(doc) == snapshot(["c", "a", "b"]) + + def test_no_anchor_raises(self) -> None: + doc = _doc("a", "b") + with pytest.raises(ValueError, match="before.*after"): + doc.apply(_tx(MoveCell(cell_id=CellId_t("a")))) + + +# ------------------------------------------------------------------ +# SetCode +# ------------------------------------------------------------------ + + +class TestSetCode: + def test_update_code(self) -> None: + doc = _doc("a", "b") + doc.apply(_tx(SetCode(cell_id=CellId_t("b"), code="new"))) + assert doc.get_cell(CellId_t("b")).code == "new" + assert doc.get_cell(CellId_t("a")).code == "" # unchanged + + def test_not_found(self) -> None: + doc = _doc("a") + with pytest.raises(KeyError): + doc.apply(_tx(SetCode(cell_id=CellId_t("missing"), code="x"))) + + +# ------------------------------------------------------------------ +# SetName +# ------------------------------------------------------------------ + + +class TestSetName: + def test_update_name(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetName(cell_id=CellId_t("a"), name="my_cell"))) + assert doc.get_cell(CellId_t("a")).name == "my_cell" + + +# ------------------------------------------------------------------ +# SetConfig +# ------------------------------------------------------------------ + + +class TestSetConfig: + def test_partial_hide_code(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetConfig(cell_id=CellId_t("a"), hide_code=True))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.hide_code is True + assert cfg.disabled is False # unchanged default + + def test_partial_disabled(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetConfig(cell_id=CellId_t("a"), disabled=True))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.disabled is True + assert cfg.hide_code is False # unchanged default + + def test_multiple_fields(self) -> None: + doc = _doc("a") + doc.apply( + _tx( + SetConfig(cell_id=CellId_t("a"), hide_code=True, disabled=True) + ) + ) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.hide_code is True + assert cfg.disabled is True + + def test_all_none_is_noop(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetConfig(cell_id=CellId_t("a")))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg == CellConfig() + + def test_preserves_existing(self) -> None: + """Setting one field preserves other non-default fields.""" + doc = NotebookDocument(LoroDoc()) + doc.add_cell( + cell_id=CellId_t("a"), + code="", + name="__", + config=CellConfig(hide_code=True, column=2), + ) + doc._loro_doc.commit() + doc.apply(_tx(SetConfig(cell_id=CellId_t("a"), disabled=True))) + cfg = doc.get_cell(CellId_t("a")).config + assert cfg.disabled is True + assert cfg.hide_code is True # preserved + assert cfg.column == 2 # preserved + + +# ------------------------------------------------------------------ +# Validation +# ------------------------------------------------------------------ + + +class TestValidation: + def test_delete_and_set_code_same_cell(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="delete.*update"): + doc.apply( + _tx( + SetCode(cell_id=CellId_t("a"), code="x"), + DeleteCell(cell_id=CellId_t("a")), + ) + ) + + def test_set_code_and_delete_same_cell(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="update.*delete"): + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + SetCode(cell_id=CellId_t("a"), code="x"), + ) + ) + + def test_delete_and_move_same_cell(self) -> None: + doc = _doc("a", "b") + with pytest.raises(ValueError, match="delete.*move"): + doc.apply( + _tx( + MoveCell(cell_id=CellId_t("a"), after=CellId_t("b")), + DeleteCell(cell_id=CellId_t("a")), + ) + ) + + def test_double_delete(self) -> None: + doc = _doc("a") + with pytest.raises(ValueError, match="deleted more than once"): + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("a")), + DeleteCell(cell_id=CellId_t("a")), + ) + ) + + def test_valid_mixed_changes(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ), + SetCode(cell_id=CellId_t("a"), code="y"), + DeleteCell(cell_id=CellId_t("b")), + ) + ) + assert _ids(doc) == snapshot(["a", "new"]) + + +# ------------------------------------------------------------------ +# Versioning +# ------------------------------------------------------------------ + + +class TestVersion: + def test_increments_on_apply(self) -> None: + doc = _doc("a") + assert doc.version == 0 + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.version == 1 + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="y"))) + assert doc.version == 2 + + def test_stamped_on_returned_tx(self) -> None: + doc = _doc("a") + tx = _tx(SetCode(cell_id=CellId_t("a"), code="x")) + assert tx.version is None + applied = doc.apply(tx) + assert applied.version == 1 + + def test_empty_tx_no_increment(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.version == 1 + applied = doc.apply(_tx()) + assert doc.version == 1 + assert applied.version == 1 + + +# ------------------------------------------------------------------ +# Combined changes +# ------------------------------------------------------------------ + + +class TestCombined: + def test_delete_then_create(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + DeleteCell(cell_id=CellId_t("b")), + CreateCell( + cell_id=CellId_t("new"), + code="d", + name="__", + config=CellConfig(), + after=CellId_t("a"), + ), + ) + ) + assert _ids(doc) == snapshot(["a", "new", "c"]) + + def test_create_then_set_code(self) -> None: + doc = _doc("a") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="tmp", + name="__", + config=CellConfig(), + ), + SetCode(cell_id=CellId_t("new"), code="final"), + ) + ) + assert _ids(doc) == snapshot(["a", "new"]) + assert doc.get_cell(CellId_t("new")).code == "final" + + def test_create_then_move(self) -> None: + doc = _doc("a", "b") + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("new"), + code="x", + name="__", + config=CellConfig(), + ), + MoveCell(cell_id=CellId_t("new"), before=CellId_t("a")), + ) + ) + assert _ids(doc) == snapshot(["new", "a", "b"]) + + def test_source_preserved(self) -> None: + doc = _doc("a") + applied = doc.apply( + _tx( + SetCode(cell_id=CellId_t("a"), code="x"), + source="kernel", + ) + ) + assert applied.source == "kernel" + + +# ------------------------------------------------------------------ +# Initialization +# ------------------------------------------------------------------ + + +class TestInit: + def test_from_cells(self) -> None: + doc = _doc("a", "b", "c") + assert _ids(doc) == ["a", "b", "c"] + + def test_empty(self) -> None: + doc = NotebookDocument(LoroDoc()) + assert _ids(doc) == [] + assert doc.version == 0 + + def test_get_cell(self) -> None: + doc = _doc("a", "b") + cell = doc.get_cell(CellId_t("b")) + assert cell.id == CellId_t("b") + + def test_get_cell_not_found(self) -> None: + doc = _doc("a") + with pytest.raises(KeyError): + doc.get_cell(CellId_t("missing")) + + def test_get_returns_none(self) -> None: + doc = _doc("a") + assert doc.get(CellId_t("missing")) is None + assert doc.get(CellId_t("a")) is not None + + def test_contains(self) -> None: + doc = _doc("a", "b") + assert CellId_t("a") in doc + assert CellId_t("missing") not in doc + + def test_len(self) -> None: + assert len(_doc()) == 0 + assert len(_doc("a", "b")) == 2 + + def test_iter(self) -> None: + doc = _doc("a", "b", "c") + assert list(doc) == [CellId_t("a"), CellId_t("b"), CellId_t("c")] + + def test_repr(self) -> None: + doc = _doc("a") + assert "NotebookDocument(1 cells)" in repr(doc) + + +# ------------------------------------------------------------------ +# ReorderCells +# ------------------------------------------------------------------ + + +class TestReorderCells: + def test_reorder(self) -> None: + doc = _doc("a", "b", "c") + doc.apply( + _tx( + ReorderCells( + cell_ids=(CellId_t("c"), CellId_t("a"), CellId_t("b")) + ) + ) + ) + assert _ids(doc) == snapshot(["c", "a", "b"]) + + def test_missing_ids_appended(self) -> None: + """Cells not in the reorder list are appended at the end.""" + doc = _doc("a", "b", "c") + doc.apply(_tx(ReorderCells(cell_ids=(CellId_t("c"), CellId_t("a"))))) + assert _ids(doc) == snapshot(["c", "a", "b"]) + + def test_unknown_ids_ignored(self) -> None: + """IDs not in the document are silently skipped.""" + doc = _doc("a", "b") + doc.apply( + _tx( + ReorderCells( + cell_ids=( + CellId_t("b"), + CellId_t("unknown"), + CellId_t("a"), + ) + ) + ) + ) + assert _ids(doc) == snapshot(["b", "a"]) + + def test_reorder_single(self) -> None: + doc = _doc("a", "b", "c") + doc.apply(_tx(ReorderCells(cell_ids=(CellId_t("b"),)))) + assert _ids(doc) == snapshot(["b", "a", "c"]) From 7c0c578bf1f3102d8542960717192b4a0edea2c9 Mon Sep 17 00:00:00 2001 From: Trevor Manz Date: Mon, 23 Mar 2026 23:17:33 -0400 Subject: [PATCH 3/6] Unify LoroDocManager with session's LoroDoc `LoroDocManager` previously created its own `LoroDoc` with duplicate cell data at RTC init time. Now that `NotebookDocument` owns the `LoroDoc`, the manager just registers the session's existing doc via `register_doc`, giving RTC clients and the document model a single shared instance. The cleanup timer is removed since the doc's lifetime is tied to the session. A `subscribe_local_update` hook broadcasts server-originated Loro mutations (from SetCode, file-watch) to connected RTC clients, and `apply()` batches all ops into one `doc.commit()` so clients receive one update per transaction. --- marimo/_notebook/document.py | 37 +++-- .../api/endpoints/ws/ws_kernel_ready.py | 20 ++- .../api/endpoints/ws/ws_rtc_handler.py | 2 +- marimo/_server/rtc/doc.py | 154 ++++++------------ 4 files changed, 86 insertions(+), 127 deletions(-) diff --git a/marimo/_notebook/document.py b/marimo/_notebook/document.py index 07b959638fe..1c2258dfc96 100644 --- a/marimo/_notebook/document.py +++ b/marimo/_notebook/document.py @@ -10,21 +10,19 @@ from contextlib import contextmanager from contextvars import ContextVar -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from marimo._utils.assert_never import assert_never if TYPE_CHECKING: from collections.abc import Generator, Iterable, Iterator +import msgspec from loro import LoroDoc, LoroText from msgspec.structs import replace as structs_replace -from marimo._notebook._loro import create_doc, create_text - -import msgspec - from marimo._ast.cell import CellConfig +from marimo._notebook._loro import create_doc, create_text from marimo._messaging.notebook.changes import ( CreateCell, DeleteCell, @@ -52,20 +50,15 @@ class NotebookCell(msgspec.Struct, frozen=True): config: CellConfig -class CellMeta: - """Mutable metadata for a cell. Owned by the document internally. +class CellMeta(msgspec.Struct): + """Mutable metadata for a cell. Owned by the document internally. Does *not* hold code — that lives in the ``LoroDoc``. """ - __slots__ = ("id", "name", "config") - - def __init__( - self, id: CellId_t, name: str, config: CellConfig - ) -> None: - self.id = id - self.name = name - self.config = config + id: CellId_t + name: str + config: CellConfig class NotebookDocument: @@ -78,8 +71,11 @@ class NotebookDocument: Usage:: from loro import LoroDoc + doc = NotebookDocument(LoroDoc()) - doc.add_cell(CellId_t("a"), code="x = 1", name="__", config=CellConfig()) + doc.add_cell( + CellId_t("a"), code="x = 1", name="__", config=CellConfig() + ) tx = Transaction( changes=(SetCode(CellId_t("a"), "x = 2"),), source="kernel" ) @@ -193,6 +189,11 @@ def apply(self, tx: Transaction) -> Transaction: for change in tx.changes: self._apply_change(change) + # Commit all Loro mutations from this transaction as a single + # batch. This triggers one ``subscribe_local_update`` callback + # so RTC clients receive one update per transaction. + self._loro_doc.commit() + self._version += 1 return structs_replace(tx, version=self._version) @@ -309,8 +310,8 @@ def _find_meta(self, cell_id: CellId_t) -> CellMeta: def __repr__(self) -> str: lines = [f"NotebookDocument({len(self._cell_metas)} cells):"] for i, m in enumerate(self._cell_metas): - code_preview = self._get_loro_text(m.id).to_string()[:40].replace( - "\n", "\\n" + code_preview = ( + self._get_loro_text(m.id).to_string()[:40].replace("\n", "\\n") ) lines.append(f" {i}: {m.id} {code_preview!r}") return "\n".join(lines) diff --git a/marimo/_server/api/endpoints/ws/ws_kernel_ready.py b/marimo/_server/api/endpoints/ws/ws_kernel_ready.py index 409c1c6af14..595925fb1eb 100644 --- a/marimo/_server/api/endpoints/ws/ws_kernel_ready.py +++ b/marimo/_server/api/endpoints/ws/ws_kernel_ready.py @@ -64,9 +64,9 @@ def build_kernel_ready( """ codes, names, configs, cell_ids = _extract_cell_data(session, manager) - # Initialize RTC if needed + # Register session's LoroDoc for RTC if needed if _should_init_rtc(rtc_enabled, mode): - _try_init_rtc_doc(cell_ids, codes, file_key, doc_manager) + _try_init_rtc_doc(session, file_key, doc_manager) return KernelReadyNotification( codes=codes, @@ -158,18 +158,20 @@ def _should_init_rtc(rtc_enabled: bool, mode: SessionMode) -> bool: def _try_init_rtc_doc( - cell_ids: tuple[CellId_t, ...], - codes: tuple[str, ...], + session: Session, file_key: MarimoFileKey, doc_manager: LoroDocManager, ) -> None: - """Try to initialize RTC document with cell data. + """Register the session's LoroDoc with the RTC doc manager. + + The session's ``NotebookDocument`` already owns a ``LoroDoc`` that + holds all cell text. This makes that same doc available for RTC + client connections so there is a single source of truth. Logs a warning if Loro is not available but does not fail. Args: - cell_ids: Cell IDs to initialize - codes: Cell codes to initialize + session: Current session (owns the LoroDoc via its document) file_key: File key for the document doc_manager: LoroDoc manager """ @@ -180,4 +182,6 @@ def _try_init_rtc_doc( "RTC: Loro is not installed, disabling real-time collaboration" ) else: - asyncio.create_task(doc_manager.create_doc(file_key, cell_ids, codes)) + asyncio.create_task( + doc_manager.register_doc(file_key, session.document.loro_doc) + ) diff --git a/marimo/_server/api/endpoints/ws/ws_rtc_handler.py b/marimo/_server/api/endpoints/ws/ws_rtc_handler.py index 62e242b8465..7b6dfd36720 100644 --- a/marimo/_server/api/endpoints/ws/ws_rtc_handler.py +++ b/marimo/_server/api/endpoints/ws/ws_rtc_handler.py @@ -49,7 +49,7 @@ async def handle(self) -> None: # Get or create the LoroDoc and add the client to it LOGGER.debug("RTC: getting document") update_queue: asyncio.Queue[bytes] = asyncio.Queue() - doc = await self.doc_manager.get_or_create_doc(self.file_key) + doc = await self.doc_manager.get_doc(self.file_key) self.doc_manager.add_client_to_doc(self.file_key, update_queue) # Send initial sync to client diff --git a/marimo/_server/rtc/doc.py b/marimo/_server/rtc/doc.py index 7b42ba29335..0fd956068d0 100644 --- a/marimo/_server/rtc/doc.py +++ b/marimo/_server/rtc/doc.py @@ -6,7 +6,6 @@ from marimo import _loggers from marimo._server.file_router import MarimoFileKey -from marimo._types.ids import CellId_t if TYPE_CHECKING: from loro import LoroDoc @@ -21,86 +20,57 @@ def __init__(self) -> None: self.loro_docs_clients: dict[ MarimoFileKey, set[asyncio.Queue[bytes]] ] = {} - self.loro_docs_cleaners: dict[ - MarimoFileKey, Optional[asyncio.Task[None]] - ] = {} + # Hold subscription references to prevent garbage collection + self._subscriptions: dict[MarimoFileKey, object] = {} - async def _clean_loro_doc( - self, file_key: MarimoFileKey, timeout: float = 60 - ) -> None: - """Clean up a loro doc if no clients are connected.""" - try: - await asyncio.sleep(timeout) - async with self.loro_docs_lock: - if ( - file_key in self.loro_docs_clients - and len(self.loro_docs_clients[file_key]) == 0 - ): - LOGGER.debug( - f"RTC: Removing loro doc for file {file_key} as it has no clients" - ) - # Clean up the document - await self._do_remove_doc(file_key) - except asyncio.CancelledError: - # Task was cancelled due to client reconnection - LOGGER.debug( - f"RTC: clean_loro_doc task cancelled for file {file_key} - likely due to reconnection" - ) - pass - - async def create_doc( + async def register_doc( self, file_key: MarimoFileKey, - cell_ids: tuple[CellId_t, ...], - codes: tuple[str, ...], - ) -> LoroDoc: - """Create a new loro doc.""" - from loro import LoroDoc, LoroText - - assert len(cell_ids) == len(codes), ( - "cell_ids and codes must be the same length" - ) - + doc: LoroDoc, + ) -> None: + """Register an existing LoroDoc (owned by a session's NotebookDocument). + + The session creates the LoroDoc at init time and the document + model owns it for the lifetime of the session. This method + makes the same doc available for RTC client connections and + subscribes to local updates so that server-originated changes + (e.g. SetCode from kernel or file-watch) are broadcast to all + connected RTC clients. + """ async with self.loro_docs_lock: if file_key in self.loro_docs: - return self.loro_docs[file_key] - - LOGGER.debug(f"RTC: Initializing LoroDoc for file {file_key}") - doc = LoroDoc() # type: ignore[no-untyped-call] - self.loro_docs[file_key] = doc - - # Add all cell code to the doc - doc_codes = doc.get_map("codes") + LOGGER.debug( + f"RTC: LoroDoc already registered for file {file_key}" + ) + return + # Ensure the languages map exists for the frontend doc.get_map("languages") - for cell_id, code in zip(cell_ids, codes): - cell_text = LoroText() # type: ignore[no-untyped-call] - cell_text.insert(0, code) - doc_codes.insert_container(cell_id, cell_text) + LOGGER.debug(f"RTC: Registered LoroDoc for file {file_key}") + self.loro_docs[file_key] = doc - # We don't set the language here because it will be set - # when the client connects for the first time. - return doc + # Broadcast server-side Loro mutations to RTC clients. + # The callback fires synchronously on doc.commit() — we + # enqueue directly into client queues (non-blocking). + def _on_local_update(update: bytes) -> bool: + clients = self.loro_docs_clients.get(file_key, set()) + for client in clients: + client.put_nowait(update) + return True # keep subscription alive + + self._subscriptions[file_key] = doc.subscribe_local_update( + _on_local_update + ) - async def get_or_create_doc(self, file_key: MarimoFileKey) -> LoroDoc: - """Get or create a loro doc for a file key.""" - from loro import LoroDoc + async def get_doc(self, file_key: MarimoFileKey) -> LoroDoc: + """Get the LoroDoc registered for *file_key*. + Raises ``KeyError`` if no doc has been registered via + ``register_doc``. + """ async with self.loro_docs_lock: - if file_key in self.loro_docs: - doc = self.loro_docs[file_key] - # Cancel existing cleaner task if it exists - cleaner = self.loro_docs_cleaners.get(file_key, None) - if cleaner is not None: - LOGGER.debug( - f"RTC: Cancelling existing cleaner for file {file_key}" - ) - cleaner.cancel() - self.loro_docs_cleaners[file_key] = None - else: - LOGGER.warning(f"RTC: Expected loro doc for file {file_key}") - doc = LoroDoc() # type: ignore[no-untyped-call] - self.loro_docs[file_key] = doc - return doc + if file_key not in self.loro_docs: + raise KeyError(f"No LoroDoc registered for file {file_key!r}") + return self.loro_docs[file_key] def add_client_to_doc( self, file_key: MarimoFileKey, update_queue: asyncio.Queue[bytes] @@ -129,39 +99,23 @@ async def remove_client( file_key: MarimoFileKey, update_queue: asyncio.Queue[bytes], ) -> None: - """Clean up a loro client and potentially the doc if no clients remain.""" - should_create_cleaner = False + """Remove an RTC client queue. + The LoroDoc itself is *not* cleaned up when clients disconnect — + it is owned by the session's ``NotebookDocument`` and lives for + the session's lifetime. Only the client tracking set is updated. + """ async with self.loro_docs_lock: if file_key not in self.loro_docs_clients: return - - self.loro_docs_clients[file_key].remove(update_queue) - # If no clients are connected, set up a cleaner task - if len(self.loro_docs_clients[file_key]) == 0: - # Remove any existing cleaner - cleaner = self.loro_docs_cleaners.get(file_key, None) - if cleaner is not None: - cleaner.cancel() - self.loro_docs_cleaners[file_key] = None - should_create_cleaner = True - - # Create the cleaner task outside the lock to avoid deadlocks - if should_create_cleaner: - self.loro_docs_cleaners[file_key] = asyncio.create_task( - self._clean_loro_doc(file_key, 60.0) - ) - - async def _do_remove_doc(self, file_key: MarimoFileKey) -> None: - """Actual implementation of removing a doc, separate from remove_doc to avoid deadlocks.""" - if file_key in self.loro_docs: - del self.loro_docs[file_key] - if file_key in self.loro_docs_clients: - del self.loro_docs_clients[file_key] - if file_key in self.loro_docs_cleaners: - del self.loro_docs_cleaners[file_key] + self.loro_docs_clients[file_key].discard(update_queue) async def remove_doc(self, file_key: MarimoFileKey) -> None: - """Remove a loro doc and all associated clients""" + """Unregister a LoroDoc and all associated client queues. + + Called when the session closes. + """ async with self.loro_docs_lock: - await self._do_remove_doc(file_key) + self.loro_docs.pop(file_key, None) + self.loro_docs_clients.pop(file_key, None) + self._subscriptions.pop(file_key, None) From 5670c97df3cec8986532bedbbe998920cb98df85 Mon Sep 17 00:00:00 2001 From: Trevor Manz Date: Mon, 23 Mar 2026 23:25:13 -0400 Subject: [PATCH 4/6] Route cell text through Loro on the frontend Now that the backend LoroDoc owns cell text, the frontend no longer needs to send `set-code` ops through the document transaction API or apply them from server notifications. Code changes flow exclusively through the Loro WebSocket sync. `cellCodeEditing` skips Loro-originated CodeMirror changes to prevent them from round-tripping back through the reducer and middleware. --- frontend/src/core/codemirror/cells/extensions.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/frontend/src/core/codemirror/cells/extensions.ts b/frontend/src/core/codemirror/cells/extensions.ts index 159a4505ee4..ddbeccc60e5 100644 --- a/frontend/src/core/codemirror/cells/extensions.ts +++ b/frontend/src/core/codemirror/cells/extensions.ts @@ -317,6 +317,16 @@ function cellKeymaps({ function cellCodeEditing(hotkeys: HotkeyProvider): Extension[] { const onChangePlugin = EditorView.updateListener.of((update) => { if (update.docChanged) { + // Skip changes that came from Loro sync (RTC) — these are + // already reflected in the shared LoroDoc and don't need to + // round-trip through the transaction middleware. + const isLoroSync = update.transactions.some( + (tr) => tr.annotation(loroSyncAnnotation) != null, + ); + if (isLoroSync) { + return; + } + // Check if the doc update was a formatting change // e.g. changing from python to markdown const isFormattingChange = update.transactions.some((tr) => From f45ac4eb971f41ee83052bf6ea6003fc168027b6 Mon Sep 17 00:00:00 2001 From: Trevor Manz Date: Tue, 24 Mar 2026 00:03:37 -0400 Subject: [PATCH 5/6] Clean up Loro integration and fix stale RTC tests tentralizes loro type-stub workarounds in `_loro.py` by adding `unwrap_text` alongside the existing constructor wrappers, removing the inline `type: ignore` from `document.py`. Guards the RTC broadcast callback against `asyncio.QueueFull` for slow consumers. Updates the RTC test suite to match the new `register_doc`/`get_doc` API and removes tests for the deleted cleanup timer (the deadlock scenario they guarded against no longer exists since the LoroDoc now lives for the session's lifetime with no background cleanup task). --- marimo/_notebook/_loro.py | 18 +- marimo/_notebook/document.py | 14 +- marimo/_server/rtc/doc.py | 11 +- tests/_server/api/endpoints/test_ws_rtc.py | 4 +- tests/_server/rtc/test_rtc_doc.py | 245 +++++++-------------- 5 files changed, 114 insertions(+), 178 deletions(-) diff --git a/marimo/_notebook/_loro.py b/marimo/_notebook/_loro.py index 646cb37e82d..698b1dc79f3 100644 --- a/marimo/_notebook/_loro.py +++ b/marimo/_notebook/_loro.py @@ -1,14 +1,15 @@ # Copyright 2026 Marimo. All rights reserved. -"""Typed wrappers for ``loro`` constructors. +"""Typed wrappers for ``loro`` APIs with incomplete stubs. -The ``loro`` stubs omit return types on ``__new__``, which triggers -mypy ``no-untyped-call``. These helpers provide correctly-typed -construction so the rest of the codebase stays clean. +The ``loro`` stubs omit return types on ``__new__`` and the +``ValueOrContainer`` union lacks a typed ``.container`` accessor. +These helpers isolate the ``type: ignore`` comments so the rest of +the codebase stays clean. """ from __future__ import annotations -from loro import LoroDoc, LoroText +from loro import LoroDoc, LoroText, ValueOrContainer def create_doc() -> LoroDoc: @@ -17,3 +18,10 @@ def create_doc() -> LoroDoc: def create_text() -> LoroText: return LoroText() # type: ignore[no-untyped-call] + + +def unwrap_text(val: ValueOrContainer) -> LoroText: + """Extract a ``LoroText`` from a ``ValueOrContainer``.""" + container = val.container # type: ignore[union-attr,attr-defined] + assert isinstance(container, LoroText) + return container diff --git a/marimo/_notebook/document.py b/marimo/_notebook/document.py index 1c2258dfc96..b87552e0e48 100644 --- a/marimo/_notebook/document.py +++ b/marimo/_notebook/document.py @@ -18,11 +18,10 @@ from collections.abc import Generator, Iterable, Iterator import msgspec -from loro import LoroDoc, LoroText from msgspec.structs import replace as structs_replace from marimo._ast.cell import CellConfig -from marimo._notebook._loro import create_doc, create_text +from marimo._notebook._loro import create_doc, create_text, unwrap_text from marimo._messaging.notebook.changes import ( CreateCell, DeleteCell, @@ -36,6 +35,11 @@ ) from marimo._types.ids import CellId_t +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + from loro import LoroDoc, LoroText + class NotebookCell(msgspec.Struct, frozen=True): """Read-only snapshot of a cell, materialized from CellMeta + LoroText. @@ -247,9 +251,7 @@ def _apply_change(self, change: DocumentChange) -> None: self._cell_metas = reordered elif isinstance(change, SetCode): - # Verify cell exists self._find_meta(change.cell_id) - # Full replace in Loro text = self._get_loro_text(change.cell_id) if text.len_unicode > 0: text.delete(0, text.len_unicode) @@ -291,9 +293,7 @@ def _get_loro_text(self, cell_id: CellId_t) -> LoroText: val = self._codes_map.get(cell_id) if val is None: raise KeyError(f"No LoroText for cell {cell_id!r}") - container = val.container # type: ignore[union-attr,attr-defined] - assert isinstance(container, LoroText) - return container + return unwrap_text(val) def _find_index(self, cell_id: CellId_t) -> int: for i, m in enumerate(self._cell_metas): diff --git a/marimo/_server/rtc/doc.py b/marimo/_server/rtc/doc.py index 0fd956068d0..d0ff65c4820 100644 --- a/marimo/_server/rtc/doc.py +++ b/marimo/_server/rtc/doc.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from marimo import _loggers from marimo._server.file_router import MarimoFileKey @@ -54,7 +54,12 @@ async def register_doc( def _on_local_update(update: bytes) -> bool: clients = self.loro_docs_clients.get(file_key, set()) for client in clients: - client.put_nowait(update) + try: + client.put_nowait(update) + except asyncio.QueueFull: + LOGGER.warning( + "RTC: client queue full, dropping update" + ) return True # keep subscription alive self._subscriptions[file_key] = doc.subscribe_local_update( @@ -85,7 +90,7 @@ async def broadcast_update( self, file_key: MarimoFileKey, message: bytes, - exclude_queue: Optional[asyncio.Queue[bytes]] = None, + exclude_queue: asyncio.Queue[bytes] | None = None, ) -> None: """Broadcast an update to all clients except the excluded queue.""" clients = self.loro_docs_clients[file_key] diff --git a/tests/_server/api/endpoints/test_ws_rtc.py b/tests/_server/api/endpoints/test_ws_rtc.py index 6059a364153..b14ee8c7725 100644 --- a/tests/_server/api/endpoints/test_ws_rtc.py +++ b/tests/_server/api/endpoints/test_ws_rtc.py @@ -29,12 +29,12 @@ async def setup_loro_docs() -> AsyncGenerator[None, None]: # Clear any existing loro docs DOC_MANAGER.loro_docs.clear() DOC_MANAGER.loro_docs_clients.clear() - DOC_MANAGER.loro_docs_cleaners.clear() + DOC_MANAGER._subscriptions.clear() yield # Cleanup after test DOC_MANAGER.loro_docs.clear() DOC_MANAGER.loro_docs_clients.clear() - DOC_MANAGER.loro_docs_cleaners.clear() + DOC_MANAGER._subscriptions.clear() @contextmanager diff --git a/tests/_server/rtc/test_rtc_doc.py b/tests/_server/rtc/test_rtc_doc.py index 07be31cdf6d..afc35df2c0f 100644 --- a/tests/_server/rtc/test_rtc_doc.py +++ b/tests/_server/rtc/test_rtc_doc.py @@ -26,50 +26,68 @@ async def setup_doc_manager() -> AsyncGenerator[None, None]: # Clear any existing loro docs doc_manager.loro_docs.clear() doc_manager.loro_docs_clients.clear() - doc_manager.loro_docs_cleaners.clear() + doc_manager._subscriptions.clear() yield # Cleanup after test doc_manager.loro_docs.clear() doc_manager.loro_docs_clients.clear() - doc_manager.loro_docs_cleaners.clear() + doc_manager._subscriptions.clear() @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) -async def test_quick_reconnection(setup_doc_manager: None) -> None: - """Test that quick reconnection properly handles cleanup task cancellation""" +async def test_register_doc(setup_doc_manager: None) -> None: + """Test registering a LoroDoc with the manager.""" del setup_doc_manager - # Setup file_key = MarimoFileKey("test_file") - - # Create initial loro_doc doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc - # Setup client queue - update_queue = asyncio.Queue[bytes]() - doc_manager.loro_docs_clients[file_key] = {update_queue} + await doc_manager.register_doc(file_key, doc) - # Start cleanup task - cleanup_task = asyncio.create_task(doc_manager._clean_loro_doc(file_key)) + assert file_key in doc_manager.loro_docs + assert doc_manager.loro_docs[file_key] is doc + assert file_key in doc_manager._subscriptions - # Simulate quick reconnection by creating a new client before cleanup finishes - new_queue = asyncio.Queue[bytes]() - doc_manager.loro_docs_clients[file_key].add(new_queue) - # Cancel cleanup task - cleanup_task.cancel() - try: - await cleanup_task - except asyncio.CancelledError: - pass +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_register_doc_idempotent(setup_doc_manager: None) -> None: + """Registering the same file_key twice keeps the first doc.""" + del setup_doc_manager + file_key = MarimoFileKey("test_file") + doc1 = LoroDoc() + doc2 = LoroDoc() - # Verify state - assert len(doc_manager.loro_docs) == 1 - assert ( - len(doc_manager.loro_docs_clients[file_key]) == 2 - ) # Original client + reconnected client + await doc_manager.register_doc(file_key, doc1) + await doc_manager.register_doc(file_key, doc2) + + assert doc_manager.loro_docs[file_key] is doc1 + + +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_get_doc(setup_doc_manager: None) -> None: + """Test retrieving a registered doc.""" + del setup_doc_manager + file_key = MarimoFileKey("test_file") + doc = LoroDoc() + await doc_manager.register_doc(file_key, doc) + + result = await doc_manager.get_doc(file_key) + assert result is doc + + +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_get_doc_missing(setup_doc_manager: None) -> None: + """Getting an unregistered doc raises KeyError.""" + del setup_doc_manager + with pytest.raises(KeyError): + await doc_manager.get_doc(MarimoFileKey("missing")) @pytest.mark.skipif( @@ -81,14 +99,15 @@ async def test_two_users_sync(setup_doc_manager: None) -> None: file_key = MarimoFileKey("test_file") cell_id = str(CellId_t("test_cell")) # Convert CellId to string for loro - # First user connects + # Register the doc doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Setup client queues for both users queue1 = asyncio.Queue[bytes]() queue2 = asyncio.Queue[bytes]() - doc_manager.loro_docs_clients[file_key] = {queue1, queue2} + doc_manager.add_client_to_doc(file_key, queue1) + doc_manager.add_client_to_doc(file_key, queue2) # Get maps from doc doc_codes = doc.get_map("codes") @@ -123,21 +142,17 @@ async def test_two_users_sync(setup_doc_manager: None) -> None: @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) -async def test_concurrent_doc_creation(setup_doc_manager: None) -> None: - """Test concurrent doc creation doesn't cause issues""" +async def test_concurrent_registration(setup_doc_manager: None) -> None: + """Test concurrent doc registration doesn't cause issues""" del setup_doc_manager file_key = MarimoFileKey("test_file") - cell_ids = (CellId_t("cell1"), CellId_t("cell2")) - codes = ("print('hello')", "print('world')") + doc = LoroDoc() - # Create multiple tasks that try to create the same doc - tasks = [ - doc_manager.create_doc(file_key, cell_ids, codes) for _ in range(5) - ] - docs = await asyncio.gather(*tasks) + # Create multiple tasks that try to register the same doc + tasks = [doc_manager.register_doc(file_key, doc) for _ in range(5)] + await asyncio.gather(*tasks) - # All tasks should return the same doc instance - assert all(doc is docs[0] for doc in docs) + # Only one doc should be registered assert len(doc_manager.loro_docs) == 1 @@ -151,7 +166,7 @@ async def test_concurrent_client_operations( del setup_doc_manager file_key = MarimoFileKey("test_file") doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Create multiple client queues queues = [asyncio.Queue[bytes]() for _ in range(5)] @@ -170,40 +185,6 @@ async def client_operation(queue: asyncio.Queue[bytes]) -> None: assert len(doc_manager.loro_docs_clients[file_key]) == 0 -@pytest.mark.skipif( - "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" -) -async def test_cleanup_task_management(setup_doc_manager: None) -> None: - """Test cleanup task management and cancellation""" - del setup_doc_manager - file_key = MarimoFileKey("test_file") - doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc - - # Add and remove a client to trigger cleanup - queue = asyncio.Queue[bytes]() - doc_manager.add_client_to_doc(file_key, queue) - await doc_manager.remove_client(file_key, queue) - - # Verify cleanup task was created - assert file_key in doc_manager.loro_docs_cleaners - assert doc_manager.loro_docs_cleaners[file_key] is not None - - # Add a new client before cleanup finishes - new_queue = asyncio.Queue[bytes]() - doc_manager.add_client_to_doc(file_key, new_queue) - - # Wait for the task to be cancelled - await asyncio.sleep(0.1) - - # Verify cleanup task was cancelled and removed - # TODO: not sure why this is still here. - # assert doc_manager.loro_docs_cleaners[file_key] is None - - # Clean up - await doc_manager.remove_client(file_key, new_queue) - - @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) @@ -212,7 +193,7 @@ async def test_broadcast_update(setup_doc_manager: None) -> None: del setup_doc_manager file_key = MarimoFileKey("test_file") doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Create multiple client queues queues = [asyncio.Queue[bytes]() for _ in range(3)] @@ -232,6 +213,29 @@ async def test_broadcast_update(setup_doc_manager: None) -> None: assert await queue.get() == message +@pytest.mark.skipif( + "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" +) +async def test_local_update_broadcast(setup_doc_manager: None) -> None: + """Server-side Loro mutations are broadcast to RTC clients.""" + del setup_doc_manager + file_key = MarimoFileKey("test_file") + doc = LoroDoc() + await doc_manager.register_doc(file_key, doc) + + queue: asyncio.Queue[bytes] = asyncio.Queue() + doc_manager.add_client_to_doc(file_key, queue) + + # Mutate the doc server-side (simulates SetCode via NotebookDocument) + codes = doc.get_map("codes") + text = codes.get_or_create_container("cell1", LoroText()) + text.insert(0, "x = 1") + doc.commit() + + # The subscription should have enqueued the update + assert not queue.empty() + + @pytest.mark.skipif( "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" ) @@ -242,7 +246,7 @@ async def test_remove_nonexistent_doc(setup_doc_manager: None) -> None: await doc_manager.remove_doc(file_key) assert file_key not in doc_manager.loro_docs assert file_key not in doc_manager.loro_docs_clients - assert file_key not in doc_manager.loro_docs_cleaners + assert file_key not in doc_manager._subscriptions @pytest.mark.skipif( @@ -265,7 +269,7 @@ async def test_concurrent_doc_removal(setup_doc_manager: None) -> None: del setup_doc_manager file_key = MarimoFileKey("test_file") doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc + await doc_manager.register_doc(file_key, doc) # Create multiple tasks that try to remove the same doc tasks = [doc_manager.remove_doc(file_key) for _ in range(5)] @@ -274,85 +278,4 @@ async def test_concurrent_doc_removal(setup_doc_manager: None) -> None: # Verify doc was removed assert file_key not in doc_manager.loro_docs assert file_key not in doc_manager.loro_docs_clients - assert file_key not in doc_manager.loro_docs_cleaners - - -@pytest.mark.skipif( - "sys.version_info < (3, 11) or sys.version_info >= (3, 14)" -) -async def test_prevent_lock_deadlock(setup_doc_manager: None) -> None: - """Test that our deadlock prevention measures work correctly. - - This test simulates the scenario that could cause a deadlock: - 1. A client disconnects, starting the cleanup process - 2. Another operation acquires the lock before cleanup timer finishes - 3. Cleanup timer expires and tries to acquire the lock - - The fixed implementation should handle this without deadlocking. - """ - del setup_doc_manager - file_key = MarimoFileKey("test_file") - - # Create a doc and add a client - doc = LoroDoc() - doc_manager.loro_docs[file_key] = doc - queue = asyncio.Queue[bytes]() - doc_manager.add_client_to_doc(file_key, queue) - - # Set a very short cleanup timeout for testing - original_timeout = 60.0 - cleanup_timeout = 0.1 # 100ms - - # Create a barrier to coordinate tasks - barrier = asyncio.Barrier(2) - long_operation_done = asyncio.Event() - - # Task 1: Remove client, which will schedule cleanup with short timeout - async def remove_client_task() -> None: - await doc_manager.remove_client(file_key, queue) - # Wait at barrier to synchronize with the long operation - await barrier.wait() - # Wait for long operation to complete - await long_operation_done.wait() - - # Task 2: Simulate a long operation that holds the lock - async def long_lock_operation() -> None: - # Wait for remove_client to schedule the cleanup - await barrier.wait() - - # Acquire the lock and hold it for longer than the cleanup timeout - async with doc_manager.loro_docs_lock: - # Sleep while holding the lock (longer than cleanup timeout) - await asyncio.sleep(cleanup_timeout * 2) - - # Signal that we're done holding the lock - long_operation_done.set() - - # Modified test version of _clean_loro_doc with shorter timeout - original_clean_loro_doc = doc_manager._clean_loro_doc - - async def test_clean_loro_doc( - file_key: MarimoFileKey, timeout: float = original_timeout - ) -> None: - del timeout - # Override timeout with our test value - await original_clean_loro_doc(file_key, cleanup_timeout) - - # Override the method for this test - doc_manager._clean_loro_doc = test_clean_loro_doc - - try: - # Run both tasks simultaneously - task1 = asyncio.create_task(remove_client_task()) - task2 = asyncio.create_task(long_lock_operation()) - - # This should complete without deadlocking - await asyncio.gather(task1, task2) - - # Verify the doc was properly cleaned up - assert file_key not in doc_manager.loro_docs - assert file_key not in doc_manager.loro_docs_clients - assert file_key not in doc_manager.loro_docs_cleaners - finally: - # Restore the original method - doc_manager._clean_loro_doc = original_clean_loro_doc + assert file_key not in doc_manager._subscriptions From 138999bb48969ae83b1eb7f98d0be2cdaadb4a16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Mar 2026 17:15:33 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- marimo/_notebook/__init__.py | 2 +- marimo/_notebook/document.py | 6 ++++-- tests/_notebook/test_document.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/marimo/_notebook/__init__.py b/marimo/_notebook/__init__.py index 153a7bd09da..de4330833b5 100644 --- a/marimo/_notebook/__init__.py +++ b/marimo/_notebook/__init__.py @@ -1,7 +1,6 @@ # Copyright 2026 Marimo. All rights reserved. """Notebook document model — canonical representation of notebook structure.""" -from marimo._notebook.document import CellMeta, NotebookCell, NotebookDocument from marimo._messaging.notebook.changes import ( CreateCell, DeleteCell, @@ -13,6 +12,7 @@ SetName, Transaction, ) +from marimo._notebook.document import CellMeta, NotebookCell, NotebookDocument __all__ = [ "CellMeta", diff --git a/marimo/_notebook/document.py b/marimo/_notebook/document.py index b87552e0e48..d5430c76c2b 100644 --- a/marimo/_notebook/document.py +++ b/marimo/_notebook/document.py @@ -21,7 +21,6 @@ from msgspec.structs import replace as structs_replace from marimo._ast.cell import CellConfig -from marimo._notebook._loro import create_doc, create_text, unwrap_text from marimo._messaging.notebook.changes import ( CreateCell, DeleteCell, @@ -33,6 +32,7 @@ SetName, Transaction, ) +from marimo._notebook._loro import create_doc, create_text, unwrap_text from marimo._types.ids import CellId_t if TYPE_CHECKING: @@ -210,7 +210,9 @@ def _apply_change(self, change: DocumentChange) -> None: text.insert(0, change.code) self._codes_map.insert_container(change.cell_id, text) - meta = CellMeta(id=change.cell_id, name=change.name, config=change.config) + meta = CellMeta( + id=change.cell_id, name=change.name, config=change.config + ) if change.after is not None: idx = self._find_index(change.after) self._cell_metas.insert(idx + 1, meta) diff --git a/tests/_notebook/test_document.py b/tests/_notebook/test_document.py index a956df8f156..2a6e061f2dc 100644 --- a/tests/_notebook/test_document.py +++ b/tests/_notebook/test_document.py @@ -6,7 +6,6 @@ from loro import LoroDoc from marimo._ast.cell import CellConfig -from marimo._notebook.document import NotebookDocument from marimo._messaging.notebook.changes import ( CreateCell, DeleteCell, @@ -18,6 +17,7 @@ SetName, Transaction, ) +from marimo._notebook.document import NotebookDocument from marimo._types.ids import CellId_t # ------------------------------------------------------------------