diff --git a/config/config.yaml b/config/config.yaml index ca521ea..1b8c916 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -151,6 +151,14 @@ logging: # Number of backup log files to keep backup_count: 5 +# ============================================================================= +# PERSISTENCE CONFIGURATION +# ============================================================================= +persistence: + # SQLite database path for persisting alerts and beacons + # Set to null to disable persistence (in-memory only) + db_path: "./beacon_detect.db" + # ============================================================================= # WHITELIST CONFIGURATION # ============================================================================= diff --git a/control_plane/alerter.py b/control_plane/alerter.py index 23a5a3a..997d0bf 100644 --- a/control_plane/alerter.py +++ b/control_plane/alerter.py @@ -90,6 +90,19 @@ def convert_value(v): def to_json(self): return json.dumps(self.to_dict(), indent=2) + @classmethod + def from_dict(cls, data): + return cls( + alert_id=data["alert_id"], + title=data["title"], + description=data["description"], + severity=AlertSeverity(data["severity"]), + source=data["source"], + details=data.get("details", {}), + timestamp=data.get("timestamp", ""), + tags=data.get("tags", []), + ) + def to_syslog_message(self): return ( f"[{self.severity.value.upper()}] {self.title} | " @@ -293,9 +306,10 @@ def send(self, alert: Alert): class AlertManager: - def __init__(self, config=None): + def __init__(self, config=None, persistence=None): self.config = config or AlertingConfig() + self._persistence = persistence # Initialize handlers self._syslog = SyslogHandler(self.config) @@ -406,6 +420,13 @@ def _deliver_alert(self, alert: Alert): if len(self._recent_alerts) > self._max_recent_alerts: self._recent_alerts = self._recent_alerts[-self._max_recent_alerts :] + # Persist to database + if self._persistence: + try: + self._persistence.save_alert(alert.to_dict()) + except Exception as e: + logger.error(f"Failed to persist alert: {e}") + def send_alert(self, alert: Alert): if not self.config.enabled: @@ -456,6 +477,28 @@ def get_recent_alerts(self, limit: int = 50, severity=None): return [a.to_dict() for a in reversed(alerts)] + def load_historical_alerts(self): + """Load alerts from persistence on startup.""" + if not self._persistence: + return + + try: + alert_dicts = self._persistence.load_alerts(limit=self._max_recent_alerts) + for alert_dict in reversed(alert_dicts): + try: + alert = Alert.from_dict(alert_dict) + self._recent_alerts.append(alert) + except Exception as e: + logger.warning( + f"Failed to restore alert {alert_dict.get('alert_id')}: {e}" + ) + + logger.info( + f"Loaded {len(self._recent_alerts)} historical alerts from database" + ) + except Exception as e: + logger.error(f"Failed to load historical alerts: {e}") + @property def statistics(self): diff --git a/control_plane/analyzer.py b/control_plane/analyzer.py index 2adc5c8..79d719e 100644 --- a/control_plane/analyzer.py +++ b/control_plane/analyzer.py @@ -39,6 +39,7 @@ def __init__(self, run_id: str): self.pairs_analyzed = 0 self.beacons_detected = 0 self.alerts_generated = 0 + self.pairs_skipped = 0 self.errors = 0 self.results: List[DetectionResult] = [] @@ -61,6 +62,7 @@ def to_dict(self): "pairs_analyzed": int(self.pairs_analyzed), "beacons_detected": int(self.beacons_detected), "alerts_generated": int(self.alerts_generated), + "pairs_skipped": int(self.pairs_skipped), "errors": int(self.errors), } @@ -73,6 +75,8 @@ def __init__( detector: BeaconDetector, alert_manager: AlertManager, config=None, + whitelist: dict = None, + persistence=None, ): """ Initialize the analyzer. @@ -82,11 +86,15 @@ def __init__( detector: BeaconDetector instance alert_manager: AlertManager instance config: Analyzer configuration + whitelist: Whitelist configuration for filtering pairs + persistence: SQLiteStore instance for persisting beacons """ self.storage = storage self.detector = detector self.alert_manager = alert_manager self.config = config or AnalyzerConfig() + self._whitelist = whitelist or {} + self._persistence = persistence # Thread safety lock for shared state self._lock = threading.RLock() @@ -118,6 +126,29 @@ def __init__( f"ConnectionAnalyzer initialized: interval={self.config.analysis_interval}s" ) + def _is_whitelisted(self, pair) -> bool: + """Check if a connection pair matches any whitelist rule.""" + if pair.src_ip in self._whitelist.get("source_ips", []): + return True + + if pair.dst_ip in self._whitelist.get("destination_ips", []): + return True + + if pair.dst_port in self._whitelist.get("ports", []): + return True + + pair_str = f"{pair.src_ip}:{pair.dst_ip}:{pair.dst_port}" + if pair_str in self._whitelist.get("pairs", []): + return True + + return False + + def update_whitelist(self, whitelist: dict): + """Update whitelist configuration (thread-safe).""" + with self._lock: + self._whitelist = whitelist or {} + logger.info("Whitelist updated") + def start(self): if self._running: @@ -165,13 +196,23 @@ def run_analysis(self): logger.info(f"Starting analysis run: {run_id}") try: - # logger.info(f"self.config.min_connections {self.config.min_connections} and self.config.min_duration {self.config.min_duration}") # Get analyzable pairs from storage pairs = self.storage.get_analyzable_pairs( min_connections=self.config.min_connections, min_duration=self.config.min_duration, ) - # logger.info(f"pairs{pairs}") + + # Filter out whitelisted pairs + if self._whitelist: + original_count = len(pairs) + pairs = [p for p in pairs if not self._is_whitelisted(p)] + run.pairs_skipped = original_count - len(pairs) + if run.pairs_skipped > 0: + logger.info( + f"Whitelist filtered {run.pairs_skipped} pairs " + f"({original_count} -> {len(pairs)})" + ) + # Limit pairs for performance if len(pairs) > self.config.max_pairs_per_run: logger.warning( @@ -208,6 +249,17 @@ def run_analysis(self): with self._lock: self._known_beacons[result.pair_key] = result + # Persist beacon to database + if self._persistence: + try: + self._persistence.save_beacon( + result.pair_key, result.to_dict() + ) + except Exception as e: + logger.error( + f"Failed to persist beacon {result.pair_key}: {e}" + ) + except Exception as e: logger.error(f"Error generating alert for {result.pair_key}: {e}") run.errors += 1 @@ -220,6 +272,13 @@ def run_analysis(self): ] for key in stale_keys: del self._known_beacons[key] + if self._persistence: + try: + self._persistence.remove_beacon(key) + except Exception as e: + logger.error( + f"Failed to remove beacon {key} from database: {e}" + ) except Exception as e: logger.error(f"Analysis run error: {e}", exc_info=True) @@ -320,6 +379,26 @@ def get_run_history(self, limit: int = 10): runs = self._run_history[-limit:] return [r.to_dict() for r in reversed(runs)] + def load_historical_beacons(self): + """Load known beacons from persistence on startup.""" + if not self._persistence: + return + + try: + beacon_dicts = self._persistence.load_beacons() + for pair_key, detection_dict in beacon_dicts.items(): + try: + result = DetectionResult.from_dict(detection_dict) + self._known_beacons[pair_key] = result + except Exception as e: + logger.warning(f"Failed to restore beacon {pair_key}: {e}") + + logger.info( + f"Loaded {len(self._known_beacons)} historical beacons from database" + ) + except Exception as e: + logger.error(f"Failed to load historical beacons: {e}") + @property def statistics(self): """Get analyzer statistics""" @@ -331,4 +410,10 @@ def statistics(self): "total_alerts_generated": self._total_alerts_generated, "current_known_beacons": len(self._known_beacons), "active_cooldowns": len(self._alert_cooldowns), + "whitelist_rules": { + "source_ips": len(self._whitelist.get("source_ips", [])), + "destination_ips": len(self._whitelist.get("destination_ips", [])), + "ports": len(self._whitelist.get("ports", [])), + "pairs": len(self._whitelist.get("pairs", [])), + }, } diff --git a/control_plane/detector.py b/control_plane/detector.py index 92139c5..88e8056 100644 --- a/control_plane/detector.py +++ b/control_plane/detector.py @@ -46,6 +46,19 @@ def to_dict(self): "jitter": float(round(self.jitter, 3)), } + @classmethod + def from_dict(cls, data): + return cls( + count=data["count"], + mean=data["mean"], + std_dev=data["std_dev"], + cv=data["cv"], + median=data["median"], + min_interval=data["min_interval"], + max_interval=data["max_interval"], + jitter=data["jitter"], + ) + @dataclass class PeriodicityResult: @@ -68,6 +81,15 @@ def to_dict(self): ], } + @classmethod + def from_dict(cls, data): + return cls( + is_periodic=data["is_periodic"], + dominant_period=data["dominant_period"] or 0.0, + periodicity_score=data["periodicity_score"], + frequency_peaks=[(f, m) for f, m in data["frequency_peaks"]], + ) + @dataclass class DetectionResult: @@ -125,6 +147,29 @@ def to_dict(self): "analysis_time": str(self.analysis_time), } + @classmethod + def from_dict(cls, data): + return cls( + pair_key=data["pair_key"], + src_ip=data["src_ip"], + dst_ip=data["dst_ip"], + dst_port=data["dst_port"], + protocol=data["protocol"], + cv_score=data["cv_score"], + periodicity_score=data["periodicity_score"], + jitter_score=data["jitter_score"], + combined_score=data["combined_score"], + is_beacon=data["is_beacon"], + confidence=BeaconConfidence(data["confidence"]), + interval_stats=IntervalStats.from_dict(data["interval_stats"]), + periodicity_result=PeriodicityResult.from_dict(data["periodicity_result"]), + connection_count=data["connection_count"], + duration_seconds=data["duration_seconds"], + first_seen=data["first_seen"], + last_seen=data["last_seen"], + analysis_time=data.get("analysis_time", ""), + ) + @dataclass class DetectorConfig: diff --git a/control_plane/persistence.py b/control_plane/persistence.py new file mode 100644 index 0000000..b27247f --- /dev/null +++ b/control_plane/persistence.py @@ -0,0 +1,194 @@ +""" +SQLite persistence layer for beacon detection system. + +Persists alerts and detected beacons across restarts. +Raw connection data is NOT persisted (ephemeral, high-volume). +""" + +import json +import logging +import sqlite3 +import threading +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional + +logger = logging.getLogger("beacon_detect.control_plane.persistence") + + +@dataclass +class PersistenceConfig: + db_path: str = "./beacon_detect.db" + journal_mode: str = "WAL" + busy_timeout_ms: int = 5000 + + +class SQLiteStore: + """Thread-safe SQLite persistence for alerts and beacons.""" + + def __init__(self, config: PersistenceConfig = None): + self.config = config or PersistenceConfig() + self._lock = threading.RLock() + self._conn: Optional[sqlite3.Connection] = None + + def open(self): + """Open database connection and create tables.""" + db_path = Path(self.config.db_path) + db_path.parent.mkdir(parents=True, exist_ok=True) + + self._conn = sqlite3.connect( + str(db_path), + check_same_thread=False, + ) + self._conn.execute(f"PRAGMA journal_mode={self.config.journal_mode}") + self._conn.execute(f"PRAGMA busy_timeout={self.config.busy_timeout_ms}") + self._conn.row_factory = sqlite3.Row + + self._create_tables() + logger.info(f"SQLiteStore opened: {db_path}") + + def close(self): + """Close database connection.""" + if self._conn: + self._conn.close() + self._conn = None + logger.info("SQLiteStore closed") + + def _create_tables(self): + with self._lock: + self._conn.executescript( + """ + CREATE TABLE IF NOT EXISTS alerts ( + alert_id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + severity TEXT NOT NULL, + source TEXT NOT NULL, + details_json TEXT DEFAULT '{}', + timestamp TEXT NOT NULL, + tags_json TEXT DEFAULT '[]', + created_at TEXT DEFAULT (datetime('now')) + ); + + CREATE INDEX IF NOT EXISTS idx_alerts_timestamp + ON alerts(timestamp DESC); + CREATE INDEX IF NOT EXISTS idx_alerts_severity + ON alerts(severity); + + CREATE TABLE IF NOT EXISTS beacons ( + pair_key TEXT PRIMARY KEY, + detection_json TEXT NOT NULL, + first_detected TEXT NOT NULL, + last_updated TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_beacons_last_updated + ON beacons(last_updated DESC); + """ + ) + self._conn.commit() + + # -- Alert persistence -- + + def save_alert(self, alert_dict: dict): + """Persist a single alert.""" + with self._lock: + self._conn.execute( + """INSERT OR REPLACE INTO alerts + (alert_id, title, description, severity, source, + details_json, timestamp, tags_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + alert_dict["alert_id"], + alert_dict["title"], + alert_dict["description"], + alert_dict["severity"], + alert_dict["source"], + json.dumps(alert_dict.get("details", {})), + alert_dict["timestamp"], + json.dumps(alert_dict.get("tags", [])), + ), + ) + self._conn.commit() + + def load_alerts(self, limit: int = 1000) -> List[dict]: + """Load recent alerts from database.""" + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM alerts ORDER BY timestamp DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + + alerts = [] + for row in rows: + alerts.append( + { + "alert_id": row["alert_id"], + "title": row["title"], + "description": row["description"], + "severity": row["severity"], + "source": row["source"], + "details": json.loads(row["details_json"]), + "timestamp": row["timestamp"], + "tags": json.loads(row["tags_json"]), + } + ) + return alerts + + def get_alert_count(self) -> int: + with self._lock: + cursor = self._conn.execute("SELECT COUNT(*) FROM alerts") + return cursor.fetchone()[0] + + # -- Beacon persistence -- + + def save_beacon(self, pair_key: str, detection_dict: dict): + """Persist a detected beacon.""" + now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + with self._lock: + # Preserve first_detected if beacon already exists + cursor = self._conn.execute( + "SELECT first_detected FROM beacons WHERE pair_key = ?", + (pair_key,), + ) + row = cursor.fetchone() + first_detected = row["first_detected"] if row else now + + self._conn.execute( + """INSERT OR REPLACE INTO beacons + (pair_key, detection_json, first_detected, last_updated) + VALUES (?, ?, ?, ?)""", + ( + pair_key, + json.dumps(detection_dict), + first_detected, + now, + ), + ) + self._conn.commit() + + def remove_beacon(self, pair_key: str): + """Remove a beacon that is no longer detected.""" + with self._lock: + self._conn.execute( + "DELETE FROM beacons WHERE pair_key = ?", (pair_key,) + ) + self._conn.commit() + + def load_beacons(self) -> Dict[str, dict]: + """Load all known beacons. Returns {pair_key: detection_dict}.""" + with self._lock: + cursor = self._conn.execute("SELECT * FROM beacons") + rows = cursor.fetchall() + + beacons = {} + for row in rows: + beacons[row["pair_key"]] = json.loads(row["detection_json"]) + return beacons + + def get_beacon_count(self) -> int: + with self._lock: + cursor = self._conn.execute("SELECT COUNT(*) FROM beacons") + return cursor.fetchone()[0] diff --git a/control_plane/server.py b/control_plane/server.py index 5344653..233dae5 100644 --- a/control_plane/server.py +++ b/control_plane/server.py @@ -15,6 +15,7 @@ from .alerter import AlertingConfig, AlertManager, AlertSeverity from .analyzer import AnalyzerConfig, ConnectionAnalyzer from .detector import BeaconDetector, DetectorConfig +from .persistence import PersistenceConfig, SQLiteStore from .storage import ConnectionStorage # Configure logging @@ -41,6 +42,7 @@ def __init__(self, config): # Initialize components self._init_storage(config) self._init_detector(config) + self._init_persistence(config) self._init_alerter(config) self._init_analyzer(config) @@ -101,6 +103,14 @@ def _init_storage(self, config): cleanup_interval=cp_config.get("cleanup_interval", 300), ) + def _init_persistence(self, config): + + persistence_config = config.get("persistence", {}) + db_path = persistence_config.get("db_path", "./beacon_detect.db") + + self.persistence = SQLiteStore(PersistenceConfig(db_path=db_path)) + self.persistence.open() + def _init_detector(self, config): det_config = config.get("detection", {}) @@ -140,11 +150,12 @@ def _init_alerter(self, config): webhook_timeout=alert_config.get("webhook", {}).get("timeout", 10), webhook_retries=alert_config.get("webhook", {}).get("retries", 3), ) - self.alert_manager = AlertManager(alerting_config) + self.alert_manager = AlertManager(alerting_config, persistence=self.persistence) def _init_analyzer(self, config): det_config = config.get("detection", {}) + whitelist = config.get("whitelist", {}) analyzer_config = AnalyzerConfig( analysis_interval=60, # Run every minute min_connections=det_config.get("min_connections", 10), @@ -156,6 +167,8 @@ def _init_analyzer(self, config): detector=self.detector, alert_manager=self.alert_manager, config=analyzer_config, + whitelist=whitelist, + persistence=self.persistence, ) def _setup_routes(self, app: web.Application): @@ -450,6 +463,9 @@ async def _handle_set_config(self, request: web.Request) -> web.Response: "whitelist" ]["destination_ports"] + # Propagate whitelist to analyzer + self.analyzer.update_whitelist(self.config.get("whitelist", {})) + logger.info("Configuration updated via API") return web.json_response( {"status": "updated", "config": self._runtime_config} @@ -484,6 +500,11 @@ async def start(self): # Start components self.storage.start_cleanup() self.alert_manager.start() + + # Load historical data from database + self.alert_manager.load_historical_alerts() + self.analyzer.load_historical_beacons() + self.analyzer.start() # Create and configure app with CORS support @@ -567,6 +588,12 @@ async def stop(self): except Exception as e: logger.warning(f"Error cleaning up runner: {e}") + # Close persistence + try: + self.persistence.close() + except Exception as e: + logger.warning(f"Error closing persistence: {e}") + logger.info("Control plane server stopped") def request_shutdown(self): @@ -613,11 +640,20 @@ def setup_logging(config): maxBytes=log_config.get("max_size_mb", 50) * 1024 * 1024, backupCount=log_config.get("backup_count", 5), ) - handler.setFormatter( - logging.Formatter( + + log_format = log_config.get("format", "text") + if log_format == "json": + from pythonjsonlogger import jsonlogger + + formatter = jsonlogger.JsonFormatter( + "%(asctime)s %(name)s %(levelname)s %(message)s %(pathname)s %(lineno)d" + ) + else: + formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) - ) + + handler.setFormatter(formatter) logging.getLogger().addHandler(handler) except Exception as e: logger.warning(f"Could not set up file logging: {e}") diff --git a/data_plane/collector.py b/data_plane/collector.py index b104351..de3a591 100644 --- a/data_plane/collector.py +++ b/data_plane/collector.py @@ -8,6 +8,7 @@ import argparse import ctypes import logging +import logging.handlers import os import signal import sys @@ -348,6 +349,46 @@ def load_config(config_path: str): return yaml.safe_load(f) +def setup_logging(config): + + log_config = config.get("logging", {}) + + level_str = log_config.get("level", "INFO") + level = getattr(logging, level_str.upper(), logging.INFO) + + logging.getLogger().setLevel(level) + logging.getLogger("beacon_detect").setLevel(level) + + log_file = log_config.get("file") + if log_file: + try: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + handler = logging.handlers.RotatingFileHandler( + log_path, + maxBytes=log_config.get("max_size_mb", 50) * 1024 * 1024, + backupCount=log_config.get("backup_count", 5), + ) + + log_format = log_config.get("format", "text") + if log_format == "json": + from pythonjsonlogger import jsonlogger + + formatter = jsonlogger.JsonFormatter( + "%(asctime)s %(name)s %(levelname)s %(message)s %(pathname)s %(lineno)d" + ) + else: + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + logging.getLogger().addHandler(handler) + except Exception as e: + logger.warning(f"Could not set up file logging: {e}") + + def setup_signal_handlers(collector): def signal_handler(signum, frame): @@ -392,6 +433,9 @@ def main(): logger.error(f"Failed to load configuration: {e}") sys.exit(1) + # Set up logging from config + setup_logging(config) + # Determine interface interface = args.interface or config.get("data_plane", {}).get("interface") if not interface: diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index c65a1ab..058f260 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -400,5 +400,172 @@ def test_syslog_priority_mapping(self): assert AlertSeverity.CRITICAL.syslog_priority == logging.CRITICAL +class TestWhitelist: + + def setup_method(self): + + self.storage = ConnectionStorage() + self.detector = BeaconDetector(DetectorConfig(min_connections=5)) + self.alert_manager = Mock(spec=AlertManager) + self.alert_manager.send_alert = Mock() + + def test_whitelist_source_ip(self): + + whitelist = {"source_ips": ["192.168.1.100"]} + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist=whitelist, + ) + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + ) + assert analyzer._is_whitelisted(pair) is True + + def test_whitelist_destination_ip(self): + + whitelist = {"destination_ips": ["10.0.0.1"]} + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist=whitelist, + ) + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + ) + assert analyzer._is_whitelisted(pair) is True + + def test_whitelist_port(self): + + whitelist = {"ports": [53, 123]} + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist=whitelist, + ) + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=53, protocol="UDP" + ) + assert analyzer._is_whitelisted(pair) is True + + def test_whitelist_pair(self): + + whitelist = {"pairs": ["192.168.1.100:10.0.0.1:443"]} + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist=whitelist, + ) + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + ) + assert analyzer._is_whitelisted(pair) is True + + def test_no_whitelist_match(self): + + whitelist = {"source_ips": ["10.10.10.10"]} + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist=whitelist, + ) + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + ) + assert analyzer._is_whitelisted(pair) is False + + def test_empty_whitelist(self): + + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist={}, + ) + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + ) + assert analyzer._is_whitelisted(pair) is False + + def test_analysis_run_skips_whitelisted(self): + + whitelist = {"destination_ips": ["10.0.0.1"]} + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + config=AnalyzerConfig(min_connections=5, min_duration=60), + whitelist=whitelist, + ) + + base_time = time.time() + # Add whitelisted pair + for i in range(20): + record = ConnectionRecord( + timestamp_ns=int((base_time + i * 60) * 1e9), + timestamp_utc=f"2024-01-01T00:{i:02d}:00Z", + src_ip="192.168.1.100", + dst_ip="10.0.0.1", + src_port=54321, + dst_port=443, + packet_size=1500, + protocol=6, + protocol_name="TCP", + tcp_flags=0x10, + direction=1, + node_id="test-node", + connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", + ) + record.timestamp_epoch = base_time + i * 60 + self.storage.add_record(record) + + # Add non-whitelisted pair + for i in range(20): + record = ConnectionRecord( + timestamp_ns=int((base_time + i * 60) * 1e9), + timestamp_utc=f"2024-01-01T00:{i:02d}:00Z", + src_ip="192.168.1.200", + dst_ip="10.0.0.2", + src_port=54322, + dst_port=80, + packet_size=1500, + protocol=6, + protocol_name="TCP", + tcp_flags=0x10, + direction=1, + node_id="test-node", + connection_key="192.168.1.200:54322->10.0.0.2:80/TCP", + ) + record.timestamp_epoch = base_time + i * 60 + self.storage.add_record(record) + + run = analyzer.run_analysis() + + assert run.pairs_skipped == 1 + assert run.pairs_analyzed == 1 + + def test_update_whitelist(self): + + analyzer = ConnectionAnalyzer( + storage=self.storage, + detector=self.detector, + alert_manager=self.alert_manager, + whitelist={}, + ) + + pair = ConnectionPair( + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + ) + assert analyzer._is_whitelisted(pair) is False + + analyzer.update_whitelist({"source_ips": ["192.168.1.100"]}) + assert analyzer._is_whitelisted(pair) is True + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 0000000..32c4188 --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,289 @@ +import sys +import time +from pathlib import Path +from unittest.mock import Mock + +import pytest + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from control_plane.alerter import Alert, AlertManager, AlertSeverity +from control_plane.detector import ( + BeaconConfidence, + DetectionResult, + IntervalStats, + PeriodicityResult, +) +from control_plane.persistence import PersistenceConfig, SQLiteStore + + +@pytest.fixture +def db_store(tmp_path): + """Create a temporary SQLiteStore for testing.""" + db_path = str(tmp_path / "test.db") + store = SQLiteStore(PersistenceConfig(db_path=db_path)) + store.open() + yield store + store.close() + + +def _make_alert_dict(alert_id="test-alert-1", severity="high"): + return { + "alert_id": alert_id, + "title": "Beacon Detected", + "description": "Beaconing from 192.168.1.100 to 10.0.0.1:443", + "severity": severity, + "source": "beacon_detector", + "details": {"score": 0.85, "pair_key": "192.168.1.100->10.0.0.1:443/TCP"}, + "timestamp": "2024-01-01T00:00:00Z", + "tags": ["beacon", "tcp"], + } + + +def _make_detection_result(): + return DetectionResult( + pair_key="192.168.1.100->10.0.0.1:443/TCP", + src_ip="192.168.1.100", + dst_ip="10.0.0.1", + dst_port=443, + protocol="TCP", + cv_score=0.9, + periodicity_score=0.85, + jitter_score=0.8, + combined_score=0.87, + is_beacon=True, + confidence=BeaconConfidence.HIGH, + interval_stats=IntervalStats( + count=20, + mean=60.0, + std_dev=2.0, + cv=0.033, + median=60.0, + min_interval=58.0, + max_interval=62.0, + jitter=2.0, + ), + periodicity_result=PeriodicityResult( + is_periodic=True, + dominant_period=60.0, + periodicity_score=0.85, + frequency_peaks=[(0.0167, 0.95)], + ), + connection_count=20, + duration_seconds=1140.0, + first_seen="2024-01-01T00:00:00Z", + last_seen="2024-01-01T00:19:00Z", + ) + + +class TestAlertPersistence: + + def test_save_and_load_alert(self, db_store): + + alert_dict = _make_alert_dict() + db_store.save_alert(alert_dict) + + alerts = db_store.load_alerts() + assert len(alerts) == 1 + assert alerts[0]["alert_id"] == "test-alert-1" + assert alerts[0]["severity"] == "high" + assert alerts[0]["details"]["score"] == 0.85 + assert alerts[0]["tags"] == ["beacon", "tcp"] + + def test_alert_deduplication(self, db_store): + + alert_dict = _make_alert_dict() + db_store.save_alert(alert_dict) + db_store.save_alert(alert_dict) + + assert db_store.get_alert_count() == 1 + + def test_multiple_alerts(self, db_store): + + for i in range(5): + db_store.save_alert(_make_alert_dict(alert_id=f"alert-{i}")) + + alerts = db_store.load_alerts() + assert len(alerts) == 5 + + def test_load_alerts_limit(self, db_store): + + for i in range(10): + db_store.save_alert( + _make_alert_dict( + alert_id=f"alert-{i}", + ) + ) + + alerts = db_store.load_alerts(limit=3) + assert len(alerts) == 3 + + def test_load_empty_alerts(self, db_store): + + alerts = db_store.load_alerts() + assert len(alerts) == 0 + assert db_store.get_alert_count() == 0 + + +class TestBeaconPersistence: + + def test_save_and_load_beacon(self, db_store): + + result = _make_detection_result() + db_store.save_beacon(result.pair_key, result.to_dict()) + + beacons = db_store.load_beacons() + assert len(beacons) == 1 + assert result.pair_key in beacons + + loaded = beacons[result.pair_key] + assert loaded["combined_score"] == 0.87 + assert loaded["is_beacon"] is True + + def test_beacon_first_detected_preserved(self, db_store): + + result = _make_detection_result() + db_store.save_beacon(result.pair_key, result.to_dict()) + + # Load to check first_detected + cursor = db_store._conn.execute( + "SELECT first_detected FROM beacons WHERE pair_key = ?", + (result.pair_key,), + ) + first_detected_original = cursor.fetchone()["first_detected"] + + # Save again (update) + time.sleep(0.01) + db_store.save_beacon(result.pair_key, result.to_dict()) + + cursor = db_store._conn.execute( + "SELECT first_detected, last_updated FROM beacons WHERE pair_key = ?", + (result.pair_key,), + ) + row = cursor.fetchone() + assert row["first_detected"] == first_detected_original + + def test_remove_beacon(self, db_store): + + result = _make_detection_result() + db_store.save_beacon(result.pair_key, result.to_dict()) + assert db_store.get_beacon_count() == 1 + + db_store.remove_beacon(result.pair_key) + assert db_store.get_beacon_count() == 0 + + def test_load_empty_beacons(self, db_store): + + beacons = db_store.load_beacons() + assert len(beacons) == 0 + + +class TestDetectionResultRoundtrip: + + def test_to_dict_from_dict(self): + + original = _make_detection_result() + d = original.to_dict() + restored = DetectionResult.from_dict(d) + + assert restored.pair_key == original.pair_key + assert restored.src_ip == original.src_ip + assert restored.dst_ip == original.dst_ip + assert restored.dst_port == original.dst_port + assert restored.combined_score == pytest.approx(original.combined_score, abs=0.001) + assert restored.is_beacon == original.is_beacon + assert restored.confidence == original.confidence + assert restored.interval_stats.mean == pytest.approx(original.interval_stats.mean) + assert restored.periodicity_result.is_periodic == original.periodicity_result.is_periodic + + def test_roundtrip_via_persistence(self, db_store): + + original = _make_detection_result() + db_store.save_beacon(original.pair_key, original.to_dict()) + + beacons = db_store.load_beacons() + restored = DetectionResult.from_dict(beacons[original.pair_key]) + + assert restored.pair_key == original.pair_key + assert restored.combined_score == pytest.approx(original.combined_score, abs=0.001) + assert restored.confidence == original.confidence + + +class TestAlertRoundtrip: + + def test_to_dict_from_dict(self): + + original = Alert( + alert_id="test-1", + title="Test Alert", + description="Testing roundtrip", + severity=AlertSeverity.HIGH, + source="test", + details={"key": "value"}, + tags=["tag1", "tag2"], + ) + + d = original.to_dict() + restored = Alert.from_dict(d) + + assert restored.alert_id == original.alert_id + assert restored.title == original.title + assert restored.severity == original.severity + assert restored.details == original.details + assert restored.tags == original.tags + + def test_roundtrip_via_persistence(self, db_store): + + original = Alert( + alert_id="test-persist", + title="Persisted Alert", + description="Testing persistence roundtrip", + severity=AlertSeverity.CRITICAL, + source="beacon_detector", + details={"score": 0.95}, + tags=["critical"], + ) + + db_store.save_alert(original.to_dict()) + loaded = db_store.load_alerts() + restored = Alert.from_dict(loaded[0]) + + assert restored.alert_id == original.alert_id + assert restored.severity == original.severity + assert restored.details == {"score": 0.95} + + +class TestAlertManagerWithPersistence: + + def test_deliver_persists_alert(self, db_store): + + manager = AlertManager(persistence=db_store) + alert = Alert( + alert_id="persist-test", + title="Test", + description="Test alert", + severity=AlertSeverity.MEDIUM, + source="test", + ) + + manager._deliver_alert(alert) + + assert db_store.get_alert_count() == 1 + loaded = db_store.load_alerts() + assert loaded[0]["alert_id"] == "persist-test" + + def test_load_historical_alerts(self, db_store): + + # Pre-populate DB + for i in range(3): + db_store.save_alert(_make_alert_dict(alert_id=f"historical-{i}")) + + manager = AlertManager(persistence=db_store) + manager.load_historical_alerts() + + assert len(manager._recent_alerts) == 3 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])