From e63f0ea43065f87ce798b9c01852e63606879526 Mon Sep 17 00:00:00 2001 From: NTFSvolume <172021377+NTFSvolume@users.noreply.github.com> Date: Sun, 15 Mar 2026 20:21:21 -0500 Subject: [PATCH] refactor: rework database --- cyberdrop_dl/__main__.py | 14 + cyberdrop_dl/database2/__init__.py | 522 +++++++++++++++++++++++++++++ cyberdrop_dl/database2/query.py | 75 +++++ cyberdrop_dl/database2/tables.py | 162 +++++++++ cyberdrop_dl/database2/transfer.py | 202 +++++++++++ cyberdrop_dl/progress/scraping.py | 4 +- 6 files changed, 977 insertions(+), 2 deletions(-) create mode 100644 cyberdrop_dl/database2/__init__.py create mode 100644 cyberdrop_dl/database2/query.py create mode 100644 cyberdrop_dl/database2/tables.py create mode 100644 cyberdrop_dl/database2/transfer.py diff --git a/cyberdrop_dl/__main__.py b/cyberdrop_dl/__main__.py index 49cf44310..891227814 100644 --- a/cyberdrop_dl/__main__.py +++ b/cyberdrop_dl/__main__.py @@ -19,6 +19,7 @@ from cyberdrop_dl.models import format_validation_error from cyberdrop_dl.models.types import HttpURL from cyberdrop_dl.notifications import send_notifications +from cyberdrop_dl.progress.scraping import StatusMessage from cyberdrop_dl.sorting import Sorter from cyberdrop_dl.updates import check_latest_pypi from cyberdrop_dl.utils import check_partials_and_empty_folders @@ -199,6 +200,19 @@ async def load_items() -> AsyncGenerator[ScrapeItem]: aio.run(scrape(manager, load_items)) +@app.command() +def transfer(db_path: Path, output_path: Path, /) -> None: + """Transfer an old database file (v8.10) to the newer v10 format""" + with setup_logging(Path("cdl_transfer.log"), level=logging.DEBUG): + logger.warning( + " Make sure the old database is from version 8.10. Otherwise the migration may fail. Press any key to continue" + ) + from cyberdrop_dl.database2.transfer import migrate + + with StatusMessage("Database transfer in progress...").activity: + migrate(db_path, output_path) + + def main() -> None: app() diff --git a/cyberdrop_dl/database2/__init__.py b/cyberdrop_dl/database2/__init__.py new file mode 100644 index 000000000..c8307c849 --- /dev/null +++ b/cyberdrop_dl/database2/__init__.py @@ -0,0 +1,522 @@ +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +import logging +from contextvars import ContextVar +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias, cast + +import aiosqlite +from packaging.version import Version + +from cyberdrop_dl.database2 import query + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + import datetime + from collections.abc import AsyncGenerator, Iterable + from sqlite3 import Row + + from cyberdrop_dl.data_structures.url_objects import AbsoluteHttpURL, MediaItem + from cyberdrop_dl.database2.tables import Table + + +Properties: TypeAlias = tuple[str, ...] + + +_current_db: ContextVar[Database] = ContextVar("_db") +_FETCH_MANY_SIZE: int = 1000 +CURRENT_APP_SCHEMA_VERSION = "8.10.0" +MIN_REQUIRED_VERSION = "8.10.0" + +create_hash_index = """ +CREATE INDEX IF NOT EXISTS idx_hash_type_hash ON hash (hash_type, hash); +""" + + +@dataclasses.dataclass(slots=True) +class Database: + db_path: Path + ignore_history: bool + + conn: aiosqlite.Connection = dataclasses.field(init=False) + + async def connect(self) -> None: + exists = self.db_path.exists() + self.conn = await aiosqlite.connect(self.db_path, timeout=20) + self.conn.row_factory = aiosqlite.Row + + if exists: + await self._check() + + await self._pre_allocate() + for table in self.tables: + await self._create(table) + + await self._update() + + async def fetchone(self, query: str, parameters: Iterable[Any] | None = None) -> Row | None: + cursor = await self.conn.execute(query, parameters) + return await cursor.fetchone() + + async def fetchall(self, query: str, parameters: Iterable[Any] | None = None) -> list[Row]: + return await self.conn.execute_fetchall(query, parameters) # pyright: ignore[reportReturnType] + + async def commit(self, query: str, parameters: Iterable[Any] | None = None) -> None: + _ = await self.conn.execute(query, parameters) + await self.conn.commit() + + async def close(self) -> None: + await self.conn.close() + + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *_) -> None: + await self.close() + + async def _pre_allocate(self) -> None: + """We pre-allocate 100MB of space to the SQL file just in case the user runs out of disk space.""" + + free_space = await self.fetchone("PRAGMA freelist_count;") + assert free_space is not None + + if free_space[0] > 1024: + return + + pre_allocate_script = ( + "CREATE TABLE IF NOT EXISTS t(x);" + "INSERT INTO t VALUES(zeroblob(100*1024*1024));" # 100 MB + "DROP TABLE t;" + ) + _ = await self.conn.executescript(pre_allocate_script) + await self.conn.commit() + + async def _get_version(self) -> Version | None: + sql, _ = query.select(self.tables.schema, "schema_version", limit=1) + if result := await self.fetchone(sql + " ORDER BY ROWID DESC"): + return Version(result["version"]) + + async def _check(self) -> None: + logger.info(f"Expected database schema version: {CURRENT_APP_SCHEMA_VERSION}") + version = await self._get_version() + logger.info(f"Database reports installed version: {version}") + if version is None or version < Version(MIN_REQUIRED_VERSION): + raise RuntimeError("Unsupported database version") + + async def _update(self) -> None: + version = await self._get_version() + if version is not None and version >= Version(CURRENT_APP_SCHEMA_VERSION): + return + + # TODO: on v9, raise SystemExit if db version is None or older than 8.0.0 + logger.info(f"Updating database version to {CURRENT_APP_SCHEMA_VERSION}") + sql, params = query.insert(self.tables.schema, version=CURRENT_APP_SCHEMA_VERSION) + await self.commit(sql, params) + + async def _create(self, table: Table) -> None: + await self.commit(query.create(table, **table.foreign)) + + +@contextlib.asynccontextmanager +async def connect(db_path: Path, ignore_history: bool) -> AsyncGenerator[Database]: + async with Database(db_path, ignore_history) as db: + token = _current_db.set(db) + try: + yield db + finally: + _current_db.reset(token) + + +async def check_complete(domain: str, url: AbsoluteHttpURL, referer: AbsoluteHttpURL, db_path: str) -> bool: + """Checks whether an individual file has completed given its domain and url path.""" + db = _current_db.get() + if db.ignore_history: + return False + + async def get_referer_and_completed() -> tuple[str, bool]: + sql, params = query.select(db.tables.history, "referer", "completed", domain=domain, url_path=db_path) + if row := await db.fetchone(sql, params): + return row["referer"], bool(row["completed"]) + return "", False + + current_referer, completed = await get_referer_and_completed() + if completed and url != referer and str(referer) != current_referer: + # Update the referer if it has changed so that check_complete_by_referer can work + logger.info(f"Updating referer of {url} from {current_referer} to {referer}") + sql, params = query.update(db.tables.history, referer=referer, domain=domain, url_path=db_path) + await db.commit(sql, params) + + return completed + + +async def check_album(domain: str, album_id: str) -> dict[str, bool]: + """Checks whether an album has completed given its domain and album id.""" + db = _current_db.get() + if db.ignore_history: + return {} + + sql, params = query.select(db.tables.history, "url_path", "completed", domain=domain, album_id=album_id) + rows = await db.conn.execute_fetchall(sql, params) + return {row["url_path"]: bool(row["completed"]) for row in rows} + + +async def set_album_id(domain: str, media_item: MediaItem) -> None: + """Sets an album_id in the database.""" + db = _current_db.get() + sql, params = query.update( + db.tables.history, + album_id=media_item.album_id, + domain=domain, + url_path=media_item.db_path, + ) + await db.commit(sql, params) + + +async def check_complete_by_referer(domain: str | None, referer: AbsoluteHttpURL) -> bool: + """Checks whether an individual file has completed given its domain and url path.""" + db = _current_db.get() + if db.ignore_history: + return False + + if domain is None: + sql, params = query.exists(db.tables.history, completed=1, referer=referer) + + else: + sql, params = query.exists(db.tables.history, completed=1, referer=referer, domain=domain) + + return bool(await db.fetchone(sql, params)) + + +async def insert_incompleted(domain: str, media_item: MediaItem) -> None: + """Inserts an uncompleted file into the database.""" + + db = _current_db.get() + download_filename = media_item.download_filename or "" + sql, params = query.insert_or_ignore( + db.tables.history, + domain=domain, + url_path=media_item.db_path, + referer=media_item.referer, + album_id=media_item.album_id, + download_path=media_item.download_folder, + download_filename=download_filename, + original_filename=media_item.original_filename, + ) + + await db.commit(sql, params) + + +async def mark_complete(domain: str, media_item: MediaItem) -> None: + """Mark a download as completed in the database.""" + db = _current_db.get() + sql, params = query.update( + db.tables.history, + completed=1, + completed_at="CURRENT_TIMESTAMP", + domain=domain, + url_path=media_item.db_path, + ) + await db.commit(sql, params) + + +async def add_filesize(domain: str, media_item: MediaItem) -> None: + """Adds the file size to the db.""" + db = _current_db.get() + + sql, params = query.update( + db.tables.history, + file_size=await asyncio.to_thread(lambda *_: media_item.complete_file.stat().st_size), + domain=domain, + url_path=media_item.db_path, + ) + await db.commit(sql, params) + + +async def add_duration(domain: str, media_item: MediaItem) -> None: + """Adds the duration to the db.""" + db = _current_db.get() + sql, params = query.update( + db.tables.history, + duration=media_item.duration, + domain=domain, + url_path=media_item.db_path, + ) + await db.commit(sql, params) + + +async def get_duration(domain: str, media_item: MediaItem) -> float | None: + """Returns the duration from the database.""" + if media_item.is_segment: + return + + db = _current_db.get() + sql, params = query.select( + db.tables.history, + "duration", + domain=domain, + url_path=media_item.db_path, + limit=1, + ) + if row := await db.fetchone(sql, params): + return row["duration"] + + +async def add_download_filename(domain: str, media_item: MediaItem) -> None: + """Add the download_filename to the db.""" + db = _current_db.get() + url_path = media_item.db_path + query = "UPDATE media SET download_filename=? WHERE domain = ? and url_path = ? and download_filename = ''" + await db.conn.execute(query, (media_item.download_filename, domain, url_path)) + await db.conn.commit() + + +async def check_filename_exists(filename: str) -> bool: + """Checks whether a downloaded filename exists in the database.""" + db = _current_db.get() + sql, params = query.exists(db.tables.history, download_filename=filename) + return bool(await db.fetchone(sql, params)) + + +async def get_downloaded_filename(domain: str, media_item: MediaItem) -> str | None: + """Returns the downloaded filename from the database.""" + + if media_item.is_segment: + return media_item.filename + + db = _current_db.get() + sql, params = query.select( + db.tables.history, + "download_filename", + domain=domain, + url_path=media_item.db_path, + limit=1, + ) + if row := await db.fetchone(sql, params): + return row["download_filename"] + + +async def get_failed_items() -> AsyncGenerator[list[Row]]: + """Returns a list of failed items.""" + db = _current_db.get() + sql, params = query.select(db.tables.history, "referer", "download_path", "completed_at", "created_at", completed=0) + cursor = await db.conn.execute(sql, params) + while rows := await cursor.fetchmany(_FETCH_MANY_SIZE): + yield cast("list[Row]", rows) + + +async def get_all_items(after: datetime.date, before: datetime.date) -> AsyncGenerator[list[Row]]: + """Returns a list of all items.""" + query_ = """ + SELECT referer,download_path,completed_at,created_at + FROM media WHERE COALESCE(completed_at, '1970-01-01') BETWEEN ? AND ? + ORDER BY completed_at DESC; + """ + db = _current_db.get() + cursor = await db.conn.execute(query_, (after.isoformat(), before.isoformat())) + while rows := await cursor.fetchmany(_FETCH_MANY_SIZE): + yield cast("list[Row]", rows) + + +async def get_all_bunkr_failed() -> AsyncGenerator[list[Row]]: + async for rows in get_all_bunkr_failed_via_hash(): + yield rows + async for rows in get_all_bunkr_failed_via_size(): + yield rows + + +async def get_all_bunkr_failed_via_size() -> AsyncGenerator[list[Row]]: + db = _current_db.get() + sql, params = query.select( + db.tables.history, + "referer", + "download_path", + "completed_at", + "created_at", + file_size=322_509, + domain="bunkr", + ) + + cursor = await db.conn.execute(sql, params) + while rows := await cursor.fetchmany(_FETCH_MANY_SIZE): + yield cast("list[Row]", rows) + + +async def get_all_bunkr_failed_via_hash() -> AsyncGenerator[list[Row]]: + query = """ + SELECT m.referer,download_path,completed_at,created_at + FROM hash h INNER JOIN media m ON h.download_filename= m.download_filename + WHERE h.hash = 'eb669b6362e031fa2b0f1215480c4e30'; + """ + + db = _current_db.get() + cursor = await db.conn.execute(query) + while rows := await cursor.fetchmany(_FETCH_MANY_SIZE): + yield cast("list[Row]", rows) + + +async def get_file_hash_exists(path: Path | str, hash_type: str) -> str | None: + query = "SELECT hash FROM hash WHERE folder=? AND download_filename=? AND hash_type=? AND hash IS NOT NULL" + db = _current_db.get() + + path = Path(path) + if not path.is_absolute(): + path = path.absolute() + folder = str(path.parent) + filename = path.name + + # Check if the file exists with matching folder, filename, and size + if row := await db.fetchone(query, (folder, filename, hash_type)): + return row[0] + + +async def get_files_with_hash_matches(hash_value: str, size: int, hash_type: str | None = None) -> list[aiosqlite.Row]: + """Retrieves a list of (folder, filename) tuples based on a given hash. + + Args: + hash_value: The hash value to search for. + size: file size + + Returns: + A list of (folder, filename) tuples, or an empty list if no matches found. + """ + db = _current_db.get() + if hash_type: + query = """ + SELECT files.folder, files.download_filename,files.date + FROM hash JOIN files ON hash.folder = files.folder AND hash.download_filename = files.download_filename + WHERE hash.hash = ? AND files.file_size = ? AND hash.hash_type = ?; + """ + + else: + query = """ + SELECT files.folder, files.download_filename FROM hash JOIN files + ON hash.folder = files.folder AND hash.download_filename = files.download_filename + WHERE hash.hash = ? AND files.file_size = ? AND hash.hash_type = ?; + """ + + return await db.fetchall(query, (hash_value, size, hash_type)) + + +async def check_hash_exists(hash_type: str, hash_value: str) -> bool: + db = _current_db.get() + if db.ignore_history: + return False + + query = "SELECT 1 FROM hash WHERE hash.hash_type = ? AND hash.hash = ? LIMIT 1" + return bool(await db.fetchone(query, (hash_type, hash_value))) + + +async def insert_or_update_hash_db( + hash_value: str, + hash_type: Literal["md5", "sha256"], + file: Path | str, + original_filename: str | None, + referer: AbsoluteHttpURL | None, +) -> bool: + """Inserts or updates a record in the specified SQLite database. + + Args: + hash_value: The calculated hash of the file. + file: The file path + original_filename: The name original name of the file. + referer: referer URL + hash_type: The hash type (e.g., md5, sha256) + + Returns: + True if all the record was inserted or updated successfully, False otherwise. + """ + + hash = await insert_or_update_hashes(hash_value, hash_type, file) + file_ = await insert_or_update_file(original_filename, referer, file) + return file_ and hash + + +async def insert_or_update_hashes(hash_value: str, hash_type: str, file: Path | str) -> bool: + query = """ + INSERT INTO hash (hash, hash_type, folder, download_filename) + VALUES (?, ?, ?, ?) ON CONFLICT(download_filename, folder, hash_type) DO UPDATE SET hash = ?; + """ + db = _current_db.get() + try: + full_path = Path(file) + if not full_path.is_absolute(): + full_path = full_path.absolute() + download_filename = full_path.name + folder = str(full_path.parent) + await db.commit(query, (hash_value, hash_type, folder, download_filename, hash_value)) + + except Exception as e: + logger.exception(f"Error inserting/updating record: {e}") + return False + else: + return True + + +async def insert_or_update_file( + original_filename: str | None, referer: AbsoluteHttpURL | str | None, file: Path | str +) -> bool: + query = """ + INSERT INTO files (folder, original_filename, download_filename, file_size, referer, date) + VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(download_filename, folder) + DO UPDATE SET original_filename = ?, file_size = ?, referer = ?, date = ?; + """ + referer_ = str(referer) if referer else None + db = _current_db.get() + try: + full_path = Path(file) + if not full_path.is_absolute(): + full_path = full_path.absolute() + download_filename = full_path.name + folder = str(full_path.parent) + stat = full_path.stat() + file_size = stat.st_size + file_date = int(stat.st_mtime) + await db.commit( + query, + ( + folder, + original_filename, + download_filename, + file_size, + referer_, + file_date, + original_filename, + file_size, + referer_, + file_date, + ), + ) + except Exception as e: + logger.exception(f"Error inserting/updating record: {e}", 40, exc_info=e) + return False + return True + + +async def get_all_unique_hashes(hash_type: str | None = None) -> list[str]: + """Retrieves a list of hashes + + Args: + hash_value: The hash value to search for. + hash_type: The type of hash[optional] + + Returns: + A list of (folder, filename) tuples, or an empty list if no matches found. + """ + db = _current_db.get() + if hash_type: + query, params = "SELECT DISTINCT hash FROM hash WHERE hash_type =?", (hash_type,) + + else: + query, params = "SELECT DISTINCT hash FROM hash", () + try: + rows = await db.fetchall(query, params) + return [row[0] for row in rows] + except Exception as e: + logger.exception(f"Error retrieving folder and filename: {e}") + return [] diff --git a/cyberdrop_dl/database2/query.py b/cyberdrop_dl/database2/query.py new file mode 100644 index 000000000..fa2a88245 --- /dev/null +++ b/cyberdrop_dl/database2/query.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, NewType + +if TYPE_CHECKING: + from cyberdrop_dl.database2.tables import Table + +Command = NewType("Command", str) + + +def exists(table: Table, **where: Any) -> tuple[Command, tuple[Any, ...]]: + table.check_columns(where.keys()) + conditions = " AND ".join(f"{key}=?" for key in where.keys()) + command = f"SELECT EXISTS(SELECT 1 FROM {table.__table_name__} WHERE {conditions})" + return Command(command), tuple(where.values()) + + +def insert(table: Table, **values: Any) -> tuple[Command, tuple[Any, ...]]: + return _insert(table, "INSERT", **values) + + +def insert_or_ignore(table: Table, **values: Any) -> tuple[Command, tuple[Any, ...]]: + return _insert(table, "INSERT OR IGNORE", **values) + + +def _insert(table: Table, exc: str = "INSERT", **values: Any) -> tuple[Command, tuple[Any, ...]]: + table.check_columns(values.keys()) + assert len(values) == len(table.COLUMNS) + columns = ", ".join(values.keys()) + placeholders = ", ".join("?" for _ in values) + command = f"{exc} INTO {table.__table_name__} ({columns}) VALUES ({placeholders})" + return Command(command), tuple(values.values()) + + +def select(table: Table, *columns: str, limit: int | None = None, **where: Any) -> tuple[Command, tuple[Any, ...]]: + assert columns + table.check_columns(columns) + wanted = ", ".join(columns) + command = f"SELECT {wanted} FROM {table.__table_name__}" + if where: + table.check_columns(where.keys()) + conditions = " AND ".join(f"{key}=?" for key in where.keys()) + command += f" WHERE {conditions}" + if limit: + command += f" LIMIT {limit}" + + return Command(command), tuple(where.values()) + + +def update(table: Table, **row: Any) -> tuple[Command, tuple[Any, ...]]: + table.check_columns(row.keys()) + + p_keys: dict[str, Any] = {} + other_keys: dict[str, Any] = {} + for key, value in row.items(): + if key in table.UNIQUE: + p_keys[key] = value + else: + other_keys[key] = value + + assert p_keys + assert other_keys + + new = ", ".join(f"{key}={_placeholder(v)}" for key, v in other_keys.items()) + conditions = " AND ".join(f"{key}=?" for key in p_keys) + command = f"UPDATE {table.__table_name__} SET {new} WHERE {conditions}" + values = *(v for v in other_keys.values() if v != "CURRENT_TIMESTAMP"), *p_keys.values() + + return Command(command), values + + +def _placeholder(v: Any) -> Literal["CURRENT_TIMESTAMP", "?"]: + if v == "CURRENT_TIMESTAMP": + return v + return "?" diff --git a/cyberdrop_dl/database2/tables.py b/cyberdrop_dl/database2/tables.py new file mode 100644 index 000000000..0c62cba5e --- /dev/null +++ b/cyberdrop_dl/database2/tables.py @@ -0,0 +1,162 @@ +import dataclasses +import datetime +import logging +from collections.abc import Generator, Iterable +from typing import Any, ClassVar, Self, get_args + +import aiosqlite + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(slots=True) +class Reference: + table: str + column: str + on_delete: str + + def __str__(self) -> str: + return f"REFERENCES {self.table}({self.column}) ON DELETE {self.on_delete}" + + +PK = {"PK": True} +AUTOINCREMENT = {"AUTOINCREMENT": True} + + +def REFERENCE(table: str, column: str, on_delete: str = "CASCADE") -> dict[str, Reference]: # noqa: N802 + return {"REFERENCE": Reference(table, column, on_delete)} + + +_type_map: dict[type[Any], str] = { + int: "INTEGER", + float: "FLOAT", + str: "TEXT", + datetime.datetime: "DATETIME", +} + + +def _now() -> datetime.datetime: + return datetime.datetime.now(datetime.UTC) + + +@dataclasses.dataclass(slots=True) +class Table: + __table_name__: ClassVar[str] + COLUMNS: ClassVar[set[str]] + UNIQUE: ClassVar[tuple[str, ...]] = () + + def __repr__(self) -> str: + return f"{type(self).__name__}(name={self.__table_name__!r}, columns={self.COLUMNS!r})" + + def __init_subclass__(cls) -> None: + cls.__table_name__ = getattr(cls, "__table_name__", None) or cls.__name__.lower() + cls.COLUMNS = {f.name for f in dataclasses.fields(cls)} + + def check_columns(self, other: Iterable[str]) -> None: + assert self.COLUMNS.issuperset(other), f"Invalid keys for table {self.__table_name__}. {tuple(other)}" + + @classmethod + def from_row(cls, row: aiosqlite.Row) -> Self: + return cls(**{name: row[name] for name in cls.COLUMNS}) + + @classmethod + def to_sql_schema(cls) -> str: + joined_columns = ",\n".join(cls._parse_columns()) + sql = f"CREATE TABLE IF NOT EXISTS {cls.__table_name__} (\n{joined_columns}" + if cls.UNIQUE: + sql += f",\nUNIQUE({', '.join(cls.UNIQUE)})" + return sql + "\n);" + + @classmethod + def _parse_columns(cls) -> Generator[str]: + for field in dataclasses.fields(cls): + # This only work if we do not use __future__ annotations + if isinstance(field.type, type): + python_type = field.type + else: + python_type, *_ = get_args(field.type) + + sql_type = _type_map[python_type] + column = f"{field.name} {sql_type}" + + if field.metadata.get("PK"): + column += " PRIMARY KEY" + + elif field.default is not None: + column += " NOT NULL" + + if field.metadata.get("AUTOINCREMENT"): + column += " AUTOINCREMENT" + + if reference := field.metadata.get("REFERENCE"): + column += f" {reference}" + + if field.default_factory is _now: + column += " DEFAULT (datetime('now'))" + + yield column + + +@dataclasses.dataclass(slots=True) +class Media(Table): + id: int = dataclasses.field(metadata=PK | AUTOINCREMENT) + domain: str + url_path: str + referer: str + name: str + album_id: str | None = None + size: int | None = None + duration: float | None = None + created_at: datetime.datetime = dataclasses.field(default_factory=_now) + + UNIQUE: ClassVar[tuple[str, ...]] = "domain", "url_path" + + +@dataclasses.dataclass(slots=True) +class Downloads(Table): + id: int = dataclasses.field(metadata=PK | AUTOINCREMENT) + media_id: int = dataclasses.field(metadata=REFERENCE("media", "id", "CASCADE")) + folder: str + file_name: str + original_file_name: str + created_at: datetime.datetime = dataclasses.field(default_factory=_now) + completed_at: datetime.datetime | None = None + + +@dataclasses.dataclass(slots=True) +class Files(Table): + """Table of files that exists on disk""" + + id: int = dataclasses.field(metadata=PK | AUTOINCREMENT) + folder: str + name: str + size: int + modtime: datetime.datetime | None = None + + UNIQUE: ClassVar[tuple[str, ...]] = "folder", "name" + + +@dataclasses.dataclass(slots=True) +class Hash(Table): + file_id: int = dataclasses.field(metadata=REFERENCE("files", "id", "CASCADE")) + algorithm: str + hash: str + + UNIQUE: ClassVar[tuple[str, ...]] = "file_id", "algorithm", "hash" + + +@dataclasses.dataclass(slots=True) +class Schema(Table): + __table_name__: ClassVar[str] = "schema_version" + version: str = dataclasses.field(metadata=PK) + applied_on: datetime.datetime = dataclasses.field(default_factory=_now) + + UNIQUE: ClassVar[tuple[str, ...]] = ("version",) + + +TABLES = (Media, Downloads, Files, Hash, Schema) + +if __name__ == "__main__": + for table in TABLES: + print("") # noqa: T201 + print(table.to_sql_schema()) # noqa: T201 diff --git a/cyberdrop_dl/database2/transfer.py b/cyberdrop_dl/database2/transfer.py new file mode 100644 index 000000000..62e26da9d --- /dev/null +++ b/cyberdrop_dl/database2/transfer.py @@ -0,0 +1,202 @@ +import datetime +import logging +import shutil +import sqlite3 +import sys +from pathlib import Path + +from cyberdrop_dl.database2.tables import Downloads, Files, Hash, Media + +logger = logging.getLogger(__name__) + +create_downloads = f""" +{Downloads.to_sql_schema()} + +INSERT INTO + downloads ( + media_id, + folder, + file_name, + original_file_name, + created_at, + completed_at + ) +SELECT + media.id, + old.download_path AS folder, + COALESCE( + NULLIF(old.download_filename, ''), + old.original_filename + ) AS file_name, + old.original_filename AS original_file_name, + COALESCE(old.created_at, datetime('now')) AS created_at, + old.completed_at AS completed_at +FROM + old.media AS OLD + JOIN media ON media.domain = old.domain + AND media.url_path = old.url_path +WHERE + old.download_filename IS NOT NULL + AND old.download_path IS NOT NULL +ORDER BY + media.id, + COALESCE(old.created_at, ''), + old.rowid; +""" + +_transfer_media = f""" +{Media.to_sql_schema()} + +INSERT INTO + media ( + domain, + url_path, + referer, + name, + album_id, + size, + duration, + created_at + ) +SELECT + domain, + url_path, + COALESCE(referer, '') AS referer, + COALESCE(original_filename, '') AS name, + album_id, + file_size AS size, + duration, + COALESCE(created_at, datetime('now')) AS created_at +FROM + ( + SELECT + *, + ROW_NUMBER() OVER ( + PARTITION BY + domain, + url_path + ORDER BY + CASE + WHEN created_at IS NULL THEN 0 + ELSE 1 + END DESC, + created_at DESC, + rowid DESC + ) AS rn + FROM + old.media + ) +WHERE + rn = 1; +""" + +_transfer_files = f""" +{Files.to_sql_schema()} + +INSERT INTO + files (folder, name, size, modtime) +SELECT + folder, + COALESCE(NULLIF(download_filename, ''), original_filename) AS name, + file_size AS size, + CASE + WHEN DATE IS NOT NULL THEN datetime(DATE, 'unixepoch') + ELSE NULL + END AS modtime +FROM + ( + SELECT + *, + ROW_NUMBER() OVER ( + PARTITION BY + folder, + COALESCE(NULLIF(download_filename, ''), original_filename) + ORDER BY + CASE + WHEN DATE IS NULL THEN 0 + ELSE 1 + END DESC, + DATE DESC, + rowid DESC + ) AS rn + FROM + old.files + ) +WHERE + rn = 1 + AND COALESCE(NULLIF(download_filename, ''), original_filename) IS NOT NULL + AND COALESCE(NULLIF(download_filename, ''), original_filename) <> ''; +""" + +_transfer_hash = f""" +{Hash.to_sql_schema()} + +INSERT INTO + hash (file_id, algorithm, hash) +SELECT + files.id AS file_id, + old_hash.hash_type AS algorithm, + old_hash.hash AS hash +FROM + old.hash AS old_hash + JOIN files ON files.folder = old_hash.folder + AND files.name = old_hash.download_filename +ORDER BY + files.id; +""" + + +def migrate(old_db: Path, new_db: Path) -> None: + if not old_db.is_file(): + raise FileNotFoundError(old_db) + + now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup = old_db.parent / f"{old_db.stem}_{now}.bak{old_db.suffix}" + logger.info(f"Created backup at '{backup}'") + __ = shutil.copy2(old_db, backup) + + if new_db.exists(): + raise FileExistsError(new_db) + + new_db.parent.mkdir(parents=True, exist_ok=True) + + conn = sqlite3.connect(new_db) + + try: + with conn: + conn.execute("PRAGMA journal_mode = WAL;") + conn.execute("PRAGMA foreign_keys = OFF;") + conn.execute("ATTACH DATABASE ? AS old;", (str(old_db),)) + conn.executescript(_transfer_media) + conn.executescript(create_downloads) + conn.executescript(_transfer_files) + conn.executescript(_transfer_hash) + conn.execute("DETACH DATABASE old;") + conn.execute("PRAGMA foreign_keys = ON;") + except BaseException: + logger.warning("Transfer cancelled") + conn.close() + new_db.unlink(missing_ok=True) + raise + else: + conn.close() + + def count(conn: sqlite3.Connection, table: str) -> int: + return conn.execute(f"SELECT COUNT(*) FROM {table};").fetchone()[0] + + tables = "media", "files", "hash" + with sqlite3.connect(new_db) as new_conn: + rows_copied = {name: count(new_conn, name) for name in tables} + + with sqlite3.connect(old_db) as old_conn: + rows_old = {name: count(old_conn, name) for name in tables} + + for table in tables: + msg = f"Copied {rows_copied[table]:,} {table} rows into '{new_db}' (original db had {rows_old[table]:,})" + logger.info(msg) + + +if __name__ == "__main__": + old_db = Path(sys.argv[1]) + new_db = Path("new.db") + migrate(old_db, new_db) diff --git a/cyberdrop_dl/progress/scraping.py b/cyberdrop_dl/progress/scraping.py index 13a51cefb..8cec3a135 100644 --- a/cyberdrop_dl/progress/scraping.py +++ b/cyberdrop_dl/progress/scraping.py @@ -22,11 +22,11 @@ def __init__(self) -> None: class StatusMessage(UIComponent): - columns: ClassVar[ColumnsType] = (SpinnerColumn(), "[progress.description]{task.description}") + columns: ClassVar[ColumnsType] = (SpinnerColumn("dots6"), "[progress.description]{task.description}") def __init__(self, description: str = f"Running Cyberdrop-DL: v{__version__}") -> None: super().__init__() - self.activity: Progress = Progress(*self.columns) + self.activity: Progress = Progress(*self.columns, transient=True, expand=True) _ = self.activity.add_task(description) task_id = self._progress.add_task("", total=100, completed=0, visible=False) self._task: Task = self._progress[task_id]