Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import argparse
import asyncio
import logging
import os
import re
import sys

Expand Down Expand Up @@ -294,11 +293,12 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data

# Load existing chain from disk, or start fresh
chain = None
if datadir and os.path.exists(os.path.join(datadir, "data.json")):
if datadir:
try:
from minichain.persistence import load
chain = load(datadir)
logger.info("Restored chain from '%s'", datadir)
from minichain.persistence import load, persistence_exists
if persistence_exists(datadir):
chain = load(datadir)
logger.info("Restored chain from '%s'", datadir)
except FileNotFoundError as e:
logger.warning("Could not load saved chain: %s — starting fresh", e)
except ValueError as e:
Expand Down
263 changes: 177 additions & 86 deletions minichain/persistence.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,56 @@
"""
Chain persistence: save and load the blockchain and state to/from JSON.
Chain persistence: save and load the blockchain and state to/from SQLite.

Design:
- blockchain.json holds the full list of serialised blocks
- state.json holds the accounts dict (includes off-chain credits)
- data.db holds the full chain snapshot, account state, and small metadata.
- legacy data.json snapshots can still be loaded for backward compatibility.

Both files are written atomically (temp → rename) to prevent corruption
on crash. On load, chain integrity is verified before the data is trusted.

Usage:
The public API intentionally stays the same:
from minichain.persistence import save, load

save(blockchain, path="data/")
blockchain = load(path="data/")
"""

from __future__ import annotations

import json
import os
import tempfile
import logging
import copy
import os
import sqlite3
from typing import Any

from .block import Block
from .chain import Blockchain, validate_block_link_and_hash

logger = logging.getLogger(__name__)

_DATA_FILE = "data.json"
_DB_FILE = "data.db"
_LEGACY_DATA_FILE = "data.json"


def persistence_exists(path: str = ".") -> bool:
"""Return True if a SQLite or legacy JSON snapshot exists inside *path*."""
return os.path.exists(os.path.join(path, _DB_FILE)) or os.path.exists(
os.path.join(path, _LEGACY_DATA_FILE)
)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def save(blockchain: Blockchain, path: str = ".") -> None:
"""
Persist the blockchain and account state to a JSON file inside *path*.

Uses atomic write (write-to-temp → rename) with fsync so a crash mid-save
never corrupts the existing file. Chain and state are saved together to
prevent torn snapshots.
"""
def save(blockchain: Blockchain, path: str = ".") -> None:
"""Persist the blockchain and account state to SQLite inside *path*."""
os.makedirs(path, exist_ok=True)
db_path = os.path.join(path, _DB_FILE)

with blockchain._lock: # Thread-safe: hold lock while serialising
with blockchain._lock:
chain_data = [block.to_dict() for block in blockchain.chain]
state_data = copy.deepcopy(blockchain.state.accounts)
state_data = json.loads(json.dumps(blockchain.state.accounts))

snapshot = {
"chain": chain_data,
"state": state_data
}

_atomic_write_json(os.path.join(path, _DATA_FILE), snapshot)
_save_snapshot_to_sqlite(db_path, {"chain": chain_data, "state": state_data})
Comment on lines +49 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider using copy.deepcopy() instead of JSON round-trip.

The json.loads(json.dumps(...)) pattern works but is less efficient than copy.deepcopy() for in-memory deep copying. However, if the intent is to validate JSON serializability at save time, this is acceptable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 49 - 53, The current snapshot uses
json.loads(json.dumps(blockchain.state.accounts)) to deep-copy state which is
less efficient; replace the JSON round-trip with copy.deepcopy to produce an
in-memory deep copy of blockchain.state.accounts while holding blockchain._lock
before calling _save_snapshot_to_sqlite(db_path, {"chain": chain_data, "state":
state_data}); update the import to include copy and use
copy.deepcopy(blockchain.state.accounts) and leave chain_data creation via
[block.to_dict() for block in blockchain.chain] as-is (unless you also want to
validate JSON serializability, in which case keep the round-trip).


