diff --git a/pyproject.toml b/pyproject.toml index 12391a7..a66e54e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ dependencies = [ 'dask', 'pydantic>2.9', 'pydantic-settings>=2.8', - 'boto3' + 'boto3', + 'requests' ] [project.optional-dependencies] diff --git a/src/aind_data_upload_utils/trigger_co_cleanup_notification.py b/src/aind_data_upload_utils/trigger_co_cleanup_notification.py new file mode 100644 index 0000000..3ab2b9e --- /dev/null +++ b/src/aind_data_upload_utils/trigger_co_cleanup_notification.py @@ -0,0 +1,291 @@ +""" +Job to parse CSV data and send webhook notifications. +""" + +import argparse +import csv +import logging +import os +import sys +from collections import defaultdict +from io import StringIO +from pathlib import Path +from typing import Dict, List, Set, Union + +import boto3 +import requests +from pydantic import Field +from pydantic_settings import BaseSettings + +LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING") +logging.basicConfig(level=LOG_LEVEL) + + +class JobSettings(BaseSettings): + """Job settings for WebhookNotificationJob""" + + csv_file: Union[Path, str] = Field( + ..., description="Path to the CSV file to parse (local/S3)." + ) + exclude_list_file: Union[Path, str] = Field( + ..., + description=( + "Path to the plain text file containing excluded " + "usernames or capsule URLs (one per line, local/S3)." + ), + ) + webhook_url: str = Field( + ..., description="Webhook URL to send notifications to." + ) + + +class WebhookNotificationJob: + """Job to parse CSV data and send webhook notifications.""" + + def __init__(self, job_settings: JobSettings): + """ + Class constructor for WebhookNotificationJob. + + Parameters + ---------- + job_settings: JobSettings + """ + self.job_settings = job_settings + + def _is_s3_uri(self, path: Union[Path, str]) -> bool: + """ + Check if the given path is an S3 URI. + + Parameters + ---------- + path: Union[Path, str] + Path to check. + + Returns + ------- + bool + True if path is an S3 URI, False otherwise. + """ + return str(path).startswith("s3://") + + def _parse_s3_uri(self, s3_uri: str) -> tuple[str, str]: + """ + Parse S3 URI into bucket and key. + + Parameters + ---------- + s3_uri: str + S3 URI in format s3://bucket/key. + + Returns + ------- + tuple[str, str] + Tuple of (bucket, key). + """ + path_part = s3_uri[5:] + bucket, key = path_part.split("/", 1) + return bucket, key + + def read_exclude_list(self) -> Set[str]: + """ + Reads the exclude list file and returns a set of items to exclude. + + Returns + ------- + Set[str] + Set of usernames or capsule URLs to exclude. + """ + exclude_items = set() + exclude_file_path = self.job_settings.exclude_list_file + + if self._is_s3_uri(exclude_file_path): + bucket, key = self._parse_s3_uri(str(exclude_file_path)) + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + exclude_content = response["Body"].read().decode("utf-8").strip() + s3_client.close() + logging.debug(f"Read exclude list from S3: s3://{bucket}/{key}") + else: + exclude_file_path = Path(exclude_file_path) + with open(exclude_file_path, "r", encoding="utf-8") as f: + exclude_content = f.read().strip() + logging.debug( + f"Read exclude list from local file: {exclude_file_path}" + ) + + if exclude_content: + exclude_items = { + item.strip() + for item in exclude_content.split("\n") + if item.strip() + } + + logging.debug(f"Exclude items: {exclude_items}") + return exclude_items + + def read_csv_file(self) -> List[Dict[str, str]]: + """ + Reads the CSV file and returns all rows as a list of dictionaries. + + Returns + ------- + List[Dict[str, str]] + List of dictionaries representing CSV rows. + """ + csv_file_path = self.job_settings.csv_file + + if self._is_s3_uri(csv_file_path): + bucket, key = self._parse_s3_uri(str(csv_file_path)) + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + csv_content = response["Body"].read().decode("utf-8") + s3_client.close() + logging.debug(f"Read CSV from S3: s3://{bucket}/{key}") + + csv_data = [] + csv_reader = csv.DictReader(StringIO(csv_content)) + for row in csv_reader: + csv_data.append(dict(row)) + else: + csv_data = [] + csv_file_path = Path(csv_file_path) + with open(csv_file_path, "r", encoding="utf-8") as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + csv_data.append(dict(row)) + logging.debug(f"Read CSV from local file: {csv_file_path}") + + logging.debug(f"Read {len(csv_data)} rows from CSV file") + return csv_data + + def filter_csv_data( + self, csv_data: List[Dict[str, str]], exclude_items: Set[str] + ) -> List[Dict[str, str]]: + """ + Filters CSV data by excluding specified usernames or capsule URLs. + + Parameters + ---------- + csv_data: List[Dict[str, str]] + List of dictionaries representing CSV rows. + exclude_items: Set[str] + Set of usernames or capsule URLs to exclude. + + Returns + ------- + List[Dict[str, str]] + Filtered list of dictionaries. + """ + filtered_data = [] + + for row_index, row in enumerate(csv_data): + user_email = row["user_email"] + capsule_url = row["capsule_url"] + + if user_email in exclude_items or capsule_url in exclude_items: + logging.info( + f"Excluding row {row_index + 1}: {user_email} - " + f"{capsule_url}" + ) + continue + + filtered_data.append(row) + + logging.debug(f"Filtered data: {len(filtered_data)} rows remaining") + return filtered_data + + def group_by_user( + self, filtered_data: List[Dict[str, str]] + ) -> Dict[str, List[Dict[str, str]]]: + """ + Groups filtered CSV data by user email. + + Parameters + ---------- + filtered_data: List[Dict[str, str]] + Filtered list of dictionaries representing CSV rows. + + Returns + ------- + Dict[str, List[Dict[str, str]]] + Dictionary with user emails as keys and lists of capsule data. + """ + user_data = defaultdict(list) + + for row in filtered_data: + user_email = row["user_email"] + capsule_data = {"capsule_url": row["capsule_url"]} + user_data[user_email].append(capsule_data) + + logging.debug(f"Grouped data for {len(user_data)} users") + return dict(user_data) + + def send_webhook_notifications( + self, user_data: Dict[str, List[Dict[str, str]]] + ) -> None: + """ + Sends POST requests to the webhook endpoint. + + Parameters + ---------- + user_data: Dict[str, List[Dict[str, str]]] + Dictionary with user emails as keys and lists of capsule data. + """ + webhook_url = self.job_settings.webhook_url + + for user_email, capsules in user_data.items(): + table_rows = "" + for capsule in capsules: + capsule_url = capsule["capsule_url"] + table_rows += f"{capsule_url}
" + + html_table = f"{table_rows}" + payload = {"user_email": user_email, "capsule_urls": html_table} + + try: + response = requests.post( + webhook_url, + json=payload, + headers={"Content-Type": "application/json"}, + verify=False, + timeout=30, + ) + response.raise_for_status() + logging.info( + f"Successfully sent notification for {user_email}" + ) + except requests.exceptions.RequestException as e: + logging.error( + f"Failed to send notification for {user_email}: {e}" + ) + raise + + def run_job(self) -> None: + """Main job runner.""" + logging.info("Starting webhook notification job") + + exclude_items = self.read_exclude_list() + csv_data = self.read_csv_file() + filtered_data = self.filter_csv_data(csv_data, exclude_items) + user_data = self.group_by_user(filtered_data) + self.send_webhook_notifications(user_data) + logging.info("Webhook notification job completed") + + +if __name__ == "__main__": + sys_args = sys.argv[1:] + parser = argparse.ArgumentParser() + parser.add_argument( + "-j", + "--job-settings", + required=False, + type=str, + help=( + "Instead of init args the job settings can optionally be passed " + "as a json string in the command line." + ), + ) + cli_args = parser.parse_args(sys_args) + main_job_settings = JobSettings.model_validate_json(cli_args.job_settings) + main_job = WebhookNotificationJob(job_settings=main_job_settings) + main_job.run_job() diff --git a/tests/resources/example_capsules.csv b/tests/resources/example_capsules.csv new file mode 100644 index 0000000..5d9cc43 --- /dev/null +++ b/tests/resources/example_capsules.csv @@ -0,0 +1,5 @@ +user_email,capsule_url +user1@example.com,https://codeocean.com/capsule/12345 +user2@example.com,https://codeocean.com/capsule/23456 +user1@example.com,https://codeocean.com/capsule/34567 +user3@example.com,https://codeocean.com/capsule/45678 diff --git a/tests/resources/exclude_list.txt b/tests/resources/exclude_list.txt new file mode 100644 index 0000000..b91208b --- /dev/null +++ b/tests/resources/exclude_list.txt @@ -0,0 +1,2 @@ +user2@example.com +https://codeocean.com/capsule/12345 \ No newline at end of file diff --git a/tests/test_trigger_co_cleanup_notification.py b/tests/test_trigger_co_cleanup_notification.py new file mode 100644 index 0000000..7dfb48c --- /dev/null +++ b/tests/test_trigger_co_cleanup_notification.py @@ -0,0 +1,248 @@ +"""Tests trigger_co_cleanup_notification module""" +import os +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import requests + +from aind_data_upload_utils.trigger_co_cleanup_notification import ( + JobSettings, + WebhookNotificationJob, +) + +RESOURCES_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / "resources" +CSV_FILE = RESOURCES_DIR / "example_capsules.csv" +EXCLUDE_FILE = RESOURCES_DIR / "exclude_list.txt" + + +class TestWebhookNotificationJob(unittest.TestCase): + """Test class for WebhookNotificationJob.""" + + @classmethod + def setUpClass(cls) -> None: + """Sets up job settings for all tests.""" + cls.job_settings = JobSettings( + csv_file=CSV_FILE, + exclude_list_file=EXCLUDE_FILE, + webhook_url="https://webhook.site/test", + ) + cls.example_job = WebhookNotificationJob(job_settings=cls.job_settings) + + def test_job_settings_properties(self): + """Tests JobSettings properties.""" + self.assertEqual(self.job_settings.csv_file, CSV_FILE) + self.assertEqual(self.job_settings.exclude_list_file, EXCLUDE_FILE) + self.assertEqual( + self.job_settings.webhook_url, "https://webhook.site/test" + ) + + def test_s3_uri_methods(self): + """Tests S3 URI detection and parsing methods.""" + self.assertTrue(self.example_job._is_s3_uri("s3://bucket/key")) + self.assertTrue( + self.example_job._is_s3_uri("s3://bucket/folder/file.csv") + ) + self.assertFalse( + self.example_job._is_s3_uri("/local/path/file.csv") + ) + self.assertFalse(self.example_job._is_s3_uri("file.csv")) + # Test _parse_s3_uri method + bucket, key = self.example_job._parse_s3_uri("s3://my-bucket/file.csv") + self.assertEqual(bucket, "my-bucket") + self.assertEqual(key, "file.csv") + bucket, key = self.example_job._parse_s3_uri( + "s3://aind-devops-dev/co_capsule_cleanup/list.csv" + ) + self.assertEqual(bucket, "aind-devops-dev") + self.assertEqual(key, "co_capsule_cleanup/list.csv") + + def test_read_exclude_list_local_file(self): + """Tests read_exclude_list method with local file.""" + with self.assertLogs(level="DEBUG") as captured: + exclude_items = self.example_job.read_exclude_list() + self.assertIsInstance(exclude_items, set) + self.assertIn("user2@example.com", exclude_items) + debug_logs = [log for log in captured.output if "Exclude items" in log] + self.assertEqual(len(debug_logs), 1) + + @patch("boto3.client") + def test_read_exclude_list_s3_file(self, mock_boto3_client): + """Tests read_exclude_list method with S3 file.""" + mock_s3_client = MagicMock() + mock_response = {"Body": MagicMock()} + mock_response["Body"].read.return_value = ( + b"user2@example.com\nuser3@example.com" + ) + mock_s3_client.get_object.return_value = mock_response + mock_boto3_client.return_value = mock_s3_client + + s3_job_settings = JobSettings( + csv_file=CSV_FILE, + exclude_list_file="s3://test-bucket/exclude.txt", + webhook_url="https://webhook.site/test" + ) + s3_job = WebhookNotificationJob(job_settings=s3_job_settings) + + exclude_items = s3_job.read_exclude_list() + self.assertIn("user2@example.com", exclude_items) + self.assertIn("user3@example.com", exclude_items) + mock_s3_client.get_object.assert_called_once_with( + Bucket="test-bucket", Key="exclude.txt" + ) + mock_s3_client.close.assert_called_once() + + def test_read_csv_file_local(self): + """Tests read_csv_file method with local file.""" + with self.assertLogs(level="DEBUG") as captured: + csv_data = self.example_job.read_csv_file() + self.assertIsInstance(csv_data, list) + self.assertEqual(len(csv_data), 4) + for row in csv_data: + self.assertIn("user_email", row) + self.assertIn("capsule_url", row) + debug_logs = [ + log for log in captured.output + if "Read" in log and "rows" in log + ] + self.assertEqual(len(debug_logs), 1) + + @patch("boto3.client") + def test_read_csv_file_s3(self, mock_boto3_client): + """Tests read_csv_file method with S3 file.""" + csv_content = ( + "user_email,capsule_url\n" + "user1@example.com,https://codeocean.com/capsule/12345\n" + "user2@example.com,https://codeocean.com/capsule/23456" + ) + mock_s3_client = MagicMock() + mock_response = {"Body": MagicMock()} + mock_response["Body"].read.return_value = csv_content.encode("utf-8") + mock_s3_client.get_object.return_value = mock_response + mock_boto3_client.return_value = mock_s3_client + + s3_job_settings = JobSettings( + csv_file="s3://test-bucket/data.csv", + exclude_list_file=EXCLUDE_FILE, + webhook_url="https://webhook.site/test" + ) + s3_job = WebhookNotificationJob(job_settings=s3_job_settings) + + csv_data = s3_job.read_csv_file() + self.assertEqual(len(csv_data), 2) + mock_s3_client.get_object.assert_called_once_with( + Bucket="test-bucket", Key="data.csv" + ) + mock_s3_client.close.assert_called_once() + + def test_filter_csv_data(self): + """Tests filter_csv_data method.""" + csv_data = self.example_job.read_csv_file() + exclude_items = {"user2@example.com"} + with self.assertLogs(level="INFO") as captured: + filtered_data = self.example_job.filter_csv_data( + csv_data, exclude_items + ) + self.assertEqual(len(filtered_data), 3) + filtered_users = [row["user_email"] for row in filtered_data] + self.assertNotIn("user2@example.com", filtered_users) + info_logs = [log for log in captured.output if "Excluding row" in log] + self.assertEqual(len(info_logs), 1) + + def test_group_by_user(self): + """Tests group_by_user method.""" + filtered_data = [ + {"user_email": "user1@example.com", "capsule_url": "url1"}, + {"user_email": "user1@example.com", "capsule_url": "url2"}, + {"user_email": "user3@example.com", "capsule_url": "url3"}, + ] + with self.assertLogs(level="DEBUG") as captured: + user_data = self.example_job.group_by_user(filtered_data) + self.assertIn("user1@example.com", user_data) + self.assertIn("user3@example.com", user_data) + self.assertEqual(len(user_data["user1@example.com"]), 2) + self.assertEqual(len(user_data["user3@example.com"]), 1) + debug_logs = [ + log for log in captured.output if "Grouped data" in log + ] + self.assertEqual(len(debug_logs), 1) + + def test_exclude_list_integration(self): + """Tests exclusion by both user email and capsule URL.""" + exclude_items = self.example_job.read_exclude_list() + csv_data = self.example_job.read_csv_file() + filtered_data = self.example_job.filter_csv_data( + csv_data, exclude_items + ) + user_data = self.example_job.group_by_user(filtered_data) + + self.assertNotIn("user2@example.com", user_data) + self.assertIn("user1@example.com", user_data) + self.assertIn("user3@example.com", user_data) + + self.assertEqual(len(user_data["user1@example.com"]), 1) + self.assertEqual( + user_data["user1@example.com"][0]["capsule_url"], + "https://codeocean.com/capsule/34567", + ) + + self.assertEqual(len(user_data["user3@example.com"]), 1) + self.assertEqual( + user_data["user3@example.com"][0]["capsule_url"], + "https://codeocean.com/capsule/45678", + ) + + @patch("requests.post") + def test_send_webhook_notifications_success(self, mock_post: MagicMock): + """Tests successful webhook notifications.""" + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + test_data = { + "user1@example.com": [{"capsule_url": "https://example.com/1"}], + "user2@example.com": [{"capsule_url": "https://example.com/2"}] + } + with self.assertLogs(level="INFO") as captured: + self.example_job.send_webhook_notifications(test_data) + self.assertEqual(mock_post.call_count, 2) + success_logs = [ + log for log in captured.output if "Successfully" in log + ] + self.assertEqual(len(success_logs), 2) + + @patch("requests.post") + def test_send_webhook_notifications_failure(self, mock_post: MagicMock): + """Tests webhook notification failures.""" + mock_post.side_effect = requests.exceptions.RequestException("Error") + + test_data = { + "user1@example.com": [{"capsule_url": "https://example.com/1"}] + } + with self.assertLogs(level="ERROR") as captured: + with self.assertRaises(requests.exceptions.RequestException): + self.example_job.send_webhook_notifications(test_data) + + error_logs = [ + log for log in captured.output if "Failed to send" in log + ] + self.assertEqual(len(error_logs), 1) + + @patch("requests.post") + def test_run_job_integration(self, mock_post: MagicMock): + """Tests the complete run_job workflow.""" + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + with self.assertLogs(level="INFO") as captured: + self.example_job.run_job() + start_log = any("Starting webhook" in log for log in captured.output) + end_log = any("completed" in log for log in captured.output) + self.assertTrue(start_log) + self.assertTrue(end_log) + self.assertEqual(mock_post.call_count, 2) + + +if __name__ == "__main__": + unittest.main()