From 54f3fc99ef0495f22c570704a17d14ba0dbbb12d Mon Sep 17 00:00:00 2001 From: Arunabha Date: Sat, 28 Mar 2026 18:01:01 +0530 Subject: [PATCH 1/5] feat: add sqlite persistence backend --- main.py | 10 +- minichain/persistence.py | 224 +++++++++++++++++++++++--------------- tests/test_persistence.py | 117 +++++++++----------- 3 files changed, 196 insertions(+), 155 deletions(-) diff --git a/main.py b/main.py index e1edc51..4ff3d9f 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,6 @@ import argparse import asyncio import logging -import os import re import sys @@ -294,11 +293,12 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data # Load existing chain from disk, or start fresh chain = None - if datadir and os.path.exists(os.path.join(datadir, "data.json")): + if datadir: try: - from minichain.persistence import load - chain = load(datadir) - logger.info("Restored chain from '%s'", datadir) + from minichain.persistence import load, persistence_exists + if persistence_exists(datadir): + chain = load(datadir) + logger.info("Restored chain from '%s'", datadir) except FileNotFoundError as e: logger.warning("Could not load saved chain: %s — starting fresh", e) except ValueError as e: diff --git a/minichain/persistence.py b/minichain/persistence.py index b49f307..8a02aa5 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -1,58 +1,56 @@ """ -Chain persistence: save and load the blockchain and state to/from JSON. +Chain persistence: save and load the blockchain and state to/from SQLite. Design: - - blockchain.json holds the full list of serialised blocks - - state.json holds the accounts dict (includes off-chain credits) + - data.db holds the full chain snapshot, account state, and small metadata. + - legacy data.json snapshots can still be loaded for backward compatibility. -Both files are written atomically (temp → rename) to prevent corruption -on crash. On load, chain integrity is verified before the data is trusted. - -Usage: +The public API intentionally stays the same: from minichain.persistence import save, load save(blockchain, path="data/") blockchain = load(path="data/") """ +from __future__ import annotations + import json -import os -import tempfile import logging -import copy +import os +import sqlite3 +from typing import Any from .block import Block from .chain import Blockchain, validate_block_link_and_hash logger = logging.getLogger(__name__) -_DATA_FILE = "data.json" +_DB_FILE = "data.db" +_LEGACY_DATA_FILE = "data.json" + + +def persistence_exists(path: str = ".") -> bool: + """Return True if a SQLite or legacy JSON snapshot exists inside *path*.""" + return os.path.exists(os.path.join(path, _DB_FILE)) or os.path.exists( + os.path.join(path, _LEGACY_DATA_FILE) + ) # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def save(blockchain: Blockchain, path: str = ".") -> None: - """ - Persist the blockchain and account state to a JSON file inside *path*. - Uses atomic write (write-to-temp → rename) with fsync so a crash mid-save - never corrupts the existing file. Chain and state are saved together to - prevent torn snapshots. - """ +def save(blockchain: Blockchain, path: str = ".") -> None: + """Persist the blockchain and account state to SQLite inside *path*.""" os.makedirs(path, exist_ok=True) + db_path = os.path.join(path, _DB_FILE) - with blockchain._lock: # Thread-safe: hold lock while serialising + with blockchain._lock: chain_data = [block.to_dict() for block in blockchain.chain] - state_data = copy.deepcopy(blockchain.state.accounts) - - snapshot = { - "chain": chain_data, - "state": state_data - } + state_data = json.loads(json.dumps(blockchain.state.accounts)) - _atomic_write_json(os.path.join(path, _DATA_FILE), snapshot) + _save_snapshot_to_sqlite(db_path, {"chain": chain_data, "state": state_data}) logger.info( "Saved %d blocks and %d accounts to '%s'", @@ -63,42 +61,33 @@ def save(blockchain: Blockchain, path: str = ".") -> None: def load(path: str = ".") -> Blockchain: - """ - Restore a Blockchain from the JSON file inside *path*. + """Restore a Blockchain from SQLite inside *path* (with legacy JSON fallback).""" + db_path = os.path.join(path, _DB_FILE) + legacy_path = os.path.join(path, _LEGACY_DATA_FILE) - Steps: - 1. Load and deserialise blocks from data.json - 2. Verify chain integrity (genesis, linkage, hashes) - 3. Load account state - - Raises: - FileNotFoundError: if data.json is missing. - ValueError: if data is invalid or integrity checks fail. - """ - data_path = os.path.join(path, _DATA_FILE) - snapshot = _read_json(data_path) + if os.path.exists(db_path): + snapshot = _load_snapshot_from_sqlite(db_path) + elif os.path.exists(legacy_path): + snapshot = _read_legacy_json(legacy_path) + else: + raise FileNotFoundError(f"Persistence file not found in '{path}'") if not isinstance(snapshot, dict): - raise ValueError(f"Invalid snapshot data in '{data_path}'") + raise ValueError(f"Invalid snapshot data in '{path}'") raw_blocks = snapshot.get("chain") raw_accounts = snapshot.get("state") if not isinstance(raw_blocks, list) or not raw_blocks: - raise ValueError(f"Invalid or empty chain data in '{data_path}'") + raise ValueError(f"Invalid or empty chain data in '{path}'") if not isinstance(raw_accounts, dict): - raise ValueError(f"Invalid accounts data in '{data_path}'") + raise ValueError(f"Invalid accounts data in '{path}'") blocks = [_deserialize_block(b) for b in raw_blocks] - - # --- Integrity verification --- _verify_chain_integrity(blocks) - # --- Rebuild blockchain properly (no __new__ hack) --- - blockchain = Blockchain() # creates genesis + fresh state - blockchain.chain = blocks # replace with loaded chain - - # Restore state + blockchain = Blockchain() + blockchain.chain = blocks blockchain.state.accounts = raw_accounts logger.info( @@ -114,14 +103,13 @@ def load(path: str = ".") -> Blockchain: # Integrity verification # --------------------------------------------------------------------------- -def _verify_chain_integrity(blocks: list) -> None: + +def _verify_chain_integrity(blocks: list[Block]) -> None: """Verify genesis, hash linkage, and block hashes.""" - # Check genesis genesis = blocks[0] if genesis.index != 0 or genesis.hash != "0" * 64: raise ValueError("Invalid genesis block") - # Check linkage and hashes for every subsequent block for i in range(1, len(blocks)): block = blocks[i] prev = blocks[i - 1] @@ -132,46 +120,112 @@ def _verify_chain_integrity(blocks: list) -> None: # --------------------------------------------------------------------------- -# Helpers +# SQLite helpers # --------------------------------------------------------------------------- -def _atomic_write_json(filepath: str, data) -> None: - """Write JSON atomically with fsync for durability.""" - dir_name = os.path.dirname(filepath) or "." - fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp") + +def _connect(db_path: str) -> sqlite3.Connection: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + return conn + + +def _initialize_schema(conn: sqlite3.Connection) -> None: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS blocks ( + height INTEGER PRIMARY KEY, + block_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS accounts ( + address TEXT PRIMARY KEY, + account_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """ + ) + + +def _save_snapshot_to_sqlite(db_path: str, snapshot: dict[str, Any]) -> None: + conn = _connect(db_path) try: - with os.fdopen(fd, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) - f.flush() - os.fsync(f.fileno()) # Ensure data is on disk - os.replace(tmp_path, filepath) # Atomic rename - - # Attempt to fsync the directory so the rename is durable - if hasattr(os, "O_DIRECTORY"): - try: - dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY) - try: - os.fsync(dir_fd) - finally: - os.close(dir_fd) - except OSError: - pass # Directory fsync not supported on all platforms - - except BaseException: - try: - os.unlink(tmp_path) - except OSError: - pass - raise + _initialize_schema(conn) + with conn: + conn.execute("DELETE FROM blocks") + conn.execute("DELETE FROM accounts") + conn.execute("DELETE FROM metadata") + + for block in snapshot["chain"]: + conn.execute( + "INSERT INTO blocks (height, block_json) VALUES (?, ?)", + (int(block["index"]), json.dumps(block, sort_keys=True)), + ) + + for address, account in sorted(snapshot["state"].items()): + conn.execute( + "INSERT INTO accounts (address, account_json) VALUES (?, ?)", + (address, json.dumps(account, sort_keys=True)), + ) + + conn.execute( + "INSERT INTO metadata (key, value) VALUES (?, ?)", + ("chain_length", str(len(snapshot["chain"]))), + ) + finally: + conn.close() + + +def _load_snapshot_from_sqlite(db_path: str) -> dict[str, Any]: + try: + conn = _connect(db_path) + except sqlite3.DatabaseError as exc: + raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc + + try: + _initialize_schema(conn) + block_rows = conn.execute( + "SELECT block_json FROM blocks ORDER BY height ASC" + ).fetchall() + account_rows = conn.execute( + "SELECT address, account_json FROM accounts ORDER BY address ASC" + ).fetchall() + except sqlite3.DatabaseError as exc: + raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc + finally: + conn.close() + + try: + chain = [json.loads(row["block_json"]) for row in block_rows] + state = { + row["address"]: json.loads(row["account_json"]) + for row in account_rows + } + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid persisted JSON payload in '{db_path}'") from exc + return {"chain": chain, "state": state} -def _read_json(filepath: str): - if not os.path.exists(filepath): - raise FileNotFoundError(f"Persistence file not found: '{filepath}'") + +# --------------------------------------------------------------------------- +# Legacy JSON helpers +# --------------------------------------------------------------------------- + + +def _read_legacy_json(filepath: str) -> dict[str, Any]: with open(filepath, "r", encoding="utf-8") as f: return json.load(f) -def _deserialize_block(data: dict) -> Block: - """Reconstruct a Block (including its transactions) from a plain dict.""" +# --------------------------------------------------------------------------- +# Block deserialisation +# --------------------------------------------------------------------------- + + +def _deserialize_block(data: dict[str, Any]) -> Block: return Block.from_dict(data) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index e758227..e62dd73 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -1,18 +1,21 @@ -""" -Tests for chain persistence (save / load round-trip). -""" +"""Tests for chain persistence (save / load round-trip).""" import json import os import shutil +import sqlite3 import tempfile import unittest -from nacl.signing import SigningKey from nacl.encoding import HexEncoder +from nacl.signing import SigningKey + +from minichain import Block, Blockchain, Transaction, mine_block +from minichain.persistence import load, persistence_exists, save + -from minichain import Blockchain, Transaction, Block, mine_block -from minichain.persistence import save, load +DB_FILE = "data.db" +LEGACY_FILE = "data.json" def _make_keypair(): @@ -22,17 +25,13 @@ def _make_keypair(): class TestPersistence(unittest.TestCase): - def setUp(self): self.tmpdir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.tmpdir, ignore_errors=True) - # Helpers - def _chain_with_tx(self): - """Return a Blockchain that has one mined block with a transfer.""" bc = Blockchain() alice_sk, alice_pk = _make_keypair() _, bob_pk = _make_keypair() @@ -52,24 +51,21 @@ def _chain_with_tx(self): bc.add_block(block) return bc, alice_pk, bob_pk - # --- Basic save/load --- - - def test_save_creates_file(self): + def test_save_creates_sqlite_file(self): bc = Blockchain() save(bc, path=self.tmpdir) - self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "data.json"))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, DB_FILE))) + self.assertTrue(persistence_exists(self.tmpdir)) def test_chain_length_preserved(self): bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) self.assertEqual(len(restored.chain), len(bc.chain)) def test_block_hashes_preserved(self): bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) for original, loaded in zip(bc.chain, restored.chain): self.assertEqual(original.hash, loaded.hash) @@ -79,11 +75,9 @@ def test_block_hashes_preserved(self): def test_transaction_data_preserved(self): bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) original_tx = bc.chain[1].transactions[0] loaded_tx = restored.chain[1].transactions[0] - self.assertEqual(original_tx.sender, loaded_tx.sender) self.assertEqual(original_tx.receiver, loaded_tx.receiver) self.assertEqual(original_tx.amount, loaded_tx.amount) @@ -94,91 +88,71 @@ def test_genesis_only_chain(self): bc = Blockchain() save(bc, path=self.tmpdir) restored = load(path=self.tmpdir) - self.assertEqual(len(restored.chain), 1) self.assertEqual(restored.chain[0].hash, "0" * 64) - # --- State recomputation --- - - def test_state_recomputed_from_blocks(self): - """Balances must be recomputed by replaying blocks, not from a file.""" + def test_state_snapshot_preserved(self): bc, alice_pk, bob_pk = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) - # Alice started with 100, sent 30 → 70 self.assertEqual( restored.state.get_account(alice_pk)["balance"], bc.state.get_account(alice_pk)["balance"], ) - # Bob received 30 self.assertEqual( restored.state.get_account(bob_pk)["balance"], bc.state.get_account(bob_pk)["balance"], ) - # --- Integrity verification --- - def test_tampered_hash_rejected(self): - """Loading a chain with a tampered block hash must raise ValueError.""" bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - - # Tamper with block hash - chain_path = os.path.join(self.tmpdir, "data.json") - with open(chain_path, "r") as f: - data = json.load(f) - data["chain"][1]["hash"] = "deadbeef" * 8 - with open(chain_path, "w") as f: - json.dump(data, f) - + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + row = conn.execute("SELECT block_json FROM blocks WHERE height = 1").fetchone() + payload = json.loads(row[0]) + payload["hash"] = "deadbeef" * 8 + conn.execute( + "UPDATE blocks SET block_json = ? WHERE height = 1", + (json.dumps(payload),), + ) with self.assertRaises(ValueError): load(path=self.tmpdir) def test_broken_linkage_rejected(self): - """Loading a chain with broken previous_hash linkage must raise.""" bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - - chain_path = os.path.join(self.tmpdir, "data.json") - with open(chain_path, "r") as f: - data = json.load(f) - data["chain"][1]["previous_hash"] = "0" * 64 + "ff" - with open(chain_path, "w") as f: - json.dump(data, f) - + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + row = conn.execute("SELECT block_json FROM blocks WHERE height = 1").fetchone() + payload = json.loads(row[0]) + payload["previous_hash"] = "0" * 64 + "ff" + conn.execute( + "UPDATE blocks SET block_json = ? WHERE height = 1", + (json.dumps(payload),), + ) with self.assertRaises(ValueError): load(path=self.tmpdir) - # --- Crash safety --- - - def test_corrupted_json_raises(self): - """Half-written JSON must raise an error, not silently corrupt.""" + def test_corrupted_sqlite_payload_raises(self): bc = Blockchain() save(bc, path=self.tmpdir) - - # Corrupt the file - chain_path = os.path.join(self.tmpdir, "data.json") - with open(chain_path, "w") as f: - f.write('{"truncated": ') # invalid JSON - - with self.assertRaises(json.JSONDecodeError): + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + conn.execute("UPDATE blocks SET block_json = ? WHERE height = 0", ("{bad-json",)) + with self.assertRaises(ValueError): load(path=self.tmpdir) def test_missing_file_raises(self): with self.assertRaises(FileNotFoundError): - load(path=self.tmpdir) # nothing saved yet - - # --- Chain continuity after load --- + load(path=self.tmpdir) + self.assertFalse(persistence_exists(self.tmpdir)) def test_loaded_chain_can_add_new_block(self): - """Restored chain must still accept new valid blocks.""" - bc, alice_pk, bob_pk = self._chain_with_tx() + bc, _, bob_pk = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) - # Build a second transfer using a new key new_sk, new_pk = _make_keypair() restored.state.credit_mining_reward(new_pk, 50) @@ -196,6 +170,19 @@ def test_loaded_chain_can_add_new_block(self): self.assertTrue(restored.add_block(block2)) self.assertEqual(len(restored.chain), len(bc.chain) + 1) + def test_legacy_json_load_still_supported(self): + bc = Blockchain() + snapshot = { + "chain": [block.to_dict() for block in bc.chain], + "state": bc.state.accounts, + } + with open(os.path.join(self.tmpdir, LEGACY_FILE), "w", encoding="utf-8") as f: + json.dump(snapshot, f) + + restored = load(path=self.tmpdir) + self.assertEqual(len(restored.chain), 1) + self.assertTrue(persistence_exists(self.tmpdir)) + if __name__ == "__main__": unittest.main() From 1b4e96ce2937ab16a689e67bed017d166ccee7df Mon Sep 17 00:00:00 2001 From: Arunabha Date: Sat, 28 Mar 2026 18:14:39 +0530 Subject: [PATCH 2/5] test: cover sqlite persistence node lifecycle --- tests/test_persistence_runtime.py | 120 ++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/test_persistence_runtime.py diff --git a/tests/test_persistence_runtime.py b/tests/test_persistence_runtime.py new file mode 100644 index 0000000..21d437b --- /dev/null +++ b/tests/test_persistence_runtime.py @@ -0,0 +1,120 @@ +import tempfile +import shutil +import unittest +from unittest.mock import patch + +from nacl.encoding import HexEncoder +from nacl.signing import SigningKey + +import main as main_module +from minichain import Blockchain, Block, Transaction, mine_block +from minichain.persistence import load, save + + +class FakeNetwork: + def __init__(self): + self.handler = None + self.peer_count = 0 + self._on_peer_connected = None + + def register_handler(self, handler): + self.handler = handler + + def register_on_peer_connected(self, callback): + self._on_peer_connected = callback + + async def start(self, port=9000, host="127.0.0.1"): + self.port = port + self.host = host + + async def stop(self): + return None + + async def connect_to_peer(self, host, port): + self.peer_count += 1 + return True + + async def broadcast_transaction(self, tx): + return None + + async def broadcast_block(self, block, miner=None): + return None + + +def _make_keypair(): + sk = SigningKey.generate() + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + return sk, pk + + +class TestPersistenceRuntime(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def _chain_with_tx(self): + bc = Blockchain() + alice_sk, alice_pk = _make_keypair() + _, bob_pk = _make_keypair() + + bc.state.credit_mining_reward(alice_pk, 100) + tx = Transaction(alice_pk, bob_pk, 30, 0) + tx.sign(alice_sk) + + block = Block( + index=1, + previous_hash=bc.last_block.hash, + transactions=[tx], + difficulty=1, + ) + mine_block(block, difficulty=1) + bc.add_block(block) + return bc + + async def test_run_node_loads_existing_sqlite_snapshot(self): + chain = self._chain_with_tx() + save(chain, self.tmpdir) + + async def fake_cli_loop(sk, pk, loaded_chain, mempool, network): + self.assertEqual(len(loaded_chain.chain), len(chain.chain)) + self.assertEqual(loaded_chain.last_block.hash, chain.last_block.hash) + self.assertEqual(loaded_chain.state.accounts, chain.state.accounts) + + with patch.object(main_module, "P2PNetwork", FakeNetwork), patch.object( + main_module, "cli_loop", fake_cli_loop + ): + await main_module.run_node( + port=9400, + host="127.0.0.1", + connect_to=None, + fund=0, + datadir=self.tmpdir, + ) + + async def test_run_node_saves_sqlite_snapshot_on_shutdown(self): + fixed_sk, fixed_pk = _make_keypair() + + async def fake_cli_loop(sk, pk, chain, mempool, network): + self.assertEqual(pk, fixed_pk) + self.assertEqual(chain.state.get_account(pk)["balance"], 25) + + with patch.object(main_module, "P2PNetwork", FakeNetwork), patch.object( + main_module, "cli_loop", fake_cli_loop + ), patch.object(main_module, "create_wallet", return_value=(fixed_sk, fixed_pk)): + await main_module.run_node( + port=9401, + host="127.0.0.1", + connect_to=None, + fund=25, + datadir=self.tmpdir, + ) + + restored = load(self.tmpdir) + self.assertEqual(restored.state.get_account(fixed_pk)["balance"], 25) + self.assertEqual(len(restored.chain), 1) + + +if __name__ == "__main__": + unittest.main() From 5db569e53d8cb3c2b1916b81192245cae0e8878f Mon Sep 17 00:00:00 2001 From: Arunabha Date: Sat, 28 Mar 2026 20:00:29 +0530 Subject: [PATCH 3/5] fix: validate persisted snapshot rows on load --- minichain/persistence.py | 18 ++++++++++++++++-- tests/test_persistence.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/minichain/persistence.py b/minichain/persistence.py index 8a02aa5..b3893fe 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -83,12 +83,26 @@ def load(path: str = ".") -> Blockchain: if not isinstance(raw_accounts, dict): raise ValueError(f"Invalid accounts data in '{path}'") - blocks = [_deserialize_block(b) for b in raw_blocks] + blocks = [] + for raw_block in raw_blocks: + if not isinstance(raw_block, dict): + raise ValueError(f"Invalid chain data in '{path}'") + try: + blocks.append(_deserialize_block(raw_block)) + except (KeyError, TypeError, ValueError) as exc: + raise ValueError(f"Invalid chain data in '{path}'") from exc + + normalized_accounts = {} + for address, account in raw_accounts.items(): + if not isinstance(address, str) or not isinstance(account, dict): + raise ValueError(f"Invalid accounts data in '{path}'") + normalized_accounts[address] = account + _verify_chain_integrity(blocks) blockchain = Blockchain() blockchain.chain = blocks - blockchain.state.accounts = raw_accounts + blockchain.state.accounts = normalized_accounts logger.info( "Loaded %d blocks and %d accounts from '%s'", diff --git a/tests/test_persistence.py b/tests/test_persistence.py index e62dd73..0313636 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -143,6 +143,45 @@ def test_corrupted_sqlite_payload_raises(self): with self.assertRaises(ValueError): load(path=self.tmpdir) + def test_malformed_block_row_raises_value_error(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + conn.execute( + "UPDATE blocks SET block_json = ? WHERE height = 0", + (json.dumps(["not-a-block-dict"]),), + ) + with self.assertRaises(ValueError): + load(path=self.tmpdir) + + def test_block_missing_required_field_raises_value_error(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + row = conn.execute("SELECT block_json FROM blocks WHERE height = 0").fetchone() + payload = json.loads(row[0]) + payload.pop("hash", None) + conn.execute( + "UPDATE blocks SET block_json = ? WHERE height = 0", + (json.dumps(payload),), + ) + with self.assertRaises(ValueError): + load(path=self.tmpdir) + + def test_malformed_account_row_raises_value_error(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + conn.execute( + "UPDATE accounts SET account_json = ? WHERE address = ?", + (json.dumps(["not-an-account-dict"]), next(iter(bc.state.accounts))), + ) + with self.assertRaises(ValueError): + load(path=self.tmpdir) + def test_missing_file_raises(self): with self.assertRaises(FileNotFoundError): load(path=self.tmpdir) From 28ce0a67f27ddf7dc5f2eec59204c028797e5d3e Mon Sep 17 00:00:00 2001 From: Arunabha Date: Sat, 28 Mar 2026 22:00:14 +0530 Subject: [PATCH 4/5] fix: reject sqlite schema corruption on load --- minichain/persistence.py | 12 +++++++++++- tests/test_persistence.py | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/minichain/persistence.py b/minichain/persistence.py index b3893fe..f2d54d4 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -166,6 +166,16 @@ def _initialize_schema(conn: sqlite3.Connection) -> None: ) +def _require_schema(conn: sqlite3.Connection) -> None: + required = {"blocks", "accounts", "metadata"} + rows = conn.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ).fetchall() + existing = {row["name"] for row in rows} + if not required.issubset(existing): + raise ValueError("Missing persistence tables") + + def _save_snapshot_to_sqlite(db_path: str, snapshot: dict[str, Any]) -> None: conn = _connect(db_path) try: @@ -202,7 +212,7 @@ def _load_snapshot_from_sqlite(db_path: str) -> dict[str, Any]: raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc try: - _initialize_schema(conn) + _require_schema(conn) block_rows = conn.execute( "SELECT block_json FROM blocks ORDER BY height ASC" ).fetchall() diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 0313636..33701c7 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -143,6 +143,15 @@ def test_corrupted_sqlite_payload_raises(self): with self.assertRaises(ValueError): load(path=self.tmpdir) + def test_missing_required_sqlite_table_raises(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + conn.execute("DROP TABLE accounts") + with self.assertRaises(ValueError): + load(path=self.tmpdir) + def test_malformed_block_row_raises_value_error(self): bc = Blockchain() save(bc, path=self.tmpdir) From 3992bbe614f4034ad01d2b13410873c4ada187ae Mon Sep 17 00:00:00 2001 From: Arunabha Date: Sat, 28 Mar 2026 22:35:28 +0530 Subject: [PATCH 5/5] fix: validate sqlite chain length metadata --- minichain/persistence.py | 13 +++++++++++++ tests/test_persistence.py | 9 +++++++++ 2 files changed, 22 insertions(+) diff --git a/minichain/persistence.py b/minichain/persistence.py index f2d54d4..626210d 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -219,11 +219,24 @@ def _load_snapshot_from_sqlite(db_path: str) -> dict[str, Any]: account_rows = conn.execute( "SELECT address, account_json FROM accounts ORDER BY address ASC" ).fetchall() + chain_length_row = conn.execute( + "SELECT value FROM metadata WHERE key = ?", + ("chain_length",), + ).fetchone() except sqlite3.DatabaseError as exc: raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc finally: conn.close() + if chain_length_row is None: + raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") + try: + expected_chain_length = int(chain_length_row["value"]) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc + if expected_chain_length != len(block_rows): + raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") + try: chain = [json.loads(row["block_json"]) for row in block_rows] state = { diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 33701c7..c8325d7 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -152,6 +152,15 @@ def test_missing_required_sqlite_table_raises(self): with self.assertRaises(ValueError): load(path=self.tmpdir) + def test_truncated_chain_rows_raises_value_error(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + db_path = os.path.join(self.tmpdir, DB_FILE) + with sqlite3.connect(db_path) as conn: + conn.execute("DELETE FROM blocks WHERE height = 1") + with self.assertRaises(ValueError): + load(path=self.tmpdir) + def test_malformed_block_row_raises_value_error(self): bc = Blockchain() save(bc, path=self.tmpdir)