logger.info(
"Saved %d blocks and %d accounts to '%s'",
Expand All @@ -63,43 +61,48 @@ def save(blockchain: Blockchain, path: str = ".") -> None:


def load(path: str = ".") -> Blockchain:
"""
Restore a Blockchain from the JSON file inside *path*.

Steps:
1. Load and deserialise blocks from data.json
2. Verify chain integrity (genesis, linkage, hashes)
3. Load account state
"""Restore a Blockchain from SQLite inside *path* (with legacy JSON fallback)."""
db_path = os.path.join(path, _DB_FILE)
legacy_path = os.path.join(path, _LEGACY_DATA_FILE)

Raises:
FileNotFoundError: if data.json is missing.
ValueError: if data is invalid or integrity checks fail.
"""
data_path = os.path.join(path, _DATA_FILE)
snapshot = _read_json(data_path)
if os.path.exists(db_path):
snapshot = _load_snapshot_from_sqlite(db_path)
elif os.path.exists(legacy_path):
snapshot = _read_legacy_json(legacy_path)
else:
raise FileNotFoundError(f"Persistence file not found in '{path}'")

if not isinstance(snapshot, dict):
raise ValueError(f"Invalid snapshot data in '{data_path}'")
raise ValueError(f"Invalid snapshot data in '{path}'")

raw_blocks = snapshot.get("chain")
raw_accounts = snapshot.get("state")

if not isinstance(raw_blocks, list) or not raw_blocks:
raise ValueError(f"Invalid or empty chain data in '{data_path}'")
raise ValueError(f"Invalid or empty chain data in '{path}'")
if not isinstance(raw_accounts, dict):
raise ValueError(f"Invalid accounts data in '{data_path}'")
raise ValueError(f"Invalid accounts data in '{path}'")

blocks = [_deserialize_block(b) for b in raw_blocks]
blocks = []
for raw_block in raw_blocks:
if not isinstance(raw_block, dict):
raise ValueError(f"Invalid chain data in '{path}'")
try:
blocks.append(_deserialize_block(raw_block))
except (KeyError, TypeError, ValueError) as exc:
raise ValueError(f"Invalid chain data in '{path}'") from exc

# --- Integrity verification ---
_verify_chain_integrity(blocks)
normalized_accounts = {}
for address, account in raw_accounts.items():
if not isinstance(address, str) or not isinstance(account, dict):
raise ValueError(f"Invalid accounts data in '{path}'")
normalized_accounts[address] = account

Comment on lines +95 to 100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the full account schema before accepting it.

{} or {"storage": []} passes the current dict check, but State assumes balance/nonce are integers, code is str | None, and storage is a dict. That lets a corrupted snapshot boot and fail later on the first account read.

🛡️ Suggested validation
     normalized_accounts = {}
     for address, account in raw_accounts.items():
         if not isinstance(address, str) or not isinstance(account, dict):
             raise ValueError(f"Invalid accounts data in '{path}'")
+        missing = {"balance", "nonce", "code", "storage"} - account.keys()
+        if missing:
+            raise ValueError(f"Invalid accounts data in '{path}'")
+        if (
+            type(account["balance"]) is not int
+            or type(account["nonce"]) is not int
+            or (account["code"] is not None and not isinstance(account["code"], str))
+            or not isinstance(account["storage"], dict)
+        ):
+            raise ValueError(f"Invalid accounts data in '{path}'")
         normalized_accounts[address] = account
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 98-98: Prefer TypeError exception for invalid type

(TRY004)


[warning] 98-98: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 95 - 100, The loop that builds
normalized_accounts from raw_accounts must validate the full account schema
before accepting entries: in the block that iterates raw_accounts (where
normalized_accounts and path are used), ensure each account dict contains
required keys "balance" (int), "nonce" (int), "code" (str or None), and
"storage" (dict), and that no extra malformed types are present; implement or
call a small helper like validate_account_schema(account, path, address) that
raises ValueError including path/address on any mismatch, and only then assign
normalized_accounts[address] = account so corrupted snapshots are rejected early
and State (which expects balance/nonce/code/storage types) never receives
invalid data.

# --- Rebuild blockchain properly (no __new__ hack) ---
blockchain = Blockchain() # creates genesis + fresh state
blockchain.chain = blocks # replace with loaded chain
_verify_chain_integrity(blocks)

# Restore state
blockchain.state.accounts = raw_accounts
blockchain = Blockchain()
blockchain.chain = blocks
blockchain.state.accounts = normalized_accounts

logger.info(
"Loaded %d blocks and %d accounts from '%s'",
Expand All @@ -114,14 +117,13 @@ def load(path: str = ".") -> Blockchain:
# Integrity verification
# ---------------------------------------------------------------------------

def _verify_chain_integrity(blocks: list) -> None:

def _verify_chain_integrity(blocks: list[Block]) -> None:
"""Verify genesis, hash linkage, and block hashes."""
# Check genesis
genesis = blocks[0]
if genesis.index != 0 or genesis.hash != "0" * 64:
raise ValueError("Invalid genesis block")

# Check linkage and hashes for every subsequent block
for i in range(1, len(blocks)):
block = blocks[i]
prev = blocks[i - 1]
Expand All @@ -132,46 +134,135 @@ def _verify_chain_integrity(blocks: list) -> None:


# ---------------------------------------------------------------------------
# Helpers
# SQLite helpers
# ---------------------------------------------------------------------------

def _atomic_write_json(filepath: str, data) -> None:
"""Write JSON atomically with fsync for durability."""
dir_name = os.path.dirname(filepath) or "."
fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp")

def _connect(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn


def _initialize_schema(conn: sqlite3.Connection) -> None:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS blocks (
height INTEGER PRIMARY KEY,
block_json TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS accounts (
address TEXT PRIMARY KEY,
account_json TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)


def _require_schema(conn: sqlite3.Connection) -> None:
required = {"blocks", "accounts", "metadata"}
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
existing = {row["name"] for row in rows}
if not required.issubset(existing):
raise ValueError("Missing persistence tables")


def _save_snapshot_to_sqlite(db_path: str, snapshot: dict[str, Any]) -> None:
conn = _connect(db_path)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno()) # Ensure data is on disk
os.replace(tmp_path, filepath) # Atomic rename

# Attempt to fsync the directory so the rename is durable
if hasattr(os, "O_DIRECTORY"):
try:
dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY)
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
except OSError:
pass # Directory fsync not supported on all platforms

except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
_initialize_schema(conn)
with conn:
conn.execute("DELETE FROM blocks")
conn.execute("DELETE FROM accounts")
conn.execute("DELETE FROM metadata")

for block in snapshot["chain"]:
conn.execute(
"INSERT INTO blocks (height, block_json) VALUES (?, ?)",
(int(block["index"]), json.dumps(block, sort_keys=True)),
)

for address, account in sorted(snapshot["state"].items()):
conn.execute(
"INSERT INTO accounts (address, account_json) VALUES (?, ?)",
(address, json.dumps(account, sort_keys=True)),
)

conn.execute(
"INSERT INTO metadata (key, value) VALUES (?, ?)",
("chain_length", str(len(snapshot["chain"]))),
)
finally:
conn.close()


def _load_snapshot_from_sqlite(db_path: str) -> dict[str, Any]:
try:
conn = _connect(db_path)
except sqlite3.DatabaseError as exc:
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc

try:
_require_schema(conn)
block_rows = conn.execute(
"SELECT block_json FROM blocks ORDER BY height ASC"
).fetchall()
account_rows = conn.execute(
"SELECT address, account_json FROM accounts ORDER BY address ASC"
).fetchall()
chain_length_row = conn.execute(
"SELECT value FROM metadata WHERE key = ?",
("chain_length",),
).fetchone()
except sqlite3.DatabaseError as exc:
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc
finally:
conn.close()

if chain_length_row is None:
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'")
try:
expected_chain_length = int(chain_length_row["value"])
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc
if expected_chain_length != len(block_rows):
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'")

try:
chain = [json.loads(row["block_json"]) for row in block_rows]
state = {
row["address"]: json.loads(row["account_json"])
for row in account_rows
}
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid persisted JSON payload in '{db_path}'") from exc

return {"chain": chain, "state": state}


# ---------------------------------------------------------------------------
# Legacy JSON helpers
# ---------------------------------------------------------------------------


def _read_json(filepath: str):
if not os.path.exists(filepath):
raise FileNotFoundError(f"Persistence file not found: '{filepath}'")
def _read_legacy_json(filepath: str) -> dict[str, Any]:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
Comment on lines +257 to 259
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Minor: Remove redundant mode argument.

The "r" mode is the default for open(). This is a nitpick from static analysis.

♻️ Suggested fix
 def _read_legacy_json(filepath: str) -> dict[str, Any]:
-    with open(filepath, "r", encoding="utf-8") as f:
+    with open(filepath, encoding="utf-8") as f:
         return json.load(f)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _read_legacy_json(filepath: str) -> dict[str, Any]:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
def _read_legacy_json(filepath: str) -> dict[str, Any]:
with open(filepath, encoding="utf-8") as f:
return json.load(f)
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 258-258: Unnecessary mode argument

Remove mode argument

(UP015)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 257 - 259, The open call in
_read_legacy_json unnecessarily passes the default mode "r"; remove the explicit
"r" argument so the function uses open(filepath, encoding="utf-8") when reading
JSON, keeping json.load(f) unchanged.



def _deserialize_block(data: dict) -> Block:
"""Reconstruct a Block (including its transactions) from a plain dict."""
# ---------------------------------------------------------------------------
# Block deserialisation
# ---------------------------------------------------------------------------


def _deserialize_block(data: dict[str, Any]) -> Block:
return Block.from_dict(data)
Loading
Loading