diff --git a/pyproject.toml b/pyproject.toml index dd30532..4b85ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "folio_data_import" -version = "0.5.0" +version = "0.5.1" description = "A python module to perform bulk import of data into a FOLIO environment. Currently supports MARC and user data import." authors = [{ name = "Brooks Travis", email = "brooks.travis@gmail.com" }] license = "MIT" @@ -48,11 +48,15 @@ folio-data-import = "folio_data_import.__main__:app" folio-marc-import = "folio_data_import.MARCDataImport:app" folio-user-import = "folio_data_import.UserImport:app" folio-batch-poster = "folio_data_import.BatchPoster:app" +folio-di-logs = "folio_data_import.DILogRetriever:app" [project.optional-dependencies] redis = [ "redis>=7.1.0", ] +postgres = [ + "psycopg2-binary>=2.9.11", +] # Build [build-system] diff --git a/src/folio_data_import/BatchPoster.py b/src/folio_data_import/BatchPoster.py index a8cc291..4e618ad 100644 --- a/src/folio_data_import/BatchPoster.py +++ b/src/folio_data_import/BatchPoster.py @@ -9,6 +9,7 @@ import glob as glob_module import json import logging +import os import sys from io import TextIOWrapper from pathlib import Path @@ -283,7 +284,15 @@ def __init__( """ self.folio_client = folio_client self.config = config - self.reporter = reporter or NoOpProgressReporter() + # Create reporter from config if not provided + if reporter is None: + self.reporter = ( + NoOpProgressReporter() + if config.no_progress + else RichProgressReporter(show_speed=True, show_time=True) + ) + else: + self.reporter = reporter self.api_info = get_api_info(config.object_type) self.stats = BatchPosterStats() @@ -306,6 +315,10 @@ def __init__( if config.upsert and not self.api_info["supports_upsert"]: raise ValueError(f"Upsert is not supported for {config.object_type}") + self.semaphore = asyncio.Semaphore( + int(os.environ.get("FOLIO_MAX_CONCURRENT_REQUESTS", 10)) + ) # Limit concurrent requests + async def __aenter__(self): """Async context manager entry.""" # Open the file if we own it and it's not already open @@ -555,9 +568,10 @@ async def fetch_batch(batch_ids: List[str]) -> dict: query = f"id==({' OR '.join(batch_ids)})" params = {"query": query, "limit": fetch_batch_size} try: - return await self.folio_client.folio_get_async( - query_endpoint, key=object_name, query_params=params - ) + async with self.semaphore: + return await self.folio_client.folio_get_async( + query_endpoint, key=object_name, query_params=params + ) except folioclient.FolioClientError as e: logger.error(f"FOLIO client error fetching existing records: {e}") raise @@ -1412,16 +1426,7 @@ async def run_batch_poster( """ async with folio_client: try: - # Create progress reporter - reporter = ( - NoOpProgressReporter() - if config.no_progress - else RichProgressReporter(show_speed=True, show_time=True) - ) - - poster = BatchPoster( - folio_client, config, failed_records_file=failed_records_file, reporter=reporter - ) + poster = BatchPoster(folio_client, config, failed_records_file=failed_records_file) async with poster: await poster.do_work(files_to_process) diff --git a/src/folio_data_import/DILogRetriever.py b/src/folio_data_import/DILogRetriever.py new file mode 100644 index 0000000..4a95031 --- /dev/null +++ b/src/folio_data_import/DILogRetriever.py @@ -0,0 +1,259 @@ +import csv +import json +import logging +import sys +from pathlib import Path +from typing import Annotated, Dict, List, Tuple + +import cyclopts +import folioclient +import pymarc + +from folio_data_import import ( + __version__ as app_version, +) +from folio_data_import import ( + get_folio_connection_parameters, + set_up_cli_logging, +) +from folio_data_import._postgres import ( + POSTGRES_AVAILABLE, + POSTGRES_INSTALL_MESSAGE, + PostgresConfig, + RealDictCursor, + SSHTunnelConfig, + db_session, +) +from folio_data_import._progress import ( + NoOpProgressReporter, + ProgressReporter, + RichProgressReporter, +) + +logger = logging.getLogger(__name__) + + +class DILogRetriever: + def __init__( + self, + folio_client: folioclient.FolioClient, + db_config: PostgresConfig, + ssh_tunnel_config: SSHTunnelConfig, + progress_reporter: ProgressReporter | None = None, + ): + self.folio_client = folio_client + self.db_config = db_config + self.ssh_tunnel_config = ssh_tunnel_config + self.progress_reporter = ( + progress_reporter if progress_reporter is not None else NoOpProgressReporter() + ) + + def retrieve_errors_with_marc(self, job_ids: list[str]) -> List[Tuple[str, pymarc.Record]]: + error_logs = [] + get_di_errors = self.progress_reporter.start_task( + "get_di_error_logs_with_marc", + total=len(job_ids), + description="Retrieving DI error logs with MARC records", + ) + with ( + db_session( + db_config=self.db_config, ssh_tunnel_config=self.ssh_tunnel_config + ) as session, + ): + tenant = self.folio_client.tenant_id + for job_id in job_ids: + query = f""" + SELECT DISTINCT ON (jr.source_id) + jr.id, + jr.job_execution_id, + jr.source_id, + jr.error, + ir.incoming_record + FROM + "{tenant}_mod_source_record_manager".journal_records AS jr + INNER JOIN + "{tenant}_mod_source_record_manager".incoming_records AS ir + ON jr.source_id = ir.id + WHERE + jr.job_execution_id = %s + AND jr.error <> ''; + """ # noqa: S608 + cur = session.cursor(cursor_factory=RealDictCursor) + cur.execute(query, (job_id,)) + result = cur.fetchall() + for row in result: + if row: + try: + incoming_record = row.get("incoming_record") + if not incoming_record or "rawRecordContent" not in incoming_record: + logger.warning( + "Skipping record %s: missing rawRecordContent", + row.get("source_id", "unknown"), + ) + continue + marc_record = pymarc.record.Record( + incoming_record["rawRecordContent"].encode("utf-8"), + force_utf8=True, + ) + error_logs.append( + ( + json.dumps(row.get("error", "")), + marc_record, + ) + ) + except Exception as e: + logger.warning( + "Failed to parse MARC record for source_id %s: %s", + row.get("source_id", "unknown"), + str(e), + ) + cur.close() + self.progress_reporter.update_task(get_di_errors, advance=1) + self.progress_reporter.finish_task(get_di_errors) + return error_logs + + def generate_error_report_and_marc_file( + self, + error_logs: List[Tuple[Dict, pymarc.record.Record]], + report_file_path: str, + marc_file_path: str, + ): + with ( + open(report_file_path, "w", encoding="utf-8") as report_file, + open(marc_file_path, "wb") as marc_file, + ): + marc_writer = pymarc.MARCWriter(marc_file) + csv_writer = csv.writer( + report_file, delimiter="\t", quotechar="'", quoting=csv.QUOTE_ALL + ) + csv_writer.writerow(["Error Log", "MARC Record"]) + for error_log, marc_record in error_logs: + csv_writer.writerow([error_log, marc_record.as_marc().decode("utf-8")]) + marc_writer.write(marc_record) + marc_writer.close() + + +app = cyclopts.App( + version=app_version, +) + + +@app.default +def main( + folio_url: Annotated[ + str | None, + cyclopts.Parameter(env_var="FOLIO_URL", help="FOLIO Gateway URL"), + ] = None, + folio_tenant: Annotated[ + str | None, + cyclopts.Parameter(env_var="FOLIO_TENANT", help="FOLIO Tenant ID"), + ] = None, + folio_username: Annotated[ + str | None, + cyclopts.Parameter(env_var="FOLIO_USERNAME", help="FOLIO Username"), + ] = None, + folio_password: Annotated[ + str | None, + cyclopts.Parameter(env_var="FOLIO_PASSWORD", help="FOLIO Password"), + ] = None, + db_config: Annotated[ + Path | None, + cyclopts.Parameter(help="Path to the database configuration file (JSON format)"), + ] = None, + ssh_config: Annotated[ + Path | None, + cyclopts.Parameter(help="Path to the SSH tunnel configuration file (JSON format)"), + ] = None, + job_ids_file: Annotated[ + Path, + cyclopts.Parameter( + help="Path to a text file containing Data Import job execution IDs (one per line)" + ), + ] = Path("marc_import_job_ids.txt"), + report_file_path: Annotated[ + Path, + cyclopts.Parameter(help="Path to save the error report TSV file"), + ] = Path("di_error_report.tsv"), + marc_file_path: Annotated[ + Path, cyclopts.Parameter(help="Path to save the MARC records file") + ] = Path("di_error_records.mrc"), + no_progress: Annotated[ + bool, + cyclopts.Parameter(help="Disable progress reporting"), + ] = False, + debug: Annotated[ + bool, + cyclopts.Parameter(help="Enable debug logging"), + ] = False, +) -> None: + """Retrieve FOLIO Data Import error logs with MARC records and generate report files. + Requires PostgreSQL access. + + Args: + folio_url (str | None): FOLIO Gateway URL. + folio_tenant (str | None): FOLIO Tenant ID. + folio_username (str | None): FOLIO Username. + folio_password (str | None): FOLIO Password. + db_config (Path | None): Path to the database configuration file (JSON format). + ssh_config (Path | None): Path to the SSH tunnel configuration file (JSON format). + job_ids_file (Path): Path to a text file containing Data Import job execution IDs. + report_file_path (Path): Path to save the error report TSV file. + marc_file_path (Path): Path to save the MARC records file. + no_progress (bool): Disable progress reporting if True. + debug (bool): Enable debug logging if True. + """ + # Check for required PostgreSQL dependencies + if not POSTGRES_AVAILABLE: + print(f"Error: {POSTGRES_INSTALL_MESSAGE}", file=sys.stderr) + sys.exit(1) + + set_up_cli_logging(logger, log_file_prefix="di_log_retriever", debug=debug) + folio_url, folio_tenant, folio_username, folio_password = get_folio_connection_parameters( + folio_url, folio_tenant, folio_username, folio_password + ) + folio_client = folioclient.FolioClient( + gateway_url=folio_url, + tenant_id=folio_tenant, + username=folio_username, + password=folio_password, + ) + job_ids: List[str] = [] + with open(job_ids_file, "r", encoding="utf-8") as f: + job_ids = [line.strip() for line in f if line.strip()] + if job_ids: + if db_config is None: + print( + "Error: --db-config is required. Please provide a path to the database " + "configuration file (JSON format).", + file=sys.stderr, + ) + sys.exit(1) + with open(db_config, "r", encoding="utf-8") as f: + database_config = PostgresConfig.model_validate_json(f.read()) + if ssh_config: + with open(ssh_config, "r", encoding="utf-8") as f: + ssh_tunnel_config = SSHTunnelConfig.model_validate_json(f.read()) + else: + ssh_tunnel_config = SSHTunnelConfig(ssh_tunnel=False) + progress_reporter = ( + NoOpProgressReporter() if no_progress else RichProgressReporter(enabled=True) + ) + with progress_reporter: + retriever = DILogRetriever( + folio_client=folio_client, + db_config=database_config, + ssh_tunnel_config=ssh_tunnel_config, + progress_reporter=progress_reporter, + ) + error_logs = retriever.retrieve_errors_with_marc(job_ids=job_ids) + retriever.generate_error_report_and_marc_file( + error_logs=error_logs, + report_file_path=report_file_path, + marc_file_path=marc_file_path, + ) + else: + print("No job IDs found in the specified file.") + + +if __name__ == "__main__": + app() diff --git a/src/folio_data_import/MARCDataImport.py b/src/folio_data_import/MARCDataImport.py index 273ef90..e4af143 100644 --- a/src/folio_data_import/MARCDataImport.py +++ b/src/folio_data_import/MARCDataImport.py @@ -234,7 +234,15 @@ def __init__( ) -> None: self.folio_client: folioclient.FolioClient = folio_client self.config = config - self.reporter = reporter or NoOpProgressReporter() + # Create reporter from config if not provided + if reporter is None: + self.reporter = ( + NoOpProgressReporter() + if config.no_progress + else RichProgressReporter(show_speed=True, show_time=True) + ) + else: + self.reporter = reporter self.current_retry_timeout: float | None = None self.marc_record_preprocessor: MARCPreprocessor = MARCPreprocessor( config.marc_record_preprocessors or "", **(config.preprocessors_args or {}) @@ -1086,58 +1094,54 @@ def main( if member_tenant_id: folio_client.tenant_id = member_tenant_id - # Handle file path expansion - marc_files = collect_marc_file_paths(marc_file_paths) - - marc_files.sort() - - if len(marc_files) == 0: - logger.critical(f"No files found matching {marc_file_paths}. Exiting.") - sys.exit(1) - else: - logger.info(marc_files) - - if preprocessors_config: - with open(preprocessors_config, "r") as f: - preprocessor_args = json.load(f) - else: - preprocessor_args = {} - - if not import_profile_name: - import_profile_name = select_import_profile(folio_client) - - job = None - try: - if config_file: + if config_file: + # Load configuration from file + try: with open(config_file, "r") as f: config_data = json.load(f) config = MARCImportJob.Config(**config_data) + except Exception as e: + logger.critical(f"Failed to load configuration file {config_file}: {e}") + sys.exit(1) + else: + # Handle file path expansion for CLI invocation + marc_files = collect_marc_file_paths(marc_file_paths) + marc_files.sort() + + if len(marc_files) == 0: + logger.critical(f"No files found matching {marc_file_paths}. Exiting.") + sys.exit(1) else: - config = MARCImportJob.Config( - marc_files=marc_files, - import_profile_name=import_profile_name, - batch_size=batch_size, - batch_delay=batch_delay, - marc_record_preprocessors=preprocessors, - preprocessors_args=preprocessor_args, - no_progress=no_progress, - no_summary=no_summary, - let_summary_fail=let_summary_fail, - split_files=split_files, - split_size=split_size, - split_offset=split_offset, - job_ids_file_path=Path(job_ids_file_path) if job_ids_file_path else None, - show_file_names_in_data_import_logs=file_names_in_di_logs, - ) + logger.info(marc_files) - # Create progress reporter - reporter = ( - NoOpProgressReporter() - if no_progress - else RichProgressReporter(show_speed=True, show_time=True) + if preprocessors_config: + with open(preprocessors_config, "r") as f: + preprocessor_args = json.load(f) + else: + preprocessor_args = {} + + if not import_profile_name: + import_profile_name = select_import_profile(folio_client) + + config = MARCImportJob.Config( + marc_files=marc_files, + import_profile_name=import_profile_name, + batch_size=batch_size, + batch_delay=batch_delay, + marc_record_preprocessors=preprocessors, + preprocessors_args=preprocessor_args, + no_progress=no_progress, + no_summary=no_summary, + let_summary_fail=let_summary_fail, + split_files=split_files, + split_size=split_size, + split_offset=split_offset, + job_ids_file_path=Path(job_ids_file_path) if job_ids_file_path else None, + show_file_names_in_data_import_logs=file_names_in_di_logs, ) - job = MARCImportJob(folio_client, config, reporter) + try: + job = MARCImportJob(folio_client, config) asyncio.run(run_job(job)) except Exception as e: logger.error("Could not initialize MARCImportJob: " + str(e)) diff --git a/src/folio_data_import/UserImport.py b/src/folio_data_import/UserImport.py index fe955ee..4017c46 100644 --- a/src/folio_data_import/UserImport.py +++ b/src/folio_data_import/UserImport.py @@ -156,7 +156,15 @@ def __init__( ) -> None: self.config = config self.folio_client: folioclient.FolioClient = folio_client - self.reporter = reporter or NoOpProgressReporter() + # Create reporter from config if not provided + if reporter is None: + self.reporter = ( + NoOpProgressReporter() + if config.no_progress + else RichProgressReporter(show_speed=True, show_time=True) + ) + else: + self.reporter = reporter self.limit_simultaneous_requests = asyncio.Semaphore(config.limit_simultaneous_requests) # Build reference data maps (these need processing) self.patron_group_map: dict = self.build_ref_data_id_map( @@ -1161,16 +1169,47 @@ def main( if member_tenant_id: folio_client.tenant_id = member_tenant_id - if not library_name: - raise ValueError("library_name is required") + report_file_base_path = report_file_base_path or Path.cwd() + error_file_path = ( + report_file_base_path / f"failed_user_import_{dt.now(utc).strftime('%Y%m%d_%H%M%S')}.txt" + ) - if not user_file_paths: - raise ValueError( - "You must provide at least one user file path using --user-file-paths or " - "--user-file-path." + config_data = {} + if config_file: + try: + with open(config_file, "r") as f: + config_data = json.load(f) + config = UserImporter.Config(**config_data) + except Exception as e: + logger.critical(f"Failed to load configuration file {config_file}: {e}") + sys.exit(1) + else: + # Expand any glob patterns in file paths + expanded_paths = pathify_user_file_paths(user_file_paths) + + # Convert to single Path or List[Path] for Config + file_paths_list = expanded_paths if len(expanded_paths) > 1 else expanded_paths[0] + + config = UserImporter.Config( + library_name=library_name, + batch_size=batch_size, + user_match_key=user_match_key, + only_update_present_fields=update_only_present_fields, + default_preferred_contact_type=default_preferred_contact_type, + fields_to_protect=protect_fields, + limit_simultaneous_requests=limit_async_requests, + user_file_paths=file_paths_list, + no_progress=no_progress, ) + try: + importer = UserImporter(folio_client, config) + asyncio.run(run_user_importer(importer, error_file_path)) + except Exception as ee: + logger.critical(f"An unknown error occurred: {ee}") + sys.exit(1) + - # Expand any glob patterns in file paths +def pathify_user_file_paths(user_file_paths): expanded_paths = [] for path_arg in user_file_paths: path_str = str(path_arg) @@ -1185,45 +1224,7 @@ def main( expanded_paths.append(path_arg) else: expanded_paths.append(path_arg) - - # Convert to single Path or List[Path] for Config - file_paths_list = expanded_paths if len(expanded_paths) > 1 else expanded_paths[0] - - report_file_base_path = report_file_base_path or Path.cwd() - error_file_path = ( - report_file_base_path / f"failed_user_import_{dt.now(utc).strftime('%Y%m%d_%H%M%S')}.txt" - ) - try: - # Create UserImporter.Config object - if config_file: - with open(config_file, "r") as f: - config_data = json.load(f) - config = UserImporter.Config(**config_data) - else: - config = UserImporter.Config( - library_name=library_name, - batch_size=batch_size, - user_match_key=user_match_key, - only_update_present_fields=update_only_present_fields, - default_preferred_contact_type=default_preferred_contact_type, - fields_to_protect=protect_fields, - limit_simultaneous_requests=limit_async_requests, - user_file_paths=file_paths_list, - no_progress=no_progress, - ) - - # Create progress reporter - reporter = ( - NoOpProgressReporter() - if no_progress - else RichProgressReporter(show_speed=True, show_time=True) - ) - - importer = UserImporter(folio_client, config, reporter) - asyncio.run(run_user_importer(importer, error_file_path)) - except Exception as ee: - logger.critical(f"An unknown error occurred: {ee}") - sys.exit(1) + return expanded_paths async def run_user_importer(importer: UserImporter, error_file_path: Path): diff --git a/src/folio_data_import/__main__.py b/src/folio_data_import/__main__.py index b91637f..93db4c3 100644 --- a/src/folio_data_import/__main__.py +++ b/src/folio_data_import/__main__.py @@ -9,6 +9,8 @@ app.command("folio_data_import.MARCDataImport:main", name="marc") app.command("folio_data_import.UserImport:main", name="users") app.command("folio_data_import.BatchPoster:main", name="batch-poster") +app.command("folio_data_import.DILogRetriever:main", name="get-di-logs") + if __name__ == "__main__": app() diff --git a/src/folio_data_import/_postgres.py b/src/folio_data_import/_postgres.py new file mode 100644 index 0000000..7ce922b --- /dev/null +++ b/src/folio_data_import/_postgres.py @@ -0,0 +1,174 @@ +import logging +import socket +import subprocess +import time +from contextlib import contextmanager +from typing import Any, Iterator, Optional, Protocol + +from pydantic import BaseModel + +try: + import psycopg2 + from psycopg2.extras import RealDictCursor + + POSTGRES_AVAILABLE = True +except ImportError: + psycopg2 = None # type: ignore[assignment] + RealDictCursor = None # type: ignore[assignment] + POSTGRES_AVAILABLE = False + + +logger = logging.getLogger(__name__) + + +POSTGRES_INSTALL_MESSAGE = ( + "PostgreSQL support requires the 'postgres' optional dependencies.\n" + "Install with: pip install 'folio_data_import[postgres]'\n" + " or: uv add 'folio_data_import[postgres]'" +) + + +def require_postgres() -> None: + """Raise ImportError with helpful message if psycopg2 is not available.""" + if not POSTGRES_AVAILABLE: + raise ImportError(POSTGRES_INSTALL_MESSAGE) + + +class DatabaseCursor(Protocol): + """Protocol describing the database cursor interface we use.""" + + def execute(self, query: str, vars: Any = None) -> None: ... + def fetchall(self) -> list[dict[str, Any]]: ... + def fetchone(self) -> Optional[dict[str, Any]]: ... + def close(self) -> None: ... + + +class DatabaseConnection(Protocol): + """Protocol describing the database connection interface we use.""" + + def cursor(self, cursor_factory: Any = None) -> DatabaseCursor: ... + def commit(self) -> None: ... + def rollback(self) -> None: ... + def close(self) -> None: ... + + +class PostgresConfig(BaseModel): + host: str + port: int = 5432 + database: str + user: str + password: Optional[str] = None + + +class SSHTunnelConfig(BaseModel): + ssh_path: Optional[str] = "ssh" + ssh_tunnel: bool = False + use_ssh_config: bool = False + ssh_host: Optional[str] = None + ssh_user: Optional[str] = None + ssh_private_key_path: Optional[str] = None + + +def connect_postgres(cfg): + require_postgres() + return psycopg2.connect( + host=cfg.host, + port=cfg.port, + dbname=cfg.database, + user=cfg.user, + password=cfg.password, + connect_timeout=5, + gssencmode="disable", + ) + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + try: + with socket.create_connection((host, port), timeout=0.2): + return + except OSError: + time.sleep(0.05) + raise TimeoutError(f"SSH tunnel did not open port {port}") + + +@contextmanager +def ssh_tunnel( + *, + ssh_path: Optional[str], + ssh_host: str, + remote_host: str, + remote_port: int, +) -> Iterator[int]: + local_port = _free_port() + + proc = subprocess.Popen( + [ + ssh_path or "ssh", + "-N", + "-L", + f"{local_port}:{remote_host}:{remote_port}", + "-o", + "ExitOnForwardFailure=yes", + ssh_host, + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True, + ) + + try: + _wait_for_port("127.0.0.1", local_port) + yield local_port + finally: + proc.terminate() + proc.wait() + + +@contextmanager +def db_session( + *, + db_config: PostgresConfig, + ssh_tunnel_config: Optional[SSHTunnelConfig] = None, +) -> Iterator[DatabaseConnection]: + conn: Optional[DatabaseConnection] = None + + try: + if ssh_tunnel_config and ssh_tunnel_config.ssh_tunnel: + if not ssh_tunnel_config.ssh_host: + raise ValueError("ssh_host is required when ssh_tunnel is enabled") + + with ssh_tunnel( + ssh_path=ssh_tunnel_config.ssh_path, + ssh_host=ssh_tunnel_config.ssh_host, + remote_host=db_config.host, + remote_port=db_config.port, + ) as local_port: + pg_cfg = db_config.model_copy() + pg_cfg.host = "127.0.0.1" + pg_cfg.port = local_port + logger.info("Tunnel listening on port %s", local_port) + + conn = connect_postgres(pg_cfg) + yield conn + conn.commit() + else: + conn = connect_postgres(db_config) + yield conn + conn.commit() + + except Exception: + if conn: + conn.rollback() + raise + + finally: + if conn: + conn.close() diff --git a/src/folio_data_import/marc_preprocessors/_preprocessors.py b/src/folio_data_import/marc_preprocessors/_preprocessors.py index bb0663f..2cd6b8c 100644 --- a/src/folio_data_import/marc_preprocessors/_preprocessors.py +++ b/src/folio_data_import/marc_preprocessors/_preprocessors.py @@ -521,6 +521,31 @@ def mark_deleted(record: Record, **kwargs) -> Record: return record +def remove_non_numeric_fields(record: Record, **kwargs) -> Record: + """ + Remove all fields from the record that have non-numeric tags (not matching pattern 001-999). + Also removes field 000, which is invalid in MARC records. + + Args: + record (Record): The MARC record to preprocess. + + Returns: + Record: The preprocessed MARC record. + """ + for field in record.get_fields(): + if not re.fullmatch(r"\d{3}", field.tag) or field.tag == "000": + reason = "invalid tag 000" if field.tag == "000" else f"non-numeric tag {field.tag}" + logger.log( + 26, + "DATA ISSUE\t%s\t%s\t%s", + record["001"].value(), + f"Field removed: {reason}", + field, + ) + record.remove_field(field) + return record + + def ordinal(n: int) -> str: s = ("th", "st", "nd", "rd") + ("th",) * 10 v = n % 100 diff --git a/tests/test_di_log_retriever.py b/tests/test_di_log_retriever.py new file mode 100644 index 0000000..093f77d --- /dev/null +++ b/tests/test_di_log_retriever.py @@ -0,0 +1,444 @@ +"""Tests for the DILogRetriever module.""" + +import csv +import json +import pytest +from io import BytesIO, StringIO +from unittest.mock import MagicMock, Mock, patch, mock_open + +import pymarc + +from folio_data_import.DILogRetriever import DILogRetriever +from folio_data_import._postgres import PostgresConfig, SSHTunnelConfig +from folio_data_import._progress import NoOpProgressReporter + + +class TestDILogRetriever: + """Tests for DILogRetriever class.""" + + @pytest.fixture + def mock_folio_client(self): + """Create a mock FolioClient.""" + client = MagicMock() + client.tenant_id = "test_tenant" + return client + + @pytest.fixture + def db_config(self): + """Create a test database configuration.""" + return PostgresConfig( + host="localhost", + port=5432, + database="folio", + user="folio", + password="test_password", + ) + + @pytest.fixture + def ssh_tunnel_config(self): + """Create a test SSH tunnel configuration (disabled).""" + return SSHTunnelConfig(ssh_tunnel=False) + + @pytest.fixture + def retriever(self, mock_folio_client, db_config, ssh_tunnel_config): + """Create a DILogRetriever instance for testing.""" + return DILogRetriever( + folio_client=mock_folio_client, + db_config=db_config, + ssh_tunnel_config=ssh_tunnel_config, + progress_reporter=NoOpProgressReporter(), + ) + + def test_init_with_defaults(self, mock_folio_client, db_config, ssh_tunnel_config): + """Test DILogRetriever initialization with default progress reporter.""" + retriever = DILogRetriever( + folio_client=mock_folio_client, + db_config=db_config, + ssh_tunnel_config=ssh_tunnel_config, + ) + assert retriever.folio_client == mock_folio_client + assert retriever.db_config == db_config + assert retriever.ssh_tunnel_config == ssh_tunnel_config + assert retriever.progress_reporter is not None + + def test_init_with_custom_progress_reporter( + self, mock_folio_client, db_config, ssh_tunnel_config + ): + """Test DILogRetriever initialization with custom progress reporter.""" + progress_reporter = NoOpProgressReporter() + retriever = DILogRetriever( + folio_client=mock_folio_client, + db_config=db_config, + ssh_tunnel_config=ssh_tunnel_config, + progress_reporter=progress_reporter, + ) + assert retriever.progress_reporter == progress_reporter + + +class TestRetrieveErrorsWithMarc: + """Tests for retrieve_errors_with_marc method.""" + + @pytest.fixture + def mock_folio_client(self): + """Create a mock FolioClient.""" + client = MagicMock() + client.tenant_id = "test_tenant" + return client + + @pytest.fixture + def db_config(self): + """Create a test database configuration.""" + return PostgresConfig( + host="localhost", + port=5432, + database="folio", + user="folio", + password="test_password", + ) + + @pytest.fixture + def ssh_tunnel_config(self): + """Create a test SSH tunnel configuration (disabled).""" + return SSHTunnelConfig(ssh_tunnel=False) + + @pytest.fixture + def retriever(self, mock_folio_client, db_config, ssh_tunnel_config): + """Create a DILogRetriever instance for testing.""" + return DILogRetriever( + folio_client=mock_folio_client, + db_config=db_config, + ssh_tunnel_config=ssh_tunnel_config, + progress_reporter=NoOpProgressReporter(), + ) + + @pytest.fixture + def sample_marc_record(self): + """Create a sample MARC record for testing.""" + record = pymarc.Record() + record.add_field( + pymarc.Field( + tag="245", + indicators=["0", "0"], + subfields=[pymarc.Subfield(code="a", value="Test Title")], + ) + ) + return record + + @patch("folio_data_import.DILogRetriever.db_session") + def test_retrieve_errors_with_marc_success( + self, mock_db_session, retriever, sample_marc_record + ): + """Test successful retrieval of error logs with MARC records.""" + # Create mock cursor and session + mock_cursor = MagicMock() + mock_session = MagicMock() + mock_session.cursor.return_value = mock_cursor + + # Create sample row data using dict format (RealDictCursor) + raw_marc = sample_marc_record.as_marc().decode("utf-8") + mock_cursor.fetchall.return_value = [ + { + "id": "record-id-1", + "job_execution_id": "job-1", + "source_id": "source-1", + "error": "Some error message", + "incoming_record": {"rawRecordContent": raw_marc}, + } + ] + + mock_db_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_db_session.return_value.__exit__ = Mock(return_value=False) + + result = retriever.retrieve_errors_with_marc(["job-1"]) + + assert len(result) == 1 + error_log, marc_record = result[0] + assert json.loads(error_log) == "Some error message" + assert isinstance(marc_record, pymarc.Record) + + @patch("folio_data_import.DILogRetriever.db_session") + def test_retrieve_errors_with_marc_empty_results(self, mock_db_session, retriever): + """Test retrieval when no error records are found.""" + mock_cursor = MagicMock() + mock_session = MagicMock() + mock_session.cursor.return_value = mock_cursor + mock_cursor.fetchall.return_value = [] + + mock_db_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_db_session.return_value.__exit__ = Mock(return_value=False) + + result = retriever.retrieve_errors_with_marc(["job-1"]) + + assert len(result) == 0 + + @patch("folio_data_import.DILogRetriever.db_session") + def test_retrieve_errors_with_marc_missing_raw_content(self, mock_db_session, retriever): + """Test handling of records with missing rawRecordContent.""" + mock_cursor = MagicMock() + mock_session = MagicMock() + mock_session.cursor.return_value = mock_cursor + + # Row with missing rawRecordContent + mock_cursor.fetchall.return_value = [ + { + "id": "record-id-1", + "job_execution_id": "job-1", + "source_id": "source-1", + "error": "Some error message", + "incoming_record": {}, # Missing rawRecordContent + } + ] + + mock_db_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_db_session.return_value.__exit__ = Mock(return_value=False) + + result = retriever.retrieve_errors_with_marc(["job-1"]) + + # Should skip the record and return empty list + assert len(result) == 0 + + @patch("folio_data_import.DILogRetriever.db_session") + def test_retrieve_errors_with_marc_null_incoming_record(self, mock_db_session, retriever): + """Test handling of records with null incoming_record.""" + mock_cursor = MagicMock() + mock_session = MagicMock() + mock_session.cursor.return_value = mock_cursor + + mock_cursor.fetchall.return_value = [ + { + "id": "record-id-1", + "job_execution_id": "job-1", + "source_id": "source-1", + "error": "Some error message", + "incoming_record": None, + } + ] + + mock_db_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_db_session.return_value.__exit__ = Mock(return_value=False) + + result = retriever.retrieve_errors_with_marc(["job-1"]) + + # Should skip the record and return empty list + assert len(result) == 0 + + @patch("folio_data_import.DILogRetriever.db_session") + def test_retrieve_errors_with_marc_malformed_marc(self, mock_db_session, retriever): + """Test handling of malformed MARC data.""" + mock_cursor = MagicMock() + mock_session = MagicMock() + mock_session.cursor.return_value = mock_cursor + + mock_cursor.fetchall.return_value = [ + { + "id": "record-id-1", + "job_execution_id": "job-1", + "source_id": "source-1", + "error": "Some error message", + "incoming_record": {"rawRecordContent": "not valid marc data"}, + } + ] + + mock_db_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_db_session.return_value.__exit__ = Mock(return_value=False) + + # Should not raise, should log warning and skip + result = retriever.retrieve_errors_with_marc(["job-1"]) + + # May or may not parse depending on pymarc's tolerance + # The important thing is it doesn't raise an exception + assert isinstance(result, list) + + @patch("folio_data_import.DILogRetriever.db_session") + def test_retrieve_errors_multiple_jobs(self, mock_db_session, retriever, sample_marc_record): + """Test retrieval across multiple job IDs.""" + mock_cursor = MagicMock() + mock_session = MagicMock() + mock_session.cursor.return_value = mock_cursor + + raw_marc = sample_marc_record.as_marc().decode("utf-8") + + # Return different results for each call + mock_cursor.fetchall.side_effect = [ + [ + { + "id": "record-1", + "job_execution_id": "job-1", + "source_id": "source-1", + "error": "Error 1", + "incoming_record": {"rawRecordContent": raw_marc}, + } + ], + [ + { + "id": "record-2", + "job_execution_id": "job-2", + "source_id": "source-2", + "error": "Error 2", + "incoming_record": {"rawRecordContent": raw_marc}, + } + ], + ] + + mock_db_session.return_value.__enter__ = Mock(return_value=mock_session) + mock_db_session.return_value.__exit__ = Mock(return_value=False) + + result = retriever.retrieve_errors_with_marc(["job-1", "job-2"]) + + assert len(result) == 2 + + +class TestGenerateErrorReportAndMarcFile: + """Tests for generate_error_report_and_marc_file method.""" + + @pytest.fixture + def mock_folio_client(self): + """Create a mock FolioClient.""" + client = MagicMock() + client.tenant_id = "test_tenant" + return client + + @pytest.fixture + def db_config(self): + """Create a test database configuration.""" + return PostgresConfig( + host="localhost", + port=5432, + database="folio", + user="folio", + password="test_password", + ) + + @pytest.fixture + def ssh_tunnel_config(self): + """Create a test SSH tunnel configuration (disabled).""" + return SSHTunnelConfig(ssh_tunnel=False) + + @pytest.fixture + def retriever(self, mock_folio_client, db_config, ssh_tunnel_config): + """Create a DILogRetriever instance for testing.""" + return DILogRetriever( + folio_client=mock_folio_client, + db_config=db_config, + ssh_tunnel_config=ssh_tunnel_config, + progress_reporter=NoOpProgressReporter(), + ) + + @pytest.fixture + def sample_marc_record(self): + """Create a sample MARC record for testing.""" + record = pymarc.Record() + record.add_field( + pymarc.Field( + tag="245", + indicators=["0", "0"], + subfields=[pymarc.Subfield(code="a", value="Test Title")], + ) + ) + return record + + def test_generate_error_report_and_marc_file(self, retriever, sample_marc_record, tmp_path): + """Test generation of error report TSV and MARC file.""" + error_logs = [ + (json.dumps("Error message 1"), sample_marc_record), + (json.dumps("Error message 2"), sample_marc_record), + ] + + report_path = tmp_path / "report.tsv" + marc_path = tmp_path / "records.mrc" + + retriever.generate_error_report_and_marc_file( + error_logs=error_logs, + report_file_path=str(report_path), + marc_file_path=str(marc_path), + ) + + # Verify report file was created + assert report_path.exists() + with open(report_path, "r", encoding="utf-8") as f: + reader = csv.reader(f, delimiter="\t", quotechar="'") + rows = list(reader) + assert len(rows) == 3 # Header + 2 data rows + assert rows[0] == ["Error Log", "MARC Record"] + + # Verify MARC file was created + assert marc_path.exists() + with open(marc_path, "rb") as f: + reader = pymarc.MARCReader(f) + records = list(reader) + assert len(records) == 2 + + def test_generate_error_report_empty_logs(self, retriever, tmp_path): + """Test generation with empty error logs.""" + report_path = tmp_path / "report.tsv" + marc_path = tmp_path / "records.mrc" + + retriever.generate_error_report_and_marc_file( + error_logs=[], + report_file_path=str(report_path), + marc_file_path=str(marc_path), + ) + + # Verify report file was created with only header + assert report_path.exists() + with open(report_path, "r", encoding="utf-8") as f: + reader = csv.reader(f, delimiter="\t", quotechar="'") + rows = list(reader) + assert len(rows) == 1 # Only header + assert rows[0] == ["Error Log", "MARC Record"] + + +class TestPostgresConfig: + """Tests for PostgresConfig model.""" + + def test_postgres_config_defaults(self): + """Test PostgresConfig with default values.""" + config = PostgresConfig( + host="localhost", + database="folio", + user="folio", + ) + assert config.port == 5432 + assert config.password is None + + def test_postgres_config_full(self): + """Test PostgresConfig with all values.""" + config = PostgresConfig( + host="db.example.com", + port=5433, + database="folio_db", + user="admin", + password="secret", + ) + assert config.host == "db.example.com" + assert config.port == 5433 + assert config.database == "folio_db" + assert config.user == "admin" + assert config.password == "secret" + + +class TestSSHTunnelConfig: + """Tests for SSHTunnelConfig model.""" + + def test_ssh_tunnel_config_defaults(self): + """Test SSHTunnelConfig with default values.""" + config = SSHTunnelConfig() + assert config.ssh_path == "ssh" + assert config.ssh_tunnel is False + assert config.use_ssh_config is False + assert config.ssh_host is None + assert config.ssh_user is None + assert config.ssh_private_key_path is None + + def test_ssh_tunnel_config_enabled(self): + """Test SSHTunnelConfig with tunnel enabled.""" + config = SSHTunnelConfig( + ssh_tunnel=True, + ssh_host="bastion.example.com", + ssh_user="tunnel_user", + ssh_private_key_path="~/.ssh/id_rsa", + ) + assert config.ssh_tunnel is True + assert config.ssh_host == "bastion.example.com" + assert config.ssh_user == "tunnel_user" + assert config.ssh_private_key_path == "~/.ssh/id_rsa" diff --git a/tests/test_marc_data_import.py b/tests/test_marc_data_import.py index 5af2972..b217256 100644 --- a/tests/test_marc_data_import.py +++ b/tests/test_marc_data_import.py @@ -115,7 +115,8 @@ async def test_wrap_up_removes_empty_files(self, tmp_path, folio_client): marc_record_preprocessors="", preprocessors_args={}, job_ids_file_path=str(job_ids_file), - marc_files=[tmp_path / "test.mrc"] + marc_files=[tmp_path / "test.mrc"], + no_progress=True, ) with patch('folio_data_import.MARCDataImport.logger') as mock_logger: @@ -164,7 +165,8 @@ async def test_wrap_up_keeps_non_empty_files(self, tmp_path, folio_client): marc_record_preprocessors="", preprocessors_args={}, job_ids_file_path=str(job_ids_file), - marc_files=[tmp_path / "test.mrc"] + marc_files=[tmp_path / "test.mrc"], + no_progress=True, ) with patch('folio_data_import.MARCDataImport.logger'): @@ -276,7 +278,8 @@ async def test_wrap_up_removes_empty_job_ids_file(self, tmp_path, folio_client): marc_record_preprocessors="", preprocessors_args={}, job_ids_file_path=str(job_ids_file), - marc_files=[tmp_path / "test.mrc"] + marc_files=[tmp_path / "test.mrc"], + no_progress=True, ) with patch('folio_data_import.MARCDataImport.logger') as mock_logger: