From a96bd4fc735026709f8cfabe822ed956ad40c387 Mon Sep 17 00:00:00 2001 From: litemars Date: Wed, 28 Jan 2026 19:42:47 +0100 Subject: [PATCH] add pipeline and code formatting --- .github/workflows/lint-and-test.yml | 60 +++ .gitignore | 37 +- CONTRIBUTING.md | 68 +++ control_plane/alerter.py | 270 +++++----- control_plane/analyzer.py | 192 ++++---- control_plane/cli.py | 731 +++++++++++++++++----------- control_plane/detector.py | 257 +++++----- control_plane/server.py | 721 ++++++++++++++------------- control_plane/storage.py | 193 ++++---- data_plane/collector.py | 252 +++++----- data_plane/exporter.py | 176 +++---- data_plane/telemetry.py | 217 +++++---- tests/test_analyzer.py | 247 +++++----- tests/test_detector.py | 232 ++++----- 14 files changed, 1978 insertions(+), 1675 deletions(-) create mode 100644 .github/workflows/lint-and-test.yml create mode 100644 CONTRIBUTING.md diff --git a/.github/workflows/lint-and-test.yml b/.github/workflows/lint-and-test.yml new file mode 100644 index 0000000..c1646ef --- /dev/null +++ b/.github/workflows/lint-and-test.yml @@ -0,0 +1,60 @@ +name: Code Quality & Linting + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + lint: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pylint black isort + + - name: Check code formatting with black + run: | + black --check control_plane data_plane tests + + - name: Check import sorting with isort + run: | + isort --check-only control_plane data_plane tests + + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest pytest-cov pytest-asyncio + + - name: Run tests with pytest + run: | + pytest tests diff --git a/.gitignore b/.gitignore index 26ab1a5..7d2933e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,37 @@ +# Virtual environments venv +vevn +env +ENV +.venv + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so + + + +# Testing pytest.ini .pytest* -tests/__pycache__ -data_plane/__pycache__ -control_plane/__pycache__ \ No newline at end of file +.coverage +.coverage.* +htmlcov/ + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Environment +.env +.env.local +*.log + +# OS +.DS_Store \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..608b60c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,68 @@ +# Contributing to eBPF Beacon Detection System + +Thank you for your interest in contributing! This document provides guidelines and instructions for contributing to this project. + +## Getting Started + +1. **Fork the repository** on GitHub +2. **Clone your fork** locally: + ```bash + git clone https://github.com/yourusername/BeaconDetectionSystemGit.git + cd BeaconDetectionSystemGit + ``` + +3. **Set up the development environment**: + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + pip install -r requirements.txt + pip install -r requirements-dev.txt + ``` + +## Development Workflow + +1. **Create a feature branch**: + ```bash + git checkout -b feature/your-feature-name + ``` + +2. **Make your changes** and commit them with clear messages: + ```bash + git commit -m "Add: brief description of changes" + ``` + +3. **Run tests and linting**: + ```bash + pytest + pylint control_plane data_plane + black --check control_plane data_plane + ``` + +4. **Push to your fork**: + ```bash + git push origin feature/your-feature-name + ``` + +5. **Create a Pull Request** with a clear description of your changes + +## Testing + +- Write tests for new features in the `tests/` directory +- Ensure all tests pass: `pytest` +- Aim for high test coverage, especially for critical components +- Use descriptive test names that explain what is being tested + +## Reporting Issues + +- Use GitHub Issues for bug reports and feature requests +- Provide clear steps to reproduce bugs +- Include relevant system information and logs + +## Documentation + +- Update relevant documentation for new features +- Keep the README.md up to date + +## Questions? + +Feel free to open an issue or discussion for questions. diff --git a/control_plane/alerter.py b/control_plane/alerter.py index ff4c396..23a5a3a 100644 --- a/control_plane/alerter.py +++ b/control_plane/alerter.py @@ -14,7 +14,7 @@ import requests from tenacity import retry, stop_after_attempt, wait_exponential -logger = logging.getLogger('beacon_detect.control_plane.alerter') +logger = logging.getLogger("beacon_detect.control_plane.alerter") class AlertSeverity(Enum): @@ -24,16 +24,16 @@ class AlertSeverity(Enum): MEDIUM = "medium" HIGH = "high" CRITICAL = "critical" - + @property def syslog_priority(self): mapping = { - 'info': logging.INFO, - 'low': logging.WARNING, - 'medium': logging.WARNING, - 'high': logging.ERROR, - 'critical': logging.CRITICAL + "info": logging.INFO, + "low": logging.WARNING, + "medium": logging.WARNING, + "high": logging.ERROR, + "critical": logging.CRITICAL, } return mapping.get(self.value, logging.WARNING) @@ -47,16 +47,20 @@ class Alert: severity: AlertSeverity source: str details: Dict = field(default_factory=dict) - timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')) + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z") + ) tags: List = field(default_factory=list) - + def to_dict(self): def convert_value(v): if v is None: return None - if hasattr(v, 'item'): # numpy scalar + if hasattr(v, "item"): # numpy scalar return v.item() if isinstance(v, dict): return {k: convert_value(val) for k, val in v.items()} @@ -68,24 +72,24 @@ def convert_value(v): return v # Try to convert unknown types try: - return float(v) if '.' in str(v) else int(v) + return float(v) if "." in str(v) else int(v) except (ValueError, TypeError): return str(v) - + return { - 'alert_id': str(self.alert_id), - 'title': str(self.title), - 'description': str(self.description), - 'severity': str(self.severity.value), - 'source': str(self.source), - 'details': convert_value(self.details), - 'timestamp': str(self.timestamp), - 'tags': [str(t) for t in self.tags] + "alert_id": str(self.alert_id), + "title": str(self.title), + "description": str(self.description), + "severity": str(self.severity.value), + "source": str(self.source), + "details": convert_value(self.details), + "timestamp": str(self.timestamp), + "tags": [str(t) for t in self.tags], } - + def to_json(self): return json.dumps(self.to_dict(), indent=2) - + def to_syslog_message(self): return ( f"[{self.severity.value.upper()}] {self.title} | " @@ -99,18 +103,18 @@ def to_syslog_message(self): class AlertingConfig: enabled: bool = True - + # Syslog settings syslog_enabled: bool = True syslog_facility: str = "local0" syslog_address: str = "/dev/log" - + # File settings file_enabled: bool = True file_path: str = "/var/log/beacon-detect/alerts.json" file_max_size_mb: int = 100 file_backup_count: int = 5 - + # Webhook settings webhook_enabled: bool = False webhook_url: str = "" @@ -120,67 +124,64 @@ class AlertingConfig: class SyslogHandler: - + def __init__(self, config): self.config = config self._logger = None - + if config.syslog_enabled: self._setup_syslog() - + def _setup_syslog(self): - self._logger = logging.getLogger('beacon_detect.alerts.syslog') + self._logger = logging.getLogger("beacon_detect.alerts.syslog") self._logger.setLevel(logging.DEBUG) - + # Remove existing handlers self._logger.handlers = [] - + # Get facility facility_map = { - 'local0': logging.handlers.SysLogHandler.LOG_LOCAL0, - 'local1': logging.handlers.SysLogHandler.LOG_LOCAL1, - 'local2': logging.handlers.SysLogHandler.LOG_LOCAL2, - 'local3': logging.handlers.SysLogHandler.LOG_LOCAL3, - 'local4': logging.handlers.SysLogHandler.LOG_LOCAL4, - 'local5': logging.handlers.SysLogHandler.LOG_LOCAL5, - 'local6': logging.handlers.SysLogHandler.LOG_LOCAL6, - 'local7': logging.handlers.SysLogHandler.LOG_LOCAL7, + "local0": logging.handlers.SysLogHandler.LOG_LOCAL0, + "local1": logging.handlers.SysLogHandler.LOG_LOCAL1, + "local2": logging.handlers.SysLogHandler.LOG_LOCAL2, + "local3": logging.handlers.SysLogHandler.LOG_LOCAL3, + "local4": logging.handlers.SysLogHandler.LOG_LOCAL4, + "local5": logging.handlers.SysLogHandler.LOG_LOCAL5, + "local6": logging.handlers.SysLogHandler.LOG_LOCAL6, + "local7": logging.handlers.SysLogHandler.LOG_LOCAL7, } - facility = facility_map.get(self.config.syslog_facility, - logging.handlers.SysLogHandler.LOG_LOCAL0) - + facility = facility_map.get( + self.config.syslog_facility, logging.handlers.SysLogHandler.LOG_LOCAL0 + ) + try: # Try Unix socket first if os.path.exists(self.config.syslog_address): handler = logging.handlers.SysLogHandler( - address=self.config.syslog_address, - facility=facility + address=self.config.syslog_address, facility=facility ) else: # Fall back to localhost:514 handler = logging.handlers.SysLogHandler( - address=('localhost', 514), - facility=facility + address=("localhost", 514), facility=facility ) - - formatter = logging.Formatter( - 'beacon-detect[%(process)d]: %(message)s' - ) + + formatter = logging.Formatter("beacon-detect[%(process)d]: %(message)s") handler.setFormatter(formatter) self._logger.addHandler(handler) - + logger.info("Syslog handler initialized") except Exception as e: logger.error(f"Failed to initialize syslog: {e}") self._logger = None - + def send(self, alert): if not self._logger: return - + try: message = alert.to_syslog_message() self._logger.log(alert.severity.syslog_priority, message) @@ -190,52 +191,51 @@ def send(self, alert): class FileHandler: - - def __init__(self, config ): + def __init__(self, config): self.config = config self._file_path: Path = None self._file_handler = None self._logger = None - + if config.file_enabled: self._setup_file() - + def _setup_file(self): try: self._file_path = Path(self.config.file_path) - + # Create directory if needed self._file_path.parent.mkdir(parents=True, exist_ok=True) - + # Set up rotating file handler - self._logger = logging.getLogger('beacon_detect.alerts.file') + self._logger = logging.getLogger("beacon_detect.alerts.file") self._logger.setLevel(logging.DEBUG) self._logger.handlers = [] - + max_bytes = self.config.file_max_size_mb * 1024 * 1024 - + self._file_handler = logging.handlers.RotatingFileHandler( self._file_path, maxBytes=max_bytes, - backupCount=self.config.file_backup_count + backupCount=self.config.file_backup_count, ) - + # Use a simple formatter that just writes the message - self._file_handler.setFormatter(logging.Formatter('%(message)s')) + self._file_handler.setFormatter(logging.Formatter("%(message)s")) self._logger.addHandler(self._file_handler) - + logger.info(f"File alert handler initialized: {self._file_path}") except Exception as e: logger.error(f"Failed to initialize file handler: {e}") self._logger = None - + def send(self, alert: Alert): if not self._logger: return - + try: # Write JSON on single line json_str = json.dumps(alert.to_dict()) @@ -246,50 +246,43 @@ def send(self, alert: Alert): class WebhookHandler: - def __init__(self, config: AlertingConfig): self.config = config self._session = None - + if config.webhook_enabled and config.webhook_url: self._setup_session() - + def _setup_session(self): self._session = requests.Session() - + # Set default headers - headers = { - 'Content-Type': 'application/json', - 'User-Agent': 'BeaconDetect/1.0' - } + headers = {"Content-Type": "application/json", "User-Agent": "BeaconDetect/1.0"} headers.update(self.config.webhook_headers) self._session.headers.update(headers) - + logger.info(f"Webhook handler initialized: {self.config.webhook_url}") - + @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=10) + stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10) ) def _send_request(self, payload: str): if not self._session: return - + response = self._session.post( - self.config.webhook_url, - data=payload, - timeout=self.config.webhook_timeout + self.config.webhook_url, data=payload, timeout=self.config.webhook_timeout ) response.raise_for_status() - + def send(self, alert: Alert): if not self._session or not self.config.webhook_url: return - + try: payload = alert.to_json() self._send_request(payload) @@ -299,51 +292,50 @@ def send(self, alert: Alert): class AlertManager: - - def __init__(self, config = None): + + def __init__(self, config=None): self.config = config or AlertingConfig() - + # Initialize handlers self._syslog = SyslogHandler(self.config) self._file = FileHandler(self.config) self._webhook = WebhookHandler(self.config) - + # Alert queue for async processing self._alert_queue: queue.Queue = queue.Queue(maxsize=1000) - + # Processing thread self._running = False self._processor_thread = None - + # Statistics self._alerts_sent = 0 self._alerts_failed = 0 self._alerts_by_severity: Dict[str, int] = {} - + # Alert history (for deduplication and review) self._recent_alerts: List[Alert] = [] self._max_recent_alerts = 1000 - + logger.info("AlertManager initialized") - + def start(self): if self._running: return - + self._running = True self._processor_thread = threading.Thread( - target=self._process_alerts, - daemon=True + target=self._process_alerts, daemon=True ) self._processor_thread.start() logger.info("Alert processing started") - + def stop(self): self._running = False - + # Process remaining alerts while not self._alert_queue.empty(): try: @@ -351,12 +343,12 @@ def stop(self): self._deliver_alert(alert) except queue.Empty: break - + if self._processor_thread: self._processor_thread.join(timeout=5) - + logger.info("Alert processing stopped") - + def _process_alerts(self): while self._running: @@ -369,11 +361,11 @@ def _process_alerts(self): continue except Exception as e: logger.error(f"Alert processing error: {e}") - + def _deliver_alert(self, alert: Alert): success = True - + # Syslog if self.config.syslog_enabled: try: @@ -381,7 +373,7 @@ def _deliver_alert(self, alert: Alert): except Exception as e: logger.error(f"Syslog delivery failed: {e}") success = False - + # File if self.config.file_enabled: try: @@ -389,7 +381,7 @@ def _deliver_alert(self, alert: Alert): except Exception as e: logger.error(f"File delivery failed: {e}") success = False - + # Webhook if self.config.webhook_enabled: try: @@ -397,51 +389,51 @@ def _deliver_alert(self, alert: Alert): except Exception as e: logger.error(f"Webhook delivery failed: {e}") success = False - + # Update statistics if success: self._alerts_sent += 1 else: self._alerts_failed += 1 - + severity_key = alert.severity.value self._alerts_by_severity[severity_key] = ( self._alerts_by_severity.get(severity_key, 0) + 1 ) - + # Store in recent alerts self._recent_alerts.append(alert) if len(self._recent_alerts) > self._max_recent_alerts: - self._recent_alerts = self._recent_alerts[-self._max_recent_alerts:] - + self._recent_alerts = self._recent_alerts[-self._max_recent_alerts :] + def send_alert(self, alert: Alert): if not self.config.enabled: logger.debug("Alerting disabled, discarding alert") return - + try: self._alert_queue.put_nowait(alert) logger.debug(f"Alert queued: {alert.alert_id}") except queue.Full: logger.error("Alert queue full, discarding alert") self._alerts_failed += 1 - + def send_alert_sync(self, alert): if not self.config.enabled: return - + self._deliver_alert(alert) - + def create_and_send( self, title: str, description: str, severity: AlertSeverity, - source = "beacon_detector", - details = None, - tags = None + source="beacon_detector", + details=None, + tags=None, ): alert = Alert( alert_id=f"alert-{int(time.time())}-{self._alerts_sent}", @@ -450,37 +442,33 @@ def create_and_send( severity=severity, source=source, details=details or {}, - tags=tags or [] + tags=tags or [], ) - + self.send_alert(alert) return alert - - def get_recent_alerts( - self, - limit: int = 50, - severity = None - ): + + def get_recent_alerts(self, limit: int = 50, severity=None): alerts = self._recent_alerts[-limit:] - + if severity: alerts = [a for a in alerts if a.severity == severity] - + return [a.to_dict() for a in reversed(alerts)] - + @property def statistics(self): return { - 'enabled': self.config.enabled, - 'alerts_sent': self._alerts_sent, - 'alerts_failed': self._alerts_failed, - 'alerts_by_severity': self._alerts_by_severity.copy(), - 'queue_size': self._alert_queue.qsize(), - 'recent_alerts_count': len(self._recent_alerts), - 'channels': { - 'syslog': self.config.syslog_enabled, - 'file': self.config.file_enabled, - 'webhook': self.config.webhook_enabled - } + "enabled": self.config.enabled, + "alerts_sent": self._alerts_sent, + "alerts_failed": self._alerts_failed, + "alerts_by_severity": self._alerts_by_severity.copy(), + "queue_size": self._alert_queue.qsize(), + "recent_alerts_count": len(self._recent_alerts), + "channels": { + "syslog": self.config.syslog_enabled, + "file": self.config.file_enabled, + "webhook": self.config.webhook_enabled, + }, } diff --git a/control_plane/analyzer.py b/control_plane/analyzer.py index 29e2e1a..2adc5c8 100644 --- a/control_plane/analyzer.py +++ b/control_plane/analyzer.py @@ -5,34 +5,33 @@ from datetime import datetime, timezone from typing import Dict, List -from .storage import ConnectionStorage +from .alerter import Alert, AlertManager, AlertSeverity from .detector import BeaconDetector, DetectionResult -from .alerter import AlertManager, Alert, AlertSeverity +from .storage import ConnectionStorage -logger = logging.getLogger('beacon_detect.control_plane.analyzer') +logger = logging.getLogger("beacon_detect.control_plane.analyzer") @dataclass class AnalyzerConfig: # How often to run analysis (seconds) analysis_interval: int = 60 - + # Minimum connections required for analysis min_connections: int = 10 - + # Minimum duration for a pair to be analyzed min_duration: float = 300.0 # 5 minutes - + # Alert cooldown - don't re-alert same pair within this time alert_cooldown: int = 300 # 5 minutes - + # Maximum pairs to analyze per run (for performance) max_pairs_per_run: int = 10000 class AnalysisRun: - def __init__(self, run_id: str): self.run_id = run_id self.start_time = datetime.now(timezone.utc) @@ -42,43 +41,42 @@ def __init__(self, run_id: str): self.alerts_generated = 0 self.errors = 0 self.results: List[DetectionResult] = [] - + def complete(self): self.end_time = datetime.now(timezone.utc) - + @property def duration_seconds(self): end = self.end_time or datetime.now(timezone.utc) return (end - self.start_time).total_seconds() - + def to_dict(self): return { - 'run_id': str(self.run_id), - 'start_time': self.start_time.isoformat() + 'Z', - 'end_time': self.end_time.isoformat() + 'Z' if self.end_time else None, - 'duration_seconds': float(round(self.duration_seconds, 2)), - 'pairs_analyzed': int(self.pairs_analyzed), - 'beacons_detected': int(self.beacons_detected), - 'alerts_generated': int(self.alerts_generated), - 'errors': int(self.errors) + "run_id": str(self.run_id), + "start_time": self.start_time.isoformat() + "Z", + "end_time": self.end_time.isoformat() + "Z" if self.end_time else None, + "duration_seconds": float(round(self.duration_seconds, 2)), + "pairs_analyzed": int(self.pairs_analyzed), + "beacons_detected": int(self.beacons_detected), + "alerts_generated": int(self.alerts_generated), + "errors": int(self.errors), } class ConnectionAnalyzer: - def __init__( self, storage: ConnectionStorage, detector: BeaconDetector, alert_manager: AlertManager, - config = None + config=None, ): """ Initialize the analyzer. - + Args: storage: ConnectionStorage instance detector: BeaconDetector instance @@ -89,90 +87,91 @@ def __init__( self.detector = detector self.alert_manager = alert_manager self.config = config or AnalyzerConfig() - + # Thread safety lock for shared state self._lock = threading.RLock() - + # Track alert cooldowns: pair_key -> last_alert_time self._alert_cooldowns: Dict[str, float] = {} - + # Track known beacons for monitoring self._known_beacons: Dict[str, DetectionResult] = {} - + # Analysis run history self._run_history: List[AnalysisRun] = [] self._max_run_history = 100 - + # Running state self._running = False self._analysis_thread = None self._stop_event = threading.Event() - + # Statistics self._total_runs = 0 self._total_beacons_detected = 0 self._total_alerts_generated = 0 - + # Run counter for IDs self._run_counter = 0 - - logger.info(f"ConnectionAnalyzer initialized: interval={self.config.analysis_interval}s") - + + logger.info( + f"ConnectionAnalyzer initialized: interval={self.config.analysis_interval}s" + ) + def start(self): if self._running: logger.warning("Analyzer already running") return - + self._running = True self._stop_event.clear() - + self._analysis_thread = threading.Thread( - target=self._analysis_loop, - daemon=True + target=self._analysis_loop, daemon=True ) self._analysis_thread.start() - + logger.info("ConnectionAnalyzer started") - + def stop(self): self._running = False self._stop_event.set() - + if self._analysis_thread: self._analysis_thread.join(timeout=10) - + logger.info("ConnectionAnalyzer stopped") - + def _analysis_loop(self): logger.info("Analysis loop started") - + while not self._stop_event.wait(timeout=self.config.analysis_interval): try: self.run_analysis() except Exception as e: logger.error(f"Analysis run failed: {e}", exc_info=True) - + logger.info("Analysis loop stopped") - + def run_analysis(self): self._run_counter += 1 run_id = f"run-{self._run_counter}-{int(time.time())}" run = AnalysisRun(run_id) - + logger.info(f"Starting analysis run: {run_id}") - + try: - # logger.info(f"self.config.min_connections {self.config.min_connections} and self.config.min_duration {self.config.min_duration}") + # logger.info(f"self.config.min_connections {self.config.min_connections} and self.config.min_duration {self.config.min_duration}") # Get analyzable pairs from storage pairs = self.storage.get_analyzable_pairs( min_connections=self.config.min_connections, - min_duration=self.config.min_duration + min_duration=self.config.min_duration, ) - # logger.info(f"pairs{pairs}") + # logger.info(f"pairs{pairs}") # Limit pairs for performance if len(pairs) > self.config.max_pairs_per_run: logger.warning( @@ -181,21 +180,21 @@ def run_analysis(self): ) # Prioritize pairs with more connections pairs.sort(key=lambda p: p.connection_count, reverse=True) - pairs = pairs[:self.config.max_pairs_per_run] - + pairs = pairs[: self.config.max_pairs_per_run] + run.pairs_analyzed = len(pairs) logger.info(f"Analyzing {len(pairs)} connection pairs") - + # Run detection results = self.detector.batch_analyze(pairs) run.results = results - + # Process results beacons = [r for r in results if r.is_beacon] run.beacons_detected = len(beacons) - + logger.info(f"Detection complete: {len(beacons)} beacons found") - + # Generate alerts for new/updated beacons for result in beacons: try: @@ -204,86 +203,89 @@ def run_analysis(self): run.alerts_generated += 1 with self._lock: self._alert_cooldowns[result.pair_key] = time.time() - + # Update known beacons with self._lock: self._known_beacons[result.pair_key] = result - + except Exception as e: logger.error(f"Error generating alert for {result.pair_key}: {e}") run.errors += 1 - + # Clean up cooldowns for pairs no longer seen as beacons with self._lock: current_beacon_keys = {r.pair_key for r in beacons} - stale_keys = [k for k in self._known_beacons if k not in current_beacon_keys] + stale_keys = [ + k for k in self._known_beacons if k not in current_beacon_keys + ] for key in stale_keys: del self._known_beacons[key] - + except Exception as e: logger.error(f"Analysis run error: {e}", exc_info=True) run.errors += 1 - + run.complete() - + # Update statistics self._total_runs += 1 self._total_beacons_detected += run.beacons_detected self._total_alerts_generated += run.alerts_generated - + # Store run history self._run_history.append(run) if len(self._run_history) > self._max_run_history: - self._run_history = self._run_history[-self._max_run_history:] - + self._run_history = self._run_history[-self._max_run_history :] + logger.info( f"Analysis run complete: {run_id} - " f"{run.pairs_analyzed} pairs, {run.beacons_detected} beacons, " f"{run.alerts_generated} alerts, {run.duration_seconds:.2f}s" ) - + return run - + def _should_alert(self, result): pair_key = result.pair_key - + with self._lock: # Check cooldown last_alert = self._alert_cooldowns.get(pair_key, 0) if time.time() - last_alert < self.config.alert_cooldown: logger.debug(f"Skipping alert for {pair_key}: cooldown active") return False - + # Check if this is a new detection or significant change previous = self._known_beacons.get(pair_key) if previous is None: # New beacon return True - + # Check for significant score increase if result.combined_score - previous.combined_score > 0.1: return True - + # Check for confidence upgrade - conf_order = ['none', 'low', 'medium', 'high', 'critical'] - if (conf_order.index(result.confidence.value) > - conf_order.index(previous.confidence.value)): + conf_order = ["none", "low", "medium", "high", "critical"] + if conf_order.index(result.confidence.value) > conf_order.index( + previous.confidence.value + ): return True - + return False - + def _generate_alert(self, result): # Map confidence to severity severity_map = { - 'none': AlertSeverity.INFO, - 'low': AlertSeverity.LOW, - 'medium': AlertSeverity.MEDIUM, - 'high': AlertSeverity.HIGH, - 'critical': AlertSeverity.CRITICAL + "none": AlertSeverity.INFO, + "low": AlertSeverity.LOW, + "medium": AlertSeverity.MEDIUM, + "high": AlertSeverity.HIGH, + "critical": AlertSeverity.CRITICAL, } severity = severity_map.get(result.confidence.value, AlertSeverity.MEDIUM) - + # Create alert alert = Alert( alert_id=f"beacon-{result.pair_key}-{int(time.time())}", @@ -298,35 +300,35 @@ def _generate_alert(self, result): ), severity=severity, source="beacon_detector", - details=result.to_dict() + details=result.to_dict(), ) - + # Send alert through alert manager self.alert_manager.send_alert(alert) - + logger.warning( f"Alert generated: {alert.alert_id} - " f"{result.pair_key} (score={result.combined_score:.3f})" ) - + def get_known_beacons(self): """Get list of known beacons (thread-safe).""" with self._lock: return list(self._known_beacons.values()) - + def get_run_history(self, limit: int = 10): runs = self._run_history[-limit:] return [r.to_dict() for r in reversed(runs)] - + @property def statistics(self): """Get analyzer statistics""" return { - 'running': self._running, - 'analysis_interval': self.config.analysis_interval, - 'total_runs': self._total_runs, - 'total_beacons_detected': self._total_beacons_detected, - 'total_alerts_generated': self._total_alerts_generated, - 'current_known_beacons': len(self._known_beacons), - 'active_cooldowns': len(self._alert_cooldowns) + "running": self._running, + "analysis_interval": self.config.analysis_interval, + "total_runs": self._total_runs, + "total_beacons_detected": self._total_beacons_detected, + "total_alerts_generated": self._total_alerts_generated, + "current_known_beacons": len(self._known_beacons), + "active_cooldowns": len(self._alert_cooldowns), } diff --git a/control_plane/cli.py b/control_plane/cli.py index c03f78f..0e76262 100644 --- a/control_plane/cli.py +++ b/control_plane/cli.py @@ -22,48 +22,49 @@ import sys import time from datetime import datetime -from urllib.request import urlopen, Request -from urllib.error import URLError, HTTPError +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + # ANSI color codes class Colors: - RESET = '\033[0m' - BOLD = '\033[1m' - DIM = '\033[2m' - + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + # Foreground colors - BLACK = '\033[30m' - RED = '\033[31m' - GREEN = '\033[32m' - YELLOW = '\033[33m' - BLUE = '\033[34m' - MAGENTA = '\033[35m' - CYAN = '\033[36m' - WHITE = '\033[37m' - + BLACK = "\033[30m" + RED = "\033[31m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + BLUE = "\033[34m" + MAGENTA = "\033[35m" + CYAN = "\033[36m" + WHITE = "\033[37m" + # Bright foreground - BRIGHT_RED = '\033[91m' - BRIGHT_GREEN = '\033[92m' - BRIGHT_YELLOW = '\033[93m' - BRIGHT_BLUE = '\033[94m' - BRIGHT_MAGENTA = '\033[95m' - BRIGHT_CYAN = '\033[96m' - BRIGHT_WHITE = '\033[97m' - + BRIGHT_RED = "\033[91m" + BRIGHT_GREEN = "\033[92m" + BRIGHT_YELLOW = "\033[93m" + BRIGHT_BLUE = "\033[94m" + BRIGHT_MAGENTA = "\033[95m" + BRIGHT_CYAN = "\033[96m" + BRIGHT_WHITE = "\033[97m" + # Background colors - BG_RED = '\033[41m' - BG_GREEN = '\033[42m' - BG_YELLOW = '\033[43m' - BG_BLUE = '\033[44m' + BG_RED = "\033[41m" + BG_GREEN = "\033[42m" + BG_YELLOW = "\033[43m" + BG_BLUE = "\033[44m" def supports_color(): - if os.getenv('NO_COLOR'): + if os.getenv("NO_COLOR"): return False - if os.getenv('FORCE_COLOR'): + if os.getenv("FORCE_COLOR"): return True - return hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() USE_COLOR = supports_color() @@ -73,12 +74,12 @@ def c(text: str, *colors): if not USE_COLOR or not colors: return text - return ''.join(colors) + str(text) + Colors.RESET + return "".join(colors) + str(text) + Colors.RESET def clear_screen(): - os.system('cls' if os.name == 'nt' else 'clear') + os.system("cls" if os.name == "nt" else "clear") def get_terminal_width(): @@ -90,7 +91,7 @@ def get_terminal_width(): class BeaconCLI: - + BANNER = r""" ____ ____ __ __ / __ )___ ____ __________ ____ / __ \___ / /____ _____/ /_ @@ -99,29 +100,29 @@ class BeaconCLI: /_____/\___/\__,_/\___/\____/_/ /_//_____/\___/\__/\___/\___/\__/ """ - - def __init__(self, host: str = 'localhost', port: int = 9090): + + def __init__(self, host: str = "localhost", port: int = 9090): self.host = host self.port = port - self.base_url = f'http://{host}:{port}' - + self.base_url = f"http://{host}:{port}" + def _api_request(self, endpoint: str): try: - url = f'{self.base_url}{endpoint}' - req = Request(url, headers={'Accept': 'application/json'}) + url = f"{self.base_url}{endpoint}" + req = Request(url, headers={"Accept": "application/json"}) with urlopen(req, timeout=5) as response: return json.loads(response.read().decode()) except (URLError, HTTPError, json.JSONDecodeError) as e: return None - + def print_banner(self): print(c(self.BANNER, Colors.CYAN)) print(c(" Real-time Network Beacon Detection System", Colors.DIM)) print(c(f" Server: {self.base_url}", Colors.DIM)) print() - + def print_header(self, title: str): width = get_terminal_width() @@ -130,9 +131,8 @@ def print_header(self, title: str): print(c(f" {title}", Colors.BOLD, Colors.BRIGHT_WHITE)) print(c("═" * width, Colors.BLUE)) print() - - def print_table(self, headers, rows, - col_widths = None): + + def print_table(self, headers, rows, col_widths=None): if not col_widths: # Calculate column widths @@ -141,19 +141,21 @@ def print_table(self, headers, rows, for i, cell in enumerate(row): if i < len(col_widths): col_widths[i] = max(col_widths[i], len(str(cell))) - + # Print header header_line = " " for i, h in enumerate(headers): - header_line += c(h.ljust(col_widths[i] + 2), Colors.BOLD, Colors.BRIGHT_CYAN) + header_line += c( + h.ljust(col_widths[i] + 2), Colors.BOLD, Colors.BRIGHT_CYAN + ) print(header_line) - + # Print separator sep_line = " " for w in col_widths: sep_line += c("-" * w + " ", Colors.DIM) print(sep_line) - + # Print rows for row in rows: row_line = " " @@ -161,7 +163,7 @@ def print_table(self, headers, rows, if i < len(col_widths): row_line += str(cell).ljust(col_widths[i] + 2) print(row_line) - + def format_score(self, score: float): pct = score * 100 @@ -173,7 +175,7 @@ def format_score(self, score: float): return c(f"{pct:5.1f}%", Colors.YELLOW) else: return c(f"{pct:5.1f}%", Colors.GREEN) - + def format_severity(self, score: float): if score >= 0.9: @@ -184,7 +186,7 @@ def format_severity(self, score: float): return c("MEDIUM ", Colors.YELLOW) else: return c("LOW ", Colors.GREEN) - + def format_duration(self, seconds: float): if seconds < 60: @@ -195,7 +197,7 @@ def format_duration(self, seconds: float): return f"{seconds/3600:.1f}h" else: return f"{seconds/86400:.1f}d" - + def format_count(self, count: int): if count >= 10000: @@ -204,348 +206,472 @@ def format_count(self, count: int): return c(f"{count:,}", Colors.YELLOW) else: return f"{count:,}" - + def cmd_status(self): self.print_banner() self.print_header("SYSTEM STATUS") - + # Check health - health = self._api_request('/api/v1/health') - status = self._api_request('/api/v1/status') - + health = self._api_request("/api/v1/health") + status = self._api_request("/api/v1/status") + if not health or not status: print(c(" ✗ OFFLINE", Colors.BRIGHT_RED, Colors.BOLD)) print(c(f" Cannot connect to server at {self.base_url}", Colors.DIM)) print() return False - + # System is live print(c(" ✓ LIVE", Colors.BRIGHT_GREEN, Colors.BOLD)) print() - + # Print stats - uptime = status.get('uptime_seconds', 0) - uptime_str = self.format_duration(uptime) if uptime else 'N/A' - + uptime = status.get("uptime_seconds", 0) + uptime_str = self.format_duration(uptime) if uptime else "N/A" + stats = [ ("Server", f"{self.host}:{self.port}"), ("Status", c("Running", Colors.GREEN)), ("Uptime", uptime_str), ("Events Received", f"{status.get('events_received', 0):,}"), - ("Connection Pairs", f"{status.get('storage', {}).get('pairs_count', 0):,}"), - ("Beacons Detected", f"{status.get('analyzer', {}).get('beacons_detected', 0):,}"), - ("Alerts Generated", f"{status.get('alerter', {}).get('alerts_sent', 0):,}"), - ("Analysis Runs", f"{status.get('analyzer', {}).get('analysis_runs', 0):,}"), + ( + "Connection Pairs", + f"{status.get('storage', {}).get('pairs_count', 0):,}", + ), + ( + "Beacons Detected", + f"{status.get('analyzer', {}).get('beacons_detected', 0):,}", + ), + ( + "Alerts Generated", + f"{status.get('alerter', {}).get('alerts_sent', 0):,}", + ), + ( + "Analysis Runs", + f"{status.get('analyzer', {}).get('analysis_runs', 0):,}", + ), ] - + for label, value in stats: print(f" {c(label + ':', Colors.DIM):30} {value}") - + print() return True - - def cmd_beacons(self, min_score: float = 0.0, limit: int = 50, csv_output: bool = False): + + def cmd_beacons( + self, min_score: float = 0.0, limit: int = 50, csv_output: bool = False + ): if not csv_output: self.print_banner() self.print_header("DETECTED BEACONS") - - data = self._api_request('/api/v1/beacons') + + data = self._api_request("/api/v1/beacons") if not data: if not csv_output: print(c(" ✗ Cannot connect to server", Colors.RED)) return - - beacons = data.get('beacons', []) - + + beacons = data.get("beacons", []) + # Filter by score - beacons = [b for b in beacons if b.get('combined_score', 0) >= min_score] - + beacons = [b for b in beacons if b.get("combined_score", 0) >= min_score] + # Sort by score descending - beacons.sort(key=lambda x: x.get('combined_score', 0), reverse=True) - + beacons.sort(key=lambda x: x.get("combined_score", 0), reverse=True) + # Limit results beacons = beacons[:limit] - + if not beacons: if not csv_output: print(c(" No beacons detected", Colors.DIM)) print() return - + if csv_output: # CSV output - print("Score,Severity,Source_IP,Source_Port,Dest_IP,Dest_Port,Protocol,Connections,Interval,Jitter") + print( + "Score,Severity,Source_IP,Source_Port,Dest_IP,Dest_Port,Protocol,Connections,Interval,Jitter" + ) for b in beacons: - score = b.get('combined_score', 0) - severity = "CRITICAL" if score >= 0.9 else "HIGH" if score >= 0.8 else "MEDIUM" if score >= 0.7 else "LOW" - interval = b.get('interval_stats', {}).get('mean', 0) - jitter = b.get('interval_stats', {}).get('jitter', 0) - print(f"{score:.4f},{severity},{b.get('src_ip', '-')},-,{b.get('dst_ip', '-')},{b.get('dst_port', '-')},{b.get('protocol', 'TCP')},{b.get('connection_count', 0)},{interval:.2f},{jitter:.2f}") + score = b.get("combined_score", 0) + severity = ( + "CRITICAL" + if score >= 0.9 + else "HIGH" if score >= 0.8 else "MEDIUM" if score >= 0.7 else "LOW" + ) + interval = b.get("interval_stats", {}).get("mean", 0) + jitter = b.get("interval_stats", {}).get("jitter", 0) + print( + f"{score:.4f},{severity},{b.get('src_ip', '-')},-,{b.get('dst_ip', '-')},{b.get('dst_port', '-')},{b.get('protocol', 'TCP')},{b.get('connection_count', 0)},{interval:.2f},{jitter:.2f}" + ) else: # Table output - print(f" Found {c(str(len(beacons)), Colors.BOLD)} beacon(s) with score >= {min_score:.1%}") + print( + f" Found {c(str(len(beacons)), Colors.BOLD)} beacon(s) with score >= {min_score:.1%}" + ) print() - - headers = ["SCORE", "SEVERITY", "SOURCE IP", "DEST IP", "PORT", "PROTO", "CONNS", "INTERVAL", "JITTER"] + + headers = [ + "SCORE", + "SEVERITY", + "SOURCE IP", + "DEST IP", + "PORT", + "PROTO", + "CONNS", + "INTERVAL", + "JITTER", + ] rows = [] - + for b in beacons: - score = b.get('combined_score', 0) - interval = b.get('interval_stats', {}).get('mean', 0) - jitter = b.get('interval_stats', {}).get('jitter', 0) - - rows.append([ - self.format_score(score), - self.format_severity(score), - b.get('src_ip', '-'), - b.get('dst_ip', '-'), - str(b.get('dst_port', '-')), - b.get('protocol', 'TCP'), - self.format_count(b.get('connection_count', 0)), - f"{interval:.1f}s", - f"{jitter:.2f}s" - ]) - + score = b.get("combined_score", 0) + interval = b.get("interval_stats", {}).get("mean", 0) + jitter = b.get("interval_stats", {}).get("jitter", 0) + + rows.append( + [ + self.format_score(score), + self.format_severity(score), + b.get("src_ip", "-"), + b.get("dst_ip", "-"), + str(b.get("dst_port", "-")), + b.get("protocol", "TCP"), + self.format_count(b.get("connection_count", 0)), + f"{interval:.1f}s", + f"{jitter:.2f}s", + ] + ) + self.print_table(headers, rows, [7, 10, 18, 18, 6, 5, 8, 10, 8]) print() - - def cmd_long_connections(self, min_duration = 3600, limit = 50, csv_output = False): + + def cmd_long_connections(self, min_duration=3600, limit=50, csv_output=False): if not csv_output: self.print_banner() self.print_header("LONG CONNECTIONS") - - data = self._api_request(f'/api/v1/connections?limit=500') + + data = self._api_request(f"/api/v1/connections?limit=500") if not data: if not csv_output: print(c(" ✗ Cannot connect to server", Colors.RED)) return - - pairs = data.get('pairs', []) - + + pairs = data.get("pairs", []) + # Filter by duration - pairs = [p for p in pairs if p.get('duration_seconds', 0) >= min_duration] - + pairs = [p for p in pairs if p.get("duration_seconds", 0) >= min_duration] + # Sort by duration descending - pairs.sort(key=lambda x: x.get('duration_seconds', 0), reverse=True) - + pairs.sort(key=lambda x: x.get("duration_seconds", 0), reverse=True) + # Limit results pairs = pairs[:limit] - + if not pairs: if not csv_output: - print(c(f" No connections longer than {self.format_duration(min_duration)}", Colors.DIM)) + print( + c( + f" No connections longer than {self.format_duration(min_duration)}", + Colors.DIM, + ) + ) print() return - + if csv_output: - print("Duration,Source_IP,Dest_IP,Dest_Port,Protocol,Connections,First_Seen,Last_Seen") + print( + "Duration,Source_IP,Dest_IP,Dest_Port,Protocol,Connections,First_Seen,Last_Seen" + ) for p in pairs: - duration = p.get('duration_seconds', 0) - print(f"{duration:.0f},{p.get('src_ip', '-')},{p.get('dst_ip', '-')},{p.get('dst_port', '-')},{p.get('protocol', 'TCP')},{p.get('connection_count', 0)},{p.get('first_seen', '-')},{p.get('last_seen', '-')}") + duration = p.get("duration_seconds", 0) + print( + f"{duration:.0f},{p.get('src_ip', '-')},{p.get('dst_ip', '-')},{p.get('dst_port', '-')},{p.get('protocol', 'TCP')},{p.get('connection_count', 0)},{p.get('first_seen', '-')},{p.get('last_seen', '-')}" + ) else: - print(f" Found {c(str(len(pairs)), Colors.BOLD)} connection(s) longer than {self.format_duration(min_duration)}") + print( + f" Found {c(str(len(pairs)), Colors.BOLD)} connection(s) longer than {self.format_duration(min_duration)}" + ) print() - - headers = ["DURATION", "SOURCE IP", "DEST IP", "PORT", "PROTO", "CONNS", "FIRST SEEN", "LAST SEEN"] + + headers = [ + "DURATION", + "SOURCE IP", + "DEST IP", + "PORT", + "PROTO", + "CONNS", + "FIRST SEEN", + "LAST SEEN", + ] rows = [] - + for p in pairs: - duration = p.get('duration_seconds', 0) - first_seen = p.get('first_seen', '-') - last_seen = p.get('last_seen', '-') - + duration = p.get("duration_seconds", 0) + first_seen = p.get("first_seen", "-") + last_seen = p.get("last_seen", "-") + # Format timestamps - if first_seen and first_seen != '-': + if first_seen and first_seen != "-": try: - dt = datetime.fromisoformat(first_seen.replace('Z', '+00:00')) - first_seen = dt.strftime('%Y-%m-%d %H:%M') + dt = datetime.fromisoformat(first_seen.replace("Z", "+00:00")) + first_seen = dt.strftime("%Y-%m-%d %H:%M") except: pass - - if last_seen and last_seen != '-': + + if last_seen and last_seen != "-": try: - dt = datetime.fromisoformat(last_seen.replace('Z', '+00:00')) - last_seen = dt.strftime('%Y-%m-%d %H:%M') + dt = datetime.fromisoformat(last_seen.replace("Z", "+00:00")) + last_seen = dt.strftime("%Y-%m-%d %H:%M") except: pass - + # Color code duration dur_str = self.format_duration(duration) if duration >= 86400: # > 1 day dur_str = c(dur_str, Colors.BRIGHT_RED, Colors.BOLD) elif duration >= 3600: # > 1 hour dur_str = c(dur_str, Colors.YELLOW) - - rows.append([ - dur_str, - p.get('src_ip', '-'), - p.get('dst_ip', '-'), - str(p.get('dst_port', '-')), - p.get('protocol', 'TCP'), - str(p.get('connection_count', 0)), - first_seen, - last_seen - ]) - + + rows.append( + [ + dur_str, + p.get("src_ip", "-"), + p.get("dst_ip", "-"), + str(p.get("dst_port", "-")), + p.get("protocol", "TCP"), + str(p.get("connection_count", 0)), + first_seen, + last_seen, + ] + ) + self.print_table(headers, rows, [10, 18, 18, 6, 5, 8, 18, 18]) print() - + def cmd_connections(self, limit: int = 100, csv_output: bool = False): if not csv_output: self.print_banner() self.print_header("TRACKED CONNECTIONS") - - data = self._api_request(f'/api/v1/connections?limit={limit}') + + data = self._api_request(f"/api/v1/connections?limit={limit}") if not data: if not csv_output: print(c(" ✗ Cannot connect to server", Colors.RED)) return - - pairs = data.get('pairs', []) - + + pairs = data.get("pairs", []) + # Sort by connection count descending - pairs.sort(key=lambda x: x.get('connection_count', 0), reverse=True) - + pairs.sort(key=lambda x: x.get("connection_count", 0), reverse=True) + if not pairs: if not csv_output: print(c(" No connections tracked", Colors.DIM)) print() return - + if csv_output: - print("Source_IP,Dest_IP,Dest_Port,Protocol,Connections,Duration,First_Seen,Last_Seen") + print( + "Source_IP,Dest_IP,Dest_Port,Protocol,Connections,Duration,First_Seen,Last_Seen" + ) for p in pairs: - print(f"{p.get('src_ip', '-')},{p.get('dst_ip', '-')},{p.get('dst_port', '-')},{p.get('protocol', 'TCP')},{p.get('connection_count', 0)},{p.get('duration_seconds', 0):.0f},{p.get('first_seen', '-')},{p.get('last_seen', '-')}") + print( + f"{p.get('src_ip', '-')},{p.get('dst_ip', '-')},{p.get('dst_port', '-')},{p.get('protocol', 'TCP')},{p.get('connection_count', 0)},{p.get('duration_seconds', 0):.0f},{p.get('first_seen', '-')},{p.get('last_seen', '-')}" + ) else: print(f" Showing {c(str(len(pairs)), Colors.BOLD)} connection pair(s)") print() - - headers = ["SOURCE IP", "DEST IP", "PORT", "PROTO", "CONNS", "DURATION", "LAST SEEN"] + + headers = [ + "SOURCE IP", + "DEST IP", + "PORT", + "PROTO", + "CONNS", + "DURATION", + "LAST SEEN", + ] rows = [] - + for p in pairs[:limit]: - duration = p.get('duration_seconds', 0) - last_seen = p.get('last_seen', '-') - - if last_seen and last_seen != '-': + duration = p.get("duration_seconds", 0) + last_seen = p.get("last_seen", "-") + + if last_seen and last_seen != "-": try: - dt = datetime.fromisoformat(last_seen.replace('Z', '+00:00')) - last_seen = dt.strftime('%m-%d %H:%M') + dt = datetime.fromisoformat(last_seen.replace("Z", "+00:00")) + last_seen = dt.strftime("%m-%d %H:%M") except: pass - - rows.append([ - p.get('src_ip', '-'), - p.get('dst_ip', '-'), - str(p.get('dst_port', '-')), - p.get('protocol', 'TCP'), - self.format_count(p.get('connection_count', 0)), - self.format_duration(duration), - last_seen - ]) - + + rows.append( + [ + p.get("src_ip", "-"), + p.get("dst_ip", "-"), + str(p.get("dst_port", "-")), + p.get("protocol", "TCP"), + self.format_count(p.get("connection_count", 0)), + self.format_duration(duration), + last_seen, + ] + ) + self.print_table(headers, rows, [18, 18, 6, 5, 8, 10, 14]) print() - + def cmd_watch(self, interval: int = 5): try: while True: clear_screen() self.print_banner() - + # Check if live - health = self._api_request('/api/v1/health') - status = self._api_request('/api/v1/status') - + health = self._api_request("/api/v1/health") + status = self._api_request("/api/v1/status") + if not health or not status: print(c(" ✗ SYSTEM OFFLINE", Colors.BRIGHT_RED, Colors.BOLD)) - print(c(f" Waiting for connection to {self.base_url}...", Colors.DIM)) + print( + c( + f" Waiting for connection to {self.base_url}...", + Colors.DIM, + ) + ) else: # Status line - uptime = status.get('uptime_seconds', 0) - events = status.get('events_received', 0) - pairs = status.get('storage', {}).get('pairs_count', 0) - beacons_count = status.get('analyzer', {}).get('beacons_detected', 0) - - print(f" {c('●', Colors.BRIGHT_GREEN)} {c('LIVE', Colors.BOLD, Colors.BRIGHT_GREEN)} │ " - f"Uptime: {self.format_duration(uptime)} │ " - f"Events: {events:,} │ " - f"Pairs: {pairs:,} │ " - f"Beacons: {c(str(beacons_count), Colors.YELLOW if beacons_count else Colors.GREEN)}") - + uptime = status.get("uptime_seconds", 0) + events = status.get("events_received", 0) + pairs = status.get("storage", {}).get("pairs_count", 0) + beacons_count = status.get("analyzer", {}).get( + "beacons_detected", 0 + ) + + print( + f" {c('●', Colors.BRIGHT_GREEN)} {c('LIVE', Colors.BOLD, Colors.BRIGHT_GREEN)} │ " + f"Uptime: {self.format_duration(uptime)} │ " + f"Events: {events:,} │ " + f"Pairs: {pairs:,} │ " + f"Beacons: {c(str(beacons_count), Colors.YELLOW if beacons_count else Colors.GREEN)}" + ) + # Show beacons self.print_header("DETECTED BEACONS (Score >= 70%)") - - data = self._api_request('/api/v1/beacons') + + data = self._api_request("/api/v1/beacons") if data: - beacons = [b for b in data.get('beacons', []) if b.get('combined_score', 0) >= 0.7] - beacons.sort(key=lambda x: x.get('combined_score', 0), reverse=True) - + beacons = [ + b + for b in data.get("beacons", []) + if b.get("combined_score", 0) >= 0.7 + ] + beacons.sort( + key=lambda x: x.get("combined_score", 0), reverse=True + ) + if beacons: - headers = ["SCORE", "SEVERITY", "SOURCE IP", "→", "DEST IP:PORT", "CONNS", "INTERVAL"] + headers = [ + "SCORE", + "SEVERITY", + "SOURCE IP", + "→", + "DEST IP:PORT", + "CONNS", + "INTERVAL", + ] rows = [] - + for b in beacons[:15]: - score = b.get('combined_score', 0) - interval_val = b.get('interval_stats', {}).get('mean', 0) - dest = f"{b.get('dst_ip', '-')}:{b.get('dst_port', '-')}" - - rows.append([ - self.format_score(score), - self.format_severity(score), - b.get('src_ip', '-'), - c("→", Colors.DIM), - dest, - str(b.get('connection_count', 0)), - f"{interval_val:.1f}s" - ]) - + score = b.get("combined_score", 0) + interval_val = b.get("interval_stats", {}).get( + "mean", 0 + ) + dest = ( + f"{b.get('dst_ip', '-')}:{b.get('dst_port', '-')}" + ) + + rows.append( + [ + self.format_score(score), + self.format_severity(score), + b.get("src_ip", "-"), + c("→", Colors.DIM), + dest, + str(b.get("connection_count", 0)), + f"{interval_val:.1f}s", + ] + ) + self.print_table(headers, rows, [7, 10, 18, 2, 25, 8, 10]) else: - print(c(" No high-confidence beacons detected", Colors.DIM)) - + print( + c(" No high-confidence beacons detected", Colors.DIM) + ) + # Show long connections self.print_header("LONG CONNECTIONS (> 1 hour)") - - data = self._api_request('/api/v1/connections?limit=200') + + data = self._api_request("/api/v1/connections?limit=200") if data: - long_conns = [p for p in data.get('pairs', []) if p.get('duration_seconds', 0) >= 3600] - long_conns.sort(key=lambda x: x.get('duration_seconds', 0), reverse=True) - + long_conns = [ + p + for p in data.get("pairs", []) + if p.get("duration_seconds", 0) >= 3600 + ] + long_conns.sort( + key=lambda x: x.get("duration_seconds", 0), reverse=True + ) + if long_conns: - headers = ["DURATION", "SOURCE IP", "→", "DEST IP:PORT", "CONNS"] + headers = [ + "DURATION", + "SOURCE IP", + "→", + "DEST IP:PORT", + "CONNS", + ] rows = [] - + for p in long_conns[:10]: - duration = p.get('duration_seconds', 0) + duration = p.get("duration_seconds", 0) dur_str = self.format_duration(duration) if duration >= 86400: dur_str = c(dur_str, Colors.BRIGHT_RED, Colors.BOLD) elif duration >= 3600: dur_str = c(dur_str, Colors.YELLOW) - - dest = f"{p.get('dst_ip', '-')}:{p.get('dst_port', '-')}" - - rows.append([ - dur_str, - p.get('src_ip', '-'), - c("→", Colors.DIM), - dest, - str(p.get('connection_count', 0)) - ]) - + + dest = ( + f"{p.get('dst_ip', '-')}:{p.get('dst_port', '-')}" + ) + + rows.append( + [ + dur_str, + p.get("src_ip", "-"), + c("→", Colors.DIM), + dest, + str(p.get("connection_count", 0)), + ] + ) + self.print_table(headers, rows, [10, 18, 2, 25, 8]) else: print(c(" No long connections detected", Colors.DIM)) - + print() - print(c(f" Refreshing every {interval}s... Press Ctrl+C to exit", Colors.DIM)) + print( + c( + f" Refreshing every {interval}s... Press Ctrl+C to exit", + Colors.DIM, + ) + ) time.sleep(interval) - + except KeyboardInterrupt: print() print(c(" Monitoring stopped", Colors.DIM)) @@ -554,9 +680,9 @@ def cmd_watch(self, interval: int = 5): def main(): parser = argparse.ArgumentParser( - description='Beacon Detection CLI', + description="Beacon Detection CLI", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=''' + epilog=""" Examples: %(prog)s status Show system status %(prog)s beacons Show all detected beacons @@ -566,75 +692,88 @@ def main(): %(prog)s long-conns --min-duration 7200 Show connections > 2 hours %(prog)s watch Live monitoring mode %(prog)s watch --interval 10 Refresh every 10 seconds - ''' + """, + ) + + parser.add_argument( + "--host", default="localhost", help="Control plane host (default: localhost)" + ) + parser.add_argument( + "--port", type=int, default=9090, help="Control plane port (default: 9090)" ) - - parser.add_argument('--host', default='localhost', - help='Control plane host (default: localhost)') - parser.add_argument('--port', type=int, default=9090, - help='Control plane port (default: 9090)') - parser.add_argument('--no-color', action='store_true', - help='Disable colored output') - - subparsers = parser.add_subparsers(dest='command', help='Command to run') - + parser.add_argument( + "--no-color", action="store_true", help="Disable colored output" + ) + + subparsers = parser.add_subparsers(dest="command", help="Command to run") + # Status command - subparsers.add_parser('status', help='Show system status') - + subparsers.add_parser("status", help="Show system status") + # Beacons command - beacons_parser = subparsers.add_parser('beacons', help='Show detected beacons') - beacons_parser.add_argument('--min-score', type=float, default=0.0, - help='Minimum beacon score (0.0-1.0)') - beacons_parser.add_argument('--limit', type=int, default=50, - help='Maximum results to show') - beacons_parser.add_argument('--csv', '-o', action='store_true', - help='Output as CSV') - + beacons_parser = subparsers.add_parser("beacons", help="Show detected beacons") + beacons_parser.add_argument( + "--min-score", type=float, default=0.0, help="Minimum beacon score (0.0-1.0)" + ) + beacons_parser.add_argument( + "--limit", type=int, default=50, help="Maximum results to show" + ) + beacons_parser.add_argument( + "--csv", "-o", action="store_true", help="Output as CSV" + ) + # Long connections command - long_parser = subparsers.add_parser('long-conns', help='Show long connections') - long_parser.add_argument('--min-duration', type=int, default=3600, - help='Minimum duration in seconds (default: 3600)') - long_parser.add_argument('--limit', type=int, default=50, - help='Maximum results to show') - long_parser.add_argument('--csv', '-o', action='store_true', - help='Output as CSV') - + long_parser = subparsers.add_parser("long-conns", help="Show long connections") + long_parser.add_argument( + "--min-duration", + type=int, + default=3600, + help="Minimum duration in seconds (default: 3600)", + ) + long_parser.add_argument( + "--limit", type=int, default=50, help="Maximum results to show" + ) + long_parser.add_argument("--csv", "-o", action="store_true", help="Output as CSV") + # Connections command - conns_parser = subparsers.add_parser('connections', help='Show all connections') - conns_parser.add_argument('--limit', type=int, default=100, - help='Maximum results to show') - conns_parser.add_argument('--csv', '-o', action='store_true', - help='Output as CSV') - + conns_parser = subparsers.add_parser("connections", help="Show all connections") + conns_parser.add_argument( + "--limit", type=int, default=100, help="Maximum results to show" + ) + conns_parser.add_argument("--csv", "-o", action="store_true", help="Output as CSV") + # Watch command - watch_parser = subparsers.add_parser('watch', help='Live monitoring mode') - watch_parser.add_argument('--interval', type=int, default=5, - help='Refresh interval in seconds') - + watch_parser = subparsers.add_parser("watch", help="Live monitoring mode") + watch_parser.add_argument( + "--interval", type=int, default=5, help="Refresh interval in seconds" + ) + args = parser.parse_args() - + # Handle no-color flag global USE_COLOR if args.no_color: USE_COLOR = False - + # Create CLI instance cli = BeaconCLI(host=args.host, port=args.port) - + # Execute command - if args.command == 'status' or args.command is None: + if args.command == "status" or args.command is None: cli.cmd_status() - elif args.command == 'beacons': + elif args.command == "beacons": cli.cmd_beacons(min_score=args.min_score, limit=args.limit, csv_output=args.csv) - elif args.command == 'long-conns': - cli.cmd_long_connections(min_duration=args.min_duration, limit=args.limit, csv_output=args.csv) - elif args.command == 'connections': + elif args.command == "long-conns": + cli.cmd_long_connections( + min_duration=args.min_duration, limit=args.limit, csv_output=args.csv + ) + elif args.command == "connections": cli.cmd_connections(limit=args.limit, csv_output=args.csv) - elif args.command == 'watch': + elif args.command == "watch": cli.cmd_watch(interval=args.interval) else: parser.print_help() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/control_plane/detector.py b/control_plane/detector.py index c5e444b..92139c5 100644 --- a/control_plane/detector.py +++ b/control_plane/detector.py @@ -10,7 +10,7 @@ from .storage import ConnectionPair -logger = logging.getLogger('beacon_detect.control_plane.detector') +logger = logging.getLogger("beacon_detect.control_plane.detector") class BeaconConfidence(Enum): @@ -33,17 +33,17 @@ class IntervalStats: min_interval: float max_interval: float jitter: float # Max deviation from median - + def to_dict(self): return { - 'count': int(self.count), - 'mean': float(round(self.mean, 3)), - 'std_dev': float(round(self.std_dev, 3)), - 'cv': float(round(self.cv, 4)), - 'median': float(round(self.median, 3)), - 'min_interval': float(round(self.min_interval, 3)), - 'max_interval': float(round(self.max_interval, 3)), - 'jitter': float(round(self.jitter, 3)) + "count": int(self.count), + "mean": float(round(self.mean, 3)), + "std_dev": float(round(self.std_dev, 3)), + "cv": float(round(self.cv, 4)), + "median": float(round(self.median, 3)), + "min_interval": float(round(self.min_interval, 3)), + "max_interval": float(round(self.max_interval, 3)), + "jitter": float(round(self.jitter, 3)), } @@ -54,15 +54,18 @@ class PeriodicityResult: dominant_period: float # In seconds periodicity_score: float # 0.0 to 1.0 frequency_peaks: List[Tuple[float, float]] # (frequency, magnitude) pairs - + def to_dict(self): return { - 'is_periodic': bool(self.is_periodic), - 'dominant_period': float(round(self.dominant_period, 3)) if self.dominant_period else None, - 'periodicity_score': float(round(self.periodicity_score, 4)), - 'frequency_peaks': [ - (float(round(f, 6)), float(round(m, 4))) for f, m in self.frequency_peaks - ] + "is_periodic": bool(self.is_periodic), + "dominant_period": ( + float(round(self.dominant_period, 3)) if self.dominant_period else None + ), + "periodicity_score": float(round(self.periodicity_score, 4)), + "frequency_peaks": [ + (float(round(f, 6)), float(round(m, 4))) + for f, m in self.frequency_peaks + ], } @@ -74,48 +77,52 @@ class DetectionResult: dst_ip: str dst_port: int protocol: str - + # Detection scores (0.0 to 1.0, higher = more beacon-like) cv_score: float periodicity_score: float jitter_score: float combined_score: float - + # Detection outcome is_beacon: bool confidence: BeaconConfidence - + # Supporting data interval_stats: IntervalStats periodicity_result: PeriodicityResult - + # Metadata connection_count: int duration_seconds: float first_seen: str last_seen: str - analysis_time: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')) - + analysis_time: str = field( + default_factory=lambda: datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z") + ) + def to_dict(self): return { - 'pair_key': str(self.pair_key), - 'src_ip': str(self.src_ip), - 'dst_ip': str(self.dst_ip), - 'dst_port': int(self.dst_port), - 'protocol': str(self.protocol), - 'cv_score': float(round(self.cv_score, 4)), - 'periodicity_score': float(round(self.periodicity_score, 4)), - 'jitter_score': float(round(self.jitter_score, 4)), - 'combined_score': float(round(self.combined_score, 4)), - 'is_beacon': bool(self.is_beacon), - 'confidence': str(self.confidence.value), - 'interval_stats': self.interval_stats.to_dict(), - 'periodicity_result': self.periodicity_result.to_dict(), - 'connection_count': int(self.connection_count), - 'duration_seconds': float(round(self.duration_seconds, 2)), - 'first_seen': str(self.first_seen), - 'last_seen': str(self.last_seen), - 'analysis_time': str(self.analysis_time) + "pair_key": str(self.pair_key), + "src_ip": str(self.src_ip), + "dst_ip": str(self.dst_ip), + "dst_port": int(self.dst_port), + "protocol": str(self.protocol), + "cv_score": float(round(self.cv_score, 4)), + "periodicity_score": float(round(self.periodicity_score, 4)), + "jitter_score": float(round(self.jitter_score, 4)), + "combined_score": float(round(self.combined_score, 4)), + "is_beacon": bool(self.is_beacon), + "confidence": str(self.confidence.value), + "interval_stats": self.interval_stats.to_dict(), + "periodicity_result": self.periodicity_result.to_dict(), + "connection_count": int(self.connection_count), + "duration_seconds": float(round(self.duration_seconds, 2)), + "first_seen": str(self.first_seen), + "last_seen": str(self.last_seen), + "analysis_time": str(self.analysis_time), } @@ -125,44 +132,46 @@ class DetectorConfig: # Minimum data requirements min_connections: int = 10 time_window: int = 3600 # seconds - + # CV threshold (lower = more regular, beacon-like) cv_threshold: float = 0.15 - + # Periodicity threshold (higher = more periodic) periodicity_threshold: float = 0.7 - + # Jitter threshold in seconds (lower = more consistent) jitter_threshold: float = 5.0 - + # Interval bounds min_beacon_interval: float = 10.0 # seconds max_beacon_interval: float = 3600.0 # seconds - + # Score weights (should sum to 1.0) cv_weight: float = 0.4 periodicity_weight: float = 0.4 jitter_weight: float = 0.2 - + # Final threshold for beacon classification alert_threshold: float = 0.7 class BeaconDetector: - - def __init__(self, config = None): + + def __init__(self, config=None): self.config = config or DetectorConfig() - + # Validate weights - total_weight = (self.config.cv_weight + - self.config.periodicity_weight + - self.config.jitter_weight) + total_weight = ( + self.config.cv_weight + + self.config.periodicity_weight + + self.config.jitter_weight + ) if not 0.99 <= total_weight <= 1.01: logger.warning(f"Score weights sum to {total_weight}, should be 1.0") - + logger.info(f"BeaconDetector initialized with config: {self.config}") - + def analyze(self, pair: ConnectionPair): # Check minimum data requirements if pair.connection_count < self.config.min_connections: @@ -171,52 +180,50 @@ def analyze(self, pair: ConnectionPair): f"{pair.connection_count} < {self.config.min_connections}" ) return None - + # Get intervals intervals = pair.get_intervals() if len(intervals) < self.config.min_connections - 1: return None - + # Filter intervals within bounds intervals = [ - i for i in intervals + i + for i in intervals if self.config.min_beacon_interval <= i <= self.config.max_beacon_interval ] - + if len(intervals) < self.config.min_connections - 1: logger.debug(f"Insufficient valid intervals for {pair.pair_key}") return None - + # Calculate interval statistics interval_stats = self._calculate_interval_stats(intervals) - + # Calculate CV score (lower CV = higher score) cv_score = self._calculate_cv_score(interval_stats.cv) - + # Perform periodicity analysis periodicity_result = self._analyze_periodicity(intervals) - + # Calculate jitter score (lower jitter = higher score) jitter_score = self._calculate_jitter_score(interval_stats.jitter) - # logger.info(f"interval {interval_stats}, cv_score {cv_score}, periodicity {periodicity_result}") + # logger.info(f"interval {interval_stats}, cv_score {cv_score}, periodicity {periodicity_result}") # Calculate combined score combined_score = ( - self.config.cv_weight * cv_score + - self.config.periodicity_weight * periodicity_result.periodicity_score + - self.config.jitter_weight * jitter_score + self.config.cv_weight * cv_score + + self.config.periodicity_weight * periodicity_result.periodicity_score + + self.config.jitter_weight * jitter_score ) - + # Determine if beacon is_beacon = combined_score >= self.config.alert_threshold - + # Determine confidence level confidence = self._determine_confidence( - combined_score, - cv_score, - periodicity_result.periodicity_score, - jitter_score + combined_score, cv_score, periodicity_result.periodicity_score, jitter_score ) - + # Create result result = DetectionResult( pair_key=pair.pair_key, @@ -234,31 +241,39 @@ def analyze(self, pair: ConnectionPair): periodicity_result=periodicity_result, connection_count=pair.connection_count, duration_seconds=pair.duration_seconds, - first_seen=datetime.fromtimestamp(pair.first_seen).isoformat() + 'Z' if pair.first_seen else '', - last_seen=datetime.fromtimestamp(pair.last_seen).isoformat() + 'Z' if pair.last_seen else '' + first_seen=( + datetime.fromtimestamp(pair.first_seen).isoformat() + "Z" + if pair.first_seen + else "" + ), + last_seen=( + datetime.fromtimestamp(pair.last_seen).isoformat() + "Z" + if pair.last_seen + else "" + ), ) - + if is_beacon: logger.warning( f"Beacon detected: {pair.pair_key} " f"(score={combined_score:.3f}, confidence={confidence.value})" ) - + return result - + def _calculate_interval_stats(self, intervals): arr = np.array(intervals) - + mean = float(np.mean(arr)) std_dev = float(np.std(arr)) median = float(np.median(arr)) - + # Coefficient of variation - cv = std_dev / mean if mean > 0 else float('inf') - + cv = std_dev / mean if mean > 0 else float("inf") + # Jitter = maximum deviation from median jitter = float(np.max(np.abs(arr - median))) - + return IntervalStats( count=len(intervals), mean=mean, @@ -267,24 +282,24 @@ def _calculate_interval_stats(self, intervals): median=median, min_interval=float(np.min(arr)), max_interval=float(np.max(arr)), - jitter=jitter + jitter=jitter, ) - + def _calculate_cv_score(self, cv: float): if cv <= 0: return 1.0 - + # Sigmoid transformation # At cv = threshold, score ≈ 0.5 # cv << threshold: score → 1.0 # cv >> threshold: score → 0.0 threshold = self.config.cv_threshold k = 10.0 / threshold # Steepness factor - + score = 1.0 / (1.0 + math.exp(k * (cv - threshold))) return score - + def _analyze_periodicity(self, intervals: List[float]): if len(intervals) < 4: @@ -292,44 +307,44 @@ def _analyze_periodicity(self, intervals: List[float]): is_periodic=False, dominant_period=None, periodicity_score=0.0, - frequency_peaks=[] + frequency_peaks=[], ) - + arr = np.array(intervals) n = len(arr) - + # Remove mean (DC component) arr_centered = arr - np.mean(arr) - + # Perform FFT fft_result = fft.fft(arr_centered) frequencies = fft.fftfreq(n, d=np.mean(arr)) - + # Get magnitude spectrum (positive frequencies only) - magnitude = np.abs(fft_result[:n//2]) - freq_positive = frequencies[:n//2] - + magnitude = np.abs(fft_result[: n // 2]) + freq_positive = frequencies[: n // 2] + # Normalize magnitude magnitude_norm = magnitude / (np.sum(magnitude) + 1e-10) - + # Find peaks peaks = [] for i in range(1, len(magnitude) - 1): - if magnitude[i] > magnitude[i-1] and magnitude[i] > magnitude[i+1]: + if magnitude[i] > magnitude[i - 1] and magnitude[i] > magnitude[i + 1]: if freq_positive[i] > 0: # Skip DC peaks.append((freq_positive[i], magnitude_norm[i])) - + # Sort peaks by magnitude peaks.sort(key=lambda x: x[1], reverse=True) top_peaks = peaks[:5] - + # Calculate periodicity score # Based on how dominant the main frequency is if len(top_peaks) > 0: dominant_magnitude = top_peaks[0][1] dominant_freq = top_peaks[0][0] dominant_period = 1.0 / dominant_freq if dominant_freq > 0 else None - + # Score is based on dominance of primary frequency # A strong single frequency indicates regular periodicity if len(top_peaks) > 1: @@ -341,39 +356,39 @@ def _analyze_periodicity(self, intervals: List[float]): else: dominant_period = None periodicity_score = 0.0 - + # Determine if periodic based on threshold is_periodic = periodicity_score >= self.config.periodicity_threshold - + return PeriodicityResult( is_periodic=is_periodic, dominant_period=dominant_period, periodicity_score=periodicity_score, - frequency_peaks=top_peaks + frequency_peaks=top_peaks, ) - + def _calculate_jitter_score(self, jitter: float): if jitter <= 0: return 1.0 - + threshold = self.config.jitter_threshold - + # Linear scaling with cutoff if jitter <= threshold: score = 1.0 - (jitter / threshold) * 0.5 # Score 0.5-1.0 else: # Exponential decay beyond threshold score = 0.5 * math.exp(-(jitter - threshold) / threshold) - + return max(0.0, min(1.0, score)) - + def _determine_confidence( self, combined_score: float, cv_score: float, periodicity_score: float, - jitter_score: float + jitter_score: float, ): if combined_score < 0.3: @@ -389,13 +404,10 @@ def _determine_confidence( if cv_score > 0.7 and periodicity_score > 0.7 and jitter_score > 0.7: return BeaconConfidence.CRITICAL return BeaconConfidence.HIGH - - def batch_analyze( - self, - pairs: List[ConnectionPair] - ): + + def batch_analyze(self, pairs: List[ConnectionPair]): results = [] - + for pair in pairs: try: result = self.analyze(pair) @@ -403,16 +415,13 @@ def batch_analyze( results.append(result) except Exception as e: logger.error(f"Error analyzing {pair.pair_key}: {e}") - + # Sort by combined score descending results.sort(key=lambda r: r.combined_score, reverse=True) - + return results - - def get_beacons( - self, - pairs: List[ConnectionPair] - ): + + def get_beacons(self, pairs: List[ConnectionPair]): all_results = self.batch_analyze(pairs) return [r for r in all_results if r.is_beacon] diff --git a/control_plane/server.py b/control_plane/server.py index 2a7173f..5344653 100644 --- a/control_plane/server.py +++ b/control_plane/server.py @@ -12,18 +12,16 @@ import yaml from aiohttp import web -from .storage import ConnectionStorage +from .alerter import AlertingConfig, AlertManager, AlertSeverity +from .analyzer import AnalyzerConfig, ConnectionAnalyzer from .detector import BeaconDetector, DetectorConfig -from .analyzer import ConnectionAnalyzer, AnalyzerConfig -from .alerter import AlertManager, AlertingConfig, AlertSeverity - +from .storage import ConnectionStorage # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) -logger = logging.getLogger('beacon_detect.control_plane.server') +logger = logging.getLogger("beacon_detect.control_plane.server") class ControlPlaneServer: @@ -32,529 +30,554 @@ def __init__(self, config): self.config = config self._runtime_config = self._build_runtime_config(config) - cp_config = config.get('control_plane', {}) - - self.host = cp_config.get('listen_address', '0.0.0.0') - self.port = cp_config.get('listen_port', 9090) - + cp_config = config.get("control_plane", {}) + + self.host = cp_config.get("listen_address", "0.0.0.0") + self.port = cp_config.get("listen_port", 9090) + # Shutdown event for clean termination self._shutdown_event = asyncio.Event() - + # Initialize components self._init_storage(config) self._init_detector(config) self._init_alerter(config) self._init_analyzer(config) - + # HTTP app self._app = None self._runner = None self._site = None - + # Statistics self._start_time = None self._requests_received = 0 self._batches_processed = 0 self._events_received = 0 - + logger.info(f"ControlPlaneServer initialized: {self.host}:{self.port}") - + def _build_runtime_config(self, config): - det = config.get('detection', {}) - alert = config.get('alerting', {}) - whitelist = config.get('whitelist', {}) - + det = config.get("detection", {}) + alert = config.get("alerting", {}) + whitelist = config.get("whitelist", {}) + return { - 'detection': { - 'min_connections': det.get('min_connections', 10), - 'cv_threshold': det.get('cv_threshold', 0.15), - 'alert_threshold': det.get('alert_threshold', 0.7), - 'jitter_threshold': det.get('jitter_threshold', 5.0), - 'analysis_interval': det.get('analysis_interval', 60), - 'alert_cooldown': det.get('alert_cooldown', 300), + "detection": { + "min_connections": det.get("min_connections", 10), + "cv_threshold": det.get("cv_threshold", 0.15), + "alert_threshold": det.get("alert_threshold", 0.7), + "jitter_threshold": det.get("jitter_threshold", 5.0), + "analysis_interval": det.get("analysis_interval", 60), + "alert_cooldown": det.get("alert_cooldown", 300), }, - 'weights': { - 'cv': det.get('cv_weight', 0.4), - 'periodicity': det.get('periodicity_weight', 0.4), - 'jitter': det.get('jitter_weight', 0.2), + "weights": { + "cv": det.get("cv_weight", 0.4), + "periodicity": det.get("periodicity_weight", 0.4), + "jitter": det.get("jitter_weight", 0.2), }, - 'alerting': { - 'syslog_enabled': alert.get('syslog', {}).get('enabled', True), - 'file_enabled': alert.get('file', {}).get('enabled', True), - 'file_path': alert.get('file', {}).get('path', '/var/log/beacon-detect/alerts.json'), - 'webhook_enabled': alert.get('webhook', {}).get('enabled', False), - 'webhook_url': alert.get('webhook', {}).get('url', ''), + "alerting": { + "syslog_enabled": alert.get("syslog", {}).get("enabled", True), + "file_enabled": alert.get("file", {}).get("enabled", True), + "file_path": alert.get("file", {}).get( + "path", "/var/log/beacon-detect/alerts.json" + ), + "webhook_enabled": alert.get("webhook", {}).get("enabled", False), + "webhook_url": alert.get("webhook", {}).get("url", ""), + }, + "whitelist": { + "source_ips": whitelist.get("source_ips", []), + "destination_ips": whitelist.get("destination_ips", []), + "destination_ports": whitelist.get("ports", []), }, - 'whitelist': { - 'source_ips': whitelist.get('source_ips', []), - 'destination_ips': whitelist.get('destination_ips', []), - 'destination_ports': whitelist.get('ports', []), - } } - + def _init_storage(self, config): - cp_config = config.get('control_plane', {}) + cp_config = config.get("control_plane", {}) self.storage = ConnectionStorage( - retention_seconds=cp_config.get('data_retention', 7200), - cleanup_interval=cp_config.get('cleanup_interval', 300) + retention_seconds=cp_config.get("data_retention", 7200), + cleanup_interval=cp_config.get("cleanup_interval", 300), ) - + def _init_detector(self, config): - det_config = config.get('detection', {}) + det_config = config.get("detection", {}) detector_config = DetectorConfig( - min_connections=det_config.get('min_connections', 10), - time_window=det_config.get('time_window', 3600), - cv_threshold=det_config.get('cv_threshold', 0.15), - periodicity_threshold=det_config.get('periodicity_threshold', 0.7), - jitter_threshold=det_config.get('jitter_threshold', 5.0), - min_beacon_interval=det_config.get('min_beacon_interval', 10.0), - max_beacon_interval=det_config.get('max_beacon_interval', 3600.0), - cv_weight=det_config.get('cv_weight', 0.4), - periodicity_weight=det_config.get('periodicity_weight', 0.4), - jitter_weight=det_config.get('jitter_weight', 0.2), - alert_threshold=det_config.get('alert_threshold', 0.7) + min_connections=det_config.get("min_connections", 10), + time_window=det_config.get("time_window", 3600), + cv_threshold=det_config.get("cv_threshold", 0.15), + periodicity_threshold=det_config.get("periodicity_threshold", 0.7), + jitter_threshold=det_config.get("jitter_threshold", 5.0), + min_beacon_interval=det_config.get("min_beacon_interval", 10.0), + max_beacon_interval=det_config.get("max_beacon_interval", 3600.0), + cv_weight=det_config.get("cv_weight", 0.4), + periodicity_weight=det_config.get("periodicity_weight", 0.4), + jitter_weight=det_config.get("jitter_weight", 0.2), + alert_threshold=det_config.get("alert_threshold", 0.7), ) self.detector = BeaconDetector(detector_config) - + def _init_alerter(self, config): - alert_config = config.get('alerting', {}) - + alert_config = config.get("alerting", {}) + # Build alerting config alerting_config = AlertingConfig( - enabled=alert_config.get('enabled', True), - syslog_enabled=alert_config.get('syslog', {}).get('enabled', True), - syslog_facility=alert_config.get('syslog', {}).get('facility', 'local0'), - file_enabled=alert_config.get('file', {}).get('enabled', True), - file_path=alert_config.get('file', {}).get('path', '/var/log/beacon-detect/alerts.json'), - file_max_size_mb=alert_config.get('file', {}).get('max_size_mb', 100), - file_backup_count=alert_config.get('file', {}).get('backup_count', 5), - webhook_enabled=alert_config.get('webhook', {}).get('enabled', False), - webhook_url=alert_config.get('webhook', {}).get('url', ''), - webhook_headers=alert_config.get('webhook', {}).get('headers', {}), - webhook_timeout=alert_config.get('webhook', {}).get('timeout', 10), - webhook_retries=alert_config.get('webhook', {}).get('retries', 3) + enabled=alert_config.get("enabled", True), + syslog_enabled=alert_config.get("syslog", {}).get("enabled", True), + syslog_facility=alert_config.get("syslog", {}).get("facility", "local0"), + file_enabled=alert_config.get("file", {}).get("enabled", True), + file_path=alert_config.get("file", {}).get( + "path", "/var/log/beacon-detect/alerts.json" + ), + file_max_size_mb=alert_config.get("file", {}).get("max_size_mb", 100), + file_backup_count=alert_config.get("file", {}).get("backup_count", 5), + webhook_enabled=alert_config.get("webhook", {}).get("enabled", False), + webhook_url=alert_config.get("webhook", {}).get("url", ""), + webhook_headers=alert_config.get("webhook", {}).get("headers", {}), + webhook_timeout=alert_config.get("webhook", {}).get("timeout", 10), + webhook_retries=alert_config.get("webhook", {}).get("retries", 3), ) self.alert_manager = AlertManager(alerting_config) - + def _init_analyzer(self, config): - det_config = config.get('detection', {}) + det_config = config.get("detection", {}) analyzer_config = AnalyzerConfig( analysis_interval=60, # Run every minute - min_connections=det_config.get('min_connections', 10), + min_connections=det_config.get("min_connections", 10), min_duration=30.0, - alert_cooldown=det_config.get('alert_cooldown', 300) + alert_cooldown=det_config.get("alert_cooldown", 300), ) self.analyzer = ConnectionAnalyzer( storage=self.storage, detector=self.detector, alert_manager=self.alert_manager, - config=analyzer_config + config=analyzer_config, ) - + def _setup_routes(self, app: web.Application): # API routes - app.router.add_get('/', self._handle_info) - app.router.add_post('/api/v1/telemetry', self._handle_telemetry) - app.router.add_get('/api/v1/health', self._handle_health) - app.router.add_get('/api/v1/status', self._handle_status) - app.router.add_get('/api/v1/statistics', self._handle_statistics) - app.router.add_get('/api/v1/beacons', self._handle_beacons) - app.router.add_get('/api/v1/alerts', self._handle_alerts) - app.router.add_delete('/api/v1/alerts', self._handle_clear_alerts) - app.router.add_get('/api/v1/connections', self._handle_connections) - app.router.add_post('/api/v1/analyze', self._handle_manual_analyze) - app.router.add_get('/api/v1/config', self._handle_get_config) - app.router.add_post('/api/v1/config', self._handle_set_config) - app.router.add_delete('/api/v1/beacons', self._handle_clear_beacons) - + app.router.add_get("/", self._handle_info) + app.router.add_post("/api/v1/telemetry", self._handle_telemetry) + app.router.add_get("/api/v1/health", self._handle_health) + app.router.add_get("/api/v1/status", self._handle_status) + app.router.add_get("/api/v1/statistics", self._handle_statistics) + app.router.add_get("/api/v1/beacons", self._handle_beacons) + app.router.add_get("/api/v1/alerts", self._handle_alerts) + app.router.add_delete("/api/v1/alerts", self._handle_clear_alerts) + app.router.add_get("/api/v1/connections", self._handle_connections) + app.router.add_post("/api/v1/analyze", self._handle_manual_analyze) + app.router.add_get("/api/v1/config", self._handle_get_config) + app.router.add_post("/api/v1/config", self._handle_set_config) + app.router.add_delete("/api/v1/beacons", self._handle_clear_beacons) + # CORS preflight handler for all API routes - app.router.add_route('OPTIONS', '/api/v1/{path:.*}', self._handle_options) - + app.router.add_route("OPTIONS", "/api/v1/{path:.*}", self._handle_options) + async def _handle_info(self, request: web.Request) -> web.Response: - return web.json_response({ - 'name': 'Beacon Detection Control Plane', - 'version': '1.0.0', - 'description': 'Use the CLI for monitoring: python3 -m control_plane.cli', - 'endpoints': { - 'GET /api/v1/health': 'Health check', - 'GET /api/v1/status': 'Server status', - 'GET /api/v1/beacons': 'List detected beacons', - 'GET /api/v1/alerts': 'List alerts', - 'GET /api/v1/connections': 'List connection pairs', - 'GET /api/v1/config': 'Get configuration', - 'POST /api/v1/config': 'Update configuration', - 'POST /api/v1/telemetry': 'Receive telemetry data', - 'POST /api/v1/analyze': 'Trigger analysis', + return web.json_response( + { + "name": "Beacon Detection Control Plane", + "version": "1.0.0", + "description": "Use the CLI for monitoring: python3 -m control_plane.cli", + "endpoints": { + "GET /api/v1/health": "Health check", + "GET /api/v1/status": "Server status", + "GET /api/v1/beacons": "List detected beacons", + "GET /api/v1/alerts": "List alerts", + "GET /api/v1/connections": "List connection pairs", + "GET /api/v1/config": "Get configuration", + "POST /api/v1/config": "Update configuration", + "POST /api/v1/telemetry": "Receive telemetry data", + "POST /api/v1/analyze": "Trigger analysis", + }, } - }) - + ) + async def _handle_options(self, request: web.Request) -> web.Response: return web.Response(status=200) - + async def _handle_health(self, request: web.Request) -> web.Response: - return web.json_response({ - 'status': 'healthy', - 'timestamp': datetime.now(timezone.utc).isoformat() + 'Z' - }) - + return web.json_response( + { + "status": "healthy", + "timestamp": datetime.now(timezone.utc).isoformat() + "Z", + } + ) + async def _handle_status(self, request: web.Request) -> web.Response: uptime = None if self._start_time: uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() - - return web.json_response({ - 'status': 'running', - 'uptime_seconds': uptime, - 'start_time': self._start_time.isoformat() + 'Z' if self._start_time else None, - 'requests_received': self._requests_received, - 'batches_processed': self._batches_processed, - 'events_received': self._events_received, - 'storage': self.storage.statistics, - 'analyzer': self.analyzer.statistics, - 'alerter': self.alert_manager.statistics - }) - + + return web.json_response( + { + "status": "running", + "uptime_seconds": uptime, + "start_time": ( + self._start_time.isoformat() + "Z" if self._start_time else None + ), + "requests_received": self._requests_received, + "batches_processed": self._batches_processed, + "events_received": self._events_received, + "storage": self.storage.statistics, + "analyzer": self.analyzer.statistics, + "alerter": self.alert_manager.statistics, + } + ) + async def _handle_statistics(self, request: web.Request) -> web.Response: - return web.json_response({ - 'server': { - 'requests_received': self._requests_received, - 'batches_processed': self._batches_processed, - 'events_received': self._events_received - }, - 'storage': self.storage.statistics, - 'analyzer': self.analyzer.statistics, - 'alerter': self.alert_manager.statistics - }) - + return web.json_response( + { + "server": { + "requests_received": self._requests_received, + "batches_processed": self._batches_processed, + "events_received": self._events_received, + }, + "storage": self.storage.statistics, + "analyzer": self.analyzer.statistics, + "alerter": self.alert_manager.statistics, + } + ) + async def _handle_telemetry(self, request: web.Request) -> web.Response: self._requests_received += 1 - + try: # Get request body body = await request.read() - - data = json.loads(body.decode('utf-8')) - + + data = json.loads(body.decode("utf-8")) + # Validate batch structure - if 'events' not in data: - return web.json_response( - {'error': 'Missing events field'}, - status=400 - ) - - events = data['events'] - node_id = data.get('node_id', 'unknown') - batch_id = data.get('batch_id', 'unknown') - + if "events" not in data: + return web.json_response({"error": "Missing events field"}, status=400) + + events = data["events"] + node_id = data.get("node_id", "unknown") + batch_id = data.get("batch_id", "unknown") + # Add events to storage self.storage.add_batch(events) - + self._batches_processed += 1 self._events_received += len(events) - + logger.info( - f"Received batch {batch_id} from {node_id}: " - f"{len(events)} events" + f"Received batch {batch_id} from {node_id}: " f"{len(events)} events" ) - - return web.json_response({ - 'status': 'accepted', - 'batch_id': batch_id, - 'events_received': len(events) - }) - - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON in telemetry: {e}") + return web.json_response( - {'error': 'Invalid JSON'}, - status=400 + { + "status": "accepted", + "batch_id": batch_id, + "events_received": len(events), + } ) + + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in telemetry: {e}") + return web.json_response({"error": "Invalid JSON"}, status=400) except Exception as e: logger.error(f"Error processing telemetry: {e}") - return web.json_response( - {'error': str(e)}, - status=500 - ) - + return web.json_response({"error": str(e)}, status=500) + async def _handle_beacons(self, request: web.Request) -> web.Response: beacons = self.analyzer.get_known_beacons() - return web.json_response({ - 'count': len(beacons), - 'beacons': [b.to_dict() for b in beacons] - }) - + return web.json_response( + {"count": len(beacons), "beacons": [b.to_dict() for b in beacons]} + ) + async def _handle_alerts(self, request: web.Request) -> web.Response: - limit = int(request.query.get('limit', 50)) - severity = request.query.get('severity') - + limit = int(request.query.get("limit", 50)) + severity = request.query.get("severity") + sev_filter = None if severity: try: sev_filter = AlertSeverity(severity.lower()) except ValueError: pass - + alerts = self.alert_manager.get_recent_alerts(limit=limit, severity=sev_filter) - - return web.json_response({ - 'count': len(alerts), - 'alerts': alerts - }) - + + return web.json_response({"count": len(alerts), "alerts": alerts}) + async def _handle_connections(self, request: web.Request) -> web.Response: - src_ip = request.query.get('src_ip') - dst_ip = request.query.get('dst_ip') - limit = int(request.query.get('limit', 100)) - + src_ip = request.query.get("src_ip") + dst_ip = request.query.get("dst_ip") + limit = int(request.query.get("limit", 100)) + if src_ip: pairs = self.storage.get_pairs_by_src(src_ip) elif dst_ip: pairs = self.storage.get_pairs_by_dst(dst_ip) else: pairs = self.storage.get_all_pairs() - + # Sort by connection count and limit pairs.sort(key=lambda p: p.connection_count, reverse=True) pairs = pairs[:limit] - - return web.json_response({ - 'count': len(pairs), - 'pairs': [ - { - 'pair_key': p.pair_key, - 'src_ip': p.src_ip, - 'dst_ip': p.dst_ip, - 'dst_port': p.dst_port, - 'protocol': p.protocol, - 'connection_count': p.connection_count, - 'duration_seconds': p.duration_seconds, - 'first_seen': datetime.fromtimestamp(p.first_seen).isoformat() + 'Z' if p.first_seen else None, - 'last_seen': datetime.fromtimestamp(p.last_seen).isoformat() + 'Z' if p.last_seen else None - } - for p in pairs - ] - }) - + + return web.json_response( + { + "count": len(pairs), + "pairs": [ + { + "pair_key": p.pair_key, + "src_ip": p.src_ip, + "dst_ip": p.dst_ip, + "dst_port": p.dst_port, + "protocol": p.protocol, + "connection_count": p.connection_count, + "duration_seconds": p.duration_seconds, + "first_seen": ( + datetime.fromtimestamp(p.first_seen).isoformat() + "Z" + if p.first_seen + else None + ), + "last_seen": ( + datetime.fromtimestamp(p.last_seen).isoformat() + "Z" + if p.last_seen + else None + ), + } + for p in pairs + ], + } + ) + async def _handle_manual_analyze(self, request: web.Request) -> web.Response: try: run = self.analyzer.run_analysis() - return web.json_response({ - 'status': 'completed', - 'run': run.to_dict(), - 'beacons_found': [r.to_dict() for r in run.results if r.is_beacon] - }) - except Exception as e: - logger.error(f"Manual analysis failed: {e}") return web.json_response( - {'error': str(e)}, - status=500 + { + "status": "completed", + "run": run.to_dict(), + "beacons_found": [r.to_dict() for r in run.results if r.is_beacon], + } ) - + except Exception as e: + logger.error(f"Manual analysis failed: {e}") + return web.json_response({"error": str(e)}, status=500) + async def _handle_get_config(self, request: web.Request) -> web.Response: return web.json_response(self._runtime_config) - + async def _handle_set_config(self, request: web.Request) -> web.Response: try: data = await request.json() - + # Update runtime config - if 'detection' in data: - self._runtime_config['detection'].update(data['detection']) + if "detection" in data: + self._runtime_config["detection"].update(data["detection"]) # Update detector config - if hasattr(self.detector, 'config'): - det = data['detection'] - if 'min_connections' in det: - self.detector.config.min_connections = det['min_connections'] - if 'cv_threshold' in det: - self.detector.config.cv_threshold = det['cv_threshold'] - if 'alert_threshold' in det: - self.detector.config.alert_threshold = det['alert_threshold'] - if 'jitter_threshold' in det: - self.detector.config.jitter_threshold = det['jitter_threshold'] + if hasattr(self.detector, "config"): + det = data["detection"] + if "min_connections" in det: + self.detector.config.min_connections = det["min_connections"] + if "cv_threshold" in det: + self.detector.config.cv_threshold = det["cv_threshold"] + if "alert_threshold" in det: + self.detector.config.alert_threshold = det["alert_threshold"] + if "jitter_threshold" in det: + self.detector.config.jitter_threshold = det["jitter_threshold"] # Update analyzer config - if hasattr(self.analyzer, 'config'): - if 'analysis_interval' in det: - self.analyzer.config.analysis_interval = det.get('analysis_interval', 60) - if 'alert_cooldown' in det: - self.analyzer.config.alert_cooldown = det.get('alert_cooldown', 300) - - if 'weights' in data: - self._runtime_config['weights'].update(data['weights']) + if hasattr(self.analyzer, "config"): + if "analysis_interval" in det: + self.analyzer.config.analysis_interval = det.get( + "analysis_interval", 60 + ) + if "alert_cooldown" in det: + self.analyzer.config.alert_cooldown = det.get( + "alert_cooldown", 300 + ) + + if "weights" in data: + self._runtime_config["weights"].update(data["weights"]) # Update detector weights - if hasattr(self.detector, 'config'): - w = data['weights'] - if 'cv' in w: - self.detector.config.cv_weight = w['cv'] - if 'periodicity' in w: - self.detector.config.periodicity_weight = w['periodicity'] - if 'jitter' in w: - self.detector.config.jitter_weight = w['jitter'] - - if 'alerting' in data: - self._runtime_config['alerting'].update(data['alerting']) + if hasattr(self.detector, "config"): + w = data["weights"] + if "cv" in w: + self.detector.config.cv_weight = w["cv"] + if "periodicity" in w: + self.detector.config.periodicity_weight = w["periodicity"] + if "jitter" in w: + self.detector.config.jitter_weight = w["jitter"] + + if "alerting" in data: + self._runtime_config["alerting"].update(data["alerting"]) # Update alert manager config - if hasattr(self.alert_manager, 'config'): - a = data['alerting'] - if 'syslog_enabled' in a: - self.alert_manager.config.syslog_enabled = a['syslog_enabled'] - if 'file_enabled' in a: - self.alert_manager.config.file_enabled = a['file_enabled'] - if 'webhook_enabled' in a: - self.alert_manager.config.webhook_enabled = a['webhook_enabled'] - if 'webhook_url' in a: - self.alert_manager.config.webhook_url = a['webhook_url'] - - if 'whitelist' in data: - self._runtime_config['whitelist'].update(data['whitelist']) + if hasattr(self.alert_manager, "config"): + a = data["alerting"] + if "syslog_enabled" in a: + self.alert_manager.config.syslog_enabled = a["syslog_enabled"] + if "file_enabled" in a: + self.alert_manager.config.file_enabled = a["file_enabled"] + if "webhook_enabled" in a: + self.alert_manager.config.webhook_enabled = a["webhook_enabled"] + if "webhook_url" in a: + self.alert_manager.config.webhook_url = a["webhook_url"] + + if "whitelist" in data: + self._runtime_config["whitelist"].update(data["whitelist"]) # Update whitelist in main config for filtering - if 'source_ips' in data['whitelist']: - self.config.setdefault('whitelist', {})['source_ips'] = data['whitelist']['source_ips'] - if 'destination_ips' in data['whitelist']: - self.config.setdefault('whitelist', {})['destination_ips'] = data['whitelist']['destination_ips'] - if 'destination_ports' in data['whitelist']: - self.config.setdefault('whitelist', {})['ports'] = data['whitelist']['destination_ports'] - + if "source_ips" in data["whitelist"]: + self.config.setdefault("whitelist", {})["source_ips"] = data[ + "whitelist" + ]["source_ips"] + if "destination_ips" in data["whitelist"]: + self.config.setdefault("whitelist", {})["destination_ips"] = data[ + "whitelist" + ]["destination_ips"] + if "destination_ports" in data["whitelist"]: + self.config.setdefault("whitelist", {})["ports"] = data[ + "whitelist" + ]["destination_ports"] + logger.info("Configuration updated via API") - return web.json_response({'status': 'updated', 'config': self._runtime_config}) - + return web.json_response( + {"status": "updated", "config": self._runtime_config} + ) + except Exception as e: logger.error(f"Config update failed: {e}") - return web.json_response({'error': str(e)}, status=500) - + return web.json_response({"error": str(e)}, status=500) + async def _handle_clear_alerts(self, request: web.Request) -> web.Response: try: - if hasattr(self.alert_manager, '_recent_alerts'): + if hasattr(self.alert_manager, "_recent_alerts"): self.alert_manager._recent_alerts = [] - return web.json_response({'status': 'cleared'}) + return web.json_response({"status": "cleared"}) except Exception as e: - return web.json_response({'error': str(e)}, status=500) - + return web.json_response({"error": str(e)}, status=500) + async def _handle_clear_beacons(self, request: web.Request) -> web.Response: try: - if hasattr(self.analyzer, '_known_beacons'): + if hasattr(self.analyzer, "_known_beacons"): self.analyzer._known_beacons = {} - return web.json_response({'status': 'cleared'}) + return web.json_response({"status": "cleared"}) except Exception as e: - return web.json_response({'error': str(e)}, status=500) - + return web.json_response({"error": str(e)}, status=500) + async def start(self): logger.info(f"Starting control plane server on {self.host}:{self.port}") - + # Start components self.storage.start_cleanup() self.alert_manager.start() self.analyzer.start() - + # Create and configure app with CORS support self._app = web.Application(middlewares=[self._cors_middleware]) self._setup_routes(self._app) - + # Create runner self._runner = web.AppRunner(self._app) await self._runner.setup() - + # Create site self._site = web.TCPSite(self._runner, self.host, self.port) await self._site.start() - + self._start_time = datetime.now(timezone.utc) - + logger.info(f"Control plane server started on {self.host}:{self.port}") - + # Send startup alert self.alert_manager.create_and_send( title="Beacon Detection Control Plane Started", description=f"Control plane server started on {self.host}:{self.port}", severity=AlertSeverity.INFO, - source="control_plane" + source="control_plane", ) - + @web.middleware async def _cors_middleware(self, request: web.Request, handler): # Handle preflight OPTIONS requests - if request.method == 'OPTIONS': + if request.method == "OPTIONS": response = web.Response() else: try: response = await handler(request) except web.HTTPException as ex: response = ex - + # Add CORS headers - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' - response.headers['Access-Control-Max-Age'] = '3600' - + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, OPTIONS" + ) + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + response.headers["Access-Control-Max-Age"] = "3600" + return response - + async def stop(self): logger.info("Stopping control plane server...") - + # Signal shutdown self._shutdown_event.set() - + # Stop components try: self.analyzer.stop() except Exception as e: logger.warning(f"Error stopping analyzer: {e}") - + try: self.alert_manager.stop() except Exception as e: logger.warning(f"Error stopping alert manager: {e}") - + try: self.storage.stop_cleanup() except Exception as e: logger.warning(f"Error stopping storage cleanup: {e}") - + # Stop HTTP server if self._site: try: await self._site.stop() except Exception as e: logger.warning(f"Error stopping site: {e}") - + if self._runner: try: await self._runner.cleanup() except Exception as e: logger.warning(f"Error cleaning up runner: {e}") - + logger.info("Control plane server stopped") - + def request_shutdown(self): logger.info("Shutdown requested") self._shutdown_event.set() - + async def run_forever(self): await self.start() - + # Wait for shutdown signal try: await self._shutdown_event.wait() @@ -565,34 +588,36 @@ async def run_forever(self): def load_config(config_path: str): - with open(config_path, 'r') as f: + with open(config_path, "r") as f: return yaml.safe_load(f) def setup_logging(config): - log_config = config.get('logging', {}) - - level_str = log_config.get('level', 'INFO') + log_config = config.get("logging", {}) + + level_str = log_config.get("level", "INFO") level = getattr(logging, level_str.upper(), logging.INFO) - + logging.getLogger().setLevel(level) - logging.getLogger('beacon_detect').setLevel(level) - + logging.getLogger("beacon_detect").setLevel(level) + # Add file handler if configured - log_file = log_config.get('file') + log_file = log_config.get("file") if log_file: try: log_path = Path(log_file) log_path.parent.mkdir(parents=True, exist_ok=True) - + handler = logging.handlers.RotatingFileHandler( log_path, - maxBytes=log_config.get('max_size_mb', 50) * 1024 * 1024, - backupCount=log_config.get('backup_count', 5) + maxBytes=log_config.get("max_size_mb", 50) * 1024 * 1024, + backupCount=log_config.get("backup_count", 5), + ) + handler.setFormatter( + logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) ) - handler.setFormatter(logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - )) logging.getLogger().addHandler(handler) except Exception as e: logger.warning(f"Could not set up file logging: {e}") @@ -601,55 +626,51 @@ def setup_logging(config): async def main_async(config): server = ControlPlaneServer(config) - + # Set up signal handlers using the shutdown event loop = asyncio.get_running_loop() - + def handle_signal(): logger.info("Signal received, initiating shutdown...") server.request_shutdown() - + for sig in (signal.SIGINT, signal.SIGTERM): try: loop.add_signal_handler(sig, handle_signal) except NotImplementedError: # Windows doesn't support add_signal_handler pass - + await server.run_forever() def main(): parser = argparse.ArgumentParser( - description='eBPF Beaconing Detection - Control Plane Server' + description="eBPF Beaconing Detection - Control Plane Server" ) parser.add_argument( - '-c', '--config', - required=True, - help='Path to configuration file' + "-c", "--config", required=True, help="Path to configuration file" ) parser.add_argument( - '-v', '--verbose', - action='store_true', - help='Enable verbose logging' + "-v", "--verbose", action="store_true", help="Enable verbose logging" ) - + args = parser.parse_args() - + # Load configuration try: config = load_config(args.config) except Exception as e: logger.error(f"Failed to load configuration: {e}") sys.exit(1) - + # Set up logging setup_logging(config) - + if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + # Run server try: asyncio.run(main_async(config)) @@ -660,5 +681,5 @@ def main(): sys.exit(1) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/control_plane/storage.py b/control_plane/storage.py index ef6bec1..8675c68 100644 --- a/control_plane/storage.py +++ b/control_plane/storage.py @@ -7,12 +7,12 @@ from datetime import datetime from typing import Dict, List, Set -logger = logging.getLogger('beacon_detect.control_plane.storage') +logger = logging.getLogger("beacon_detect.control_plane.storage") @dataclass class ConnectionRecord: - + timestamp_ns: int timestamp_utc: str src_ip: str @@ -26,113 +26,113 @@ class ConnectionRecord: direction: int node_id: str connection_key: str - + # Computed fields timestamp_epoch: float = field(init=False) - + def __post_init__(self): try: - dt = datetime.fromisoformat(self.timestamp_utc.replace('Z', '+00:00')) + dt = datetime.fromisoformat(self.timestamp_utc.replace("Z", "+00:00")) self.timestamp_epoch = dt.timestamp() except (ValueError, AttributeError): self.timestamp_epoch = time.time() - + @classmethod def from_dict(cls, data): return cls( - timestamp_ns=data['timestamp_ns'], - timestamp_utc=data['timestamp_utc'], - src_ip=data['src_ip'], - dst_ip=data['dst_ip'], - src_port=data['src_port'], - dst_port=data['dst_port'], - packet_size=data['packet_size'], - protocol=data['protocol'], - protocol_name=data['protocol_name'], - tcp_flags=data.get('tcp_flags', 0), - direction=data.get('direction', 0), - node_id=data.get('node_id', 'unknown'), - connection_key=data['connection_key'] + timestamp_ns=data["timestamp_ns"], + timestamp_utc=data["timestamp_utc"], + src_ip=data["src_ip"], + dst_ip=data["dst_ip"], + src_port=data["src_port"], + dst_port=data["dst_port"], + packet_size=data["packet_size"], + protocol=data["protocol"], + protocol_name=data["protocol_name"], + tcp_flags=data.get("tcp_flags", 0), + direction=data.get("direction", 0), + node_id=data.get("node_id", "unknown"), + connection_key=data["connection_key"], ) -@dataclass +@dataclass class ConnectionPair: src_ip: str dst_ip: str dst_port: int protocol: str - + # Timestamps of all observed connections (epoch seconds) timestamps: List[float] = field(default_factory=list) - + # Packet sizes packet_sizes: List[int] = field(default_factory=list) - + # Source ports (may vary for same destination) src_ports: Set[int] = field(default_factory=set) - + # Nodes that reported this connection nodes: Set[str] = field(default_factory=set) - + # First and last seen first_seen = None - last_seen= None - + last_seen = None + @property def pair_key(self): return f"{self.src_ip}->{self.dst_ip}:{self.dst_port}/{self.protocol}" - + @property def connection_count(self): return len(self.timestamps) - + @property def duration_seconds(self): if self.first_seen and self.last_seen: return self.last_seen - self.first_seen return 0.0 - + def add_connection(self, record): ts = record.timestamp_epoch - + # Find insertion index to maintain sorted order idx = bisect.bisect_right(self.timestamps, ts) self.timestamps.insert(idx, ts) self.packet_sizes.insert(idx, record.packet_size) - + self.src_ports.add(record.src_port) self.nodes.add(record.node_id) - + if self.first_seen is None or ts < self.first_seen: self.first_seen = ts if self.last_seen is None or ts > self.last_seen: self.last_seen = ts - + def get_intervals(self): if len(self.timestamps) < 2: return [] - + intervals = [] for i in range(1, len(self.timestamps)): - interval = self.timestamps[i] - self.timestamps[i-1] + interval = self.timestamps[i] - self.timestamps[i - 1] intervals.append(interval) - + return intervals - + def prune_old(self, cutoff_time): # Find index of first timestamp >= cutoff_time idx = bisect.bisect_left(self.timestamps, cutoff_time) - + if idx > 0: self.timestamps = self.timestamps[idx:] # Can't easily prune packet_sizes without matching indices @@ -141,7 +141,7 @@ def prune_old(self, cutoff_time): keep_ratio = len(self.timestamps) / (len(self.timestamps) + idx) keep_count = max(1, int(len(self.packet_sizes) * keep_ratio)) self.packet_sizes = self.packet_sizes[-keep_count:] - + # Update first_seen if self.timestamps: self.first_seen = self.timestamps[0] @@ -151,90 +151,83 @@ def prune_old(self, cutoff_time): class ConnectionStorage: - - def __init__( - self, - retention_seconds: int = 7200, - cleanup_interval: int = 300 - ): + + def __init__(self, retention_seconds: int = 7200, cleanup_interval: int = 300): self.retention_seconds = retention_seconds self.cleanup_interval = cleanup_interval - + # Primary storage: connection pairs indexed by pair key self._pairs: Dict[str, ConnectionPair] = {} - + # Index by source IP for efficient lookup self._by_src_ip: Dict[str, Set[str]] = defaultdict(set) - + # Index by destination IP self._by_dst_ip: Dict[str, Set[str]] = defaultdict(set) - + # Index by destination port self._by_dst_port: Dict[int, Set[str]] = defaultdict(set) - + # Thread safety self._lock = threading.RLock() - + # Statistics self._records_added = 0 self._records_expired = 0 self._batches_received = 0 - + # Cleanup thread self._cleanup_thread = None self._stop_cleanup = threading.Event() - + logger.info( f"ConnectionStorage initialized: retention={retention_seconds}s, " f"cleanup_interval={cleanup_interval}s" ) - + def start_cleanup(self): self._stop_cleanup.clear() - self._cleanup_thread = threading.Thread( - target=self._cleanup_loop, - daemon=True - ) + self._cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self._cleanup_thread.start() logger.info("Storage cleanup thread started") - + def stop_cleanup(self): self._stop_cleanup.set() if self._cleanup_thread: self._cleanup_thread.join(timeout=5) logger.info("Storage cleanup thread stopped") - + def _cleanup_loop(self): while not self._stop_cleanup.wait(timeout=self.cleanup_interval): self.cleanup_expired() - + def add_record(self, record): with self._lock: # Create pair key (src -> dst:port/proto) pair_key = f"{record.src_ip}->{record.dst_ip}:{record.dst_port}/{record.protocol_name}" - + # Get or create connection pair if pair_key not in self._pairs: self._pairs[pair_key] = ConnectionPair( src_ip=record.src_ip, dst_ip=record.dst_ip, dst_port=record.dst_port, - protocol=record.protocol_name + protocol=record.protocol_name, ) - + # Update indices self._by_src_ip[record.src_ip].add(pair_key) self._by_dst_ip[record.dst_ip].add(pair_key) self._by_dst_port[record.dst_port].add(pair_key) - + # Add connection to pair self._pairs[pair_key].add_connection(record) self._records_added += 1 - + def add_batch(self, records): with self._lock: @@ -245,69 +238,71 @@ def add_batch(self, records): self.add_record(record) except Exception as e: logger.warning(f"Failed to add record: {e}") - + def get_pair(self, pair_key): with self._lock: return self._pairs.get(pair_key) - + def get_pairs_by_src(self, src_ip): with self._lock: pair_keys = self._by_src_ip.get(src_ip, set()) return [self._pairs[k] for k in pair_keys if k in self._pairs] - + def get_pairs_by_dst(self, dst_ip): with self._lock: pair_keys = self._by_dst_ip.get(dst_ip, set()) return [self._pairs[k] for k in pair_keys if k in self._pairs] - + def get_pairs_by_port(self, dst_port): with self._lock: pair_keys = self._by_dst_port.get(dst_port, set()) return [self._pairs[k] for k in pair_keys if k in self._pairs] - + def get_all_pairs(self): with self._lock: return list(self._pairs.values()) - - def get_analyzable_pairs( self, min_connections = 10, min_duration = 300): + + def get_analyzable_pairs(self, min_connections=10, min_duration=300): with self._lock: result = [] for pair in self._pairs.values(): - if (pair.connection_count >= min_connections and - pair.duration_seconds >= min_duration): + if ( + pair.connection_count >= min_connections + and pair.duration_seconds >= min_duration + ): result.append(pair) return result - + def cleanup_expired(self): cutoff_time = time.time() - self.retention_seconds removed_count = 0 - + with self._lock: # Prune old timestamps from pairs pairs_to_remove = [] - + for pair_key, pair in self._pairs.items(): original_count = pair.connection_count pair.prune_old(cutoff_time) removed_count += original_count - pair.connection_count - + # Mark empty pairs for removal if pair.connection_count == 0: pairs_to_remove.append(pair_key) - + # Remove empty pairs for pair_key in pairs_to_remove: pair = self._pairs.pop(pair_key) - + # Update indices self._by_src_ip[pair.src_ip].discard(pair_key) self._by_dst_ip[pair.dst_ip].discard(pair_key) self._by_dst_port[pair.dst_port].discard(pair_key) - + # Clean up empty index entries if not self._by_src_ip[pair.src_ip]: del self._by_src_ip[pair.src_ip] @@ -315,36 +310,38 @@ def cleanup_expired(self): del self._by_dst_ip[pair.dst_ip] if not self._by_dst_port[pair.dst_port]: del self._by_dst_port[pair.dst_port] - + if removed_count > 0 or pairs_to_remove: self._records_expired += removed_count logger.info( f"Cleanup complete: removed {removed_count} records, " f"{len(pairs_to_remove)} empty pairs" ) - + @property def statistics(self): with self._lock: total_connections = sum(p.connection_count for p in self._pairs.values()) return { - 'pair_count': len(self._pairs), - 'total_connections': total_connections, - 'unique_src_ips': len(self._by_src_ip), - 'unique_dst_ips': len(self._by_dst_ip), - 'unique_dst_ports': len(self._by_dst_port), - 'records_added': self._records_added, - 'records_expired': self._records_expired, - 'batches_received': self._batches_received, - 'retention_seconds': self.retention_seconds + "pair_count": len(self._pairs), + "total_connections": total_connections, + "unique_src_ips": len(self._by_src_ip), + "unique_dst_ips": len(self._by_dst_ip), + "unique_dst_ports": len(self._by_dst_port), + "records_added": self._records_added, + "records_expired": self._records_expired, + "batches_received": self._batches_received, + "retention_seconds": self.retention_seconds, } - + def __len__(self): with self._lock: return len(self._pairs) - + def __repr__(self): stats = self.statistics - return (f"ConnectionStorage(pairs={stats['pair_count']}, " - f"connections={stats['total_connections']})") + return ( + f"ConnectionStorage(pairs={stats['pair_count']}, " + f"connections={stats['total_connections']})" + ) diff --git a/data_plane/collector.py b/data_plane/collector.py index 7c6e12e..8d2dd6e 100644 --- a/data_plane/collector.py +++ b/data_plane/collector.py @@ -27,195 +27,189 @@ print("Install with: sudo apt-get install bpfcc-tools python3-bcc") sys.exit(1) -from .telemetry import ( - ConnectionEvent, - ConnectionEventCType, - TelemetryBuffer, - DataPlaneStats -) from .exporter import ExporterConfig, SyncTelemetryExporter - +from .telemetry import (ConnectionEvent, ConnectionEventCType, DataPlaneStats, + TelemetryBuffer) # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) -logger = logging.getLogger('beacon_detect.data_plane.collector') +logger = logging.getLogger("beacon_detect.data_plane.collector") class DataPlaneCollector: """ Main data plane collector class. - + Responsible for: - Loading and managing the eBPF program - Processing events from the ring buffer - Managing the telemetry buffer and export schedule """ - + # Path to the eBPF program source - EBPF_PROGRAM_PATH = Path(__file__).parent / 'ebpf_program.c' - - def __init__( - self, - interface: str, - config: Dict, - node_id: str = None - ): + EBPF_PROGRAM_PATH = Path(__file__).parent / "ebpf_program.c" + + def __init__(self, interface: str, config: Dict, node_id: str = None): self.interface = interface self.config = config self.node_id = node_id or self._generate_node_id() - + # eBPF components self._bpf = None - + # Telemetry components - dp_config = config.get('data_plane', {}) + dp_config = config.get("data_plane", {}) self._buffer = TelemetryBuffer( - max_size=dp_config.get('max_buffer_size', 100000) + max_size=dp_config.get("max_buffer_size", 100000) ) - + # Exporter configuration exporter_config = ExporterConfig( - control_plane_host=dp_config.get('control_plane_host', '127.0.0.1'), - control_plane_port=dp_config.get('control_plane_port', 9090), - connection_timeout=dp_config.get('connection_timeout', 10.0), - node_id=self.node_id + control_plane_host=dp_config.get("control_plane_host", "127.0.0.1"), + control_plane_port=dp_config.get("control_plane_port", 9090), + connection_timeout=dp_config.get("connection_timeout", 10.0), + node_id=self.node_id, ) self._exporter = SyncTelemetryExporter(exporter_config) - + # Export timing - self._export_interval = dp_config.get('export_interval', 60) + self._export_interval = dp_config.get("export_interval", 60) self._last_export_time = time.time() - + # Statistics self._stats = DataPlaneStats() - + # Control flags self._running = False self._shutdown_event = threading.Event() - + logger.info( f"DataPlaneCollector initialized: interface={interface}, " f"node_id={self.node_id}, export_interval={self._export_interval}s" ) - + def _generate_node_id(self): import socket + hostname = socket.gethostname() unique_suffix = uuid.uuid4().hex[:8] return f"dp-{hostname}-{unique_suffix}" - + def _load_ebpf_program(self): logger.info(f"Loading eBPF program from {self.EBPF_PROGRAM_PATH}") - + # Read the eBPF program source - with open(self.EBPF_PROGRAM_PATH, 'r') as f: + with open(self.EBPF_PROGRAM_PATH, "r") as f: bpf_source = f.read() - + # Apply configuration flags - dp_config = self.config.get('data_plane', {}) - if not dp_config.get('track_tcp', True): - bpf_source = bpf_source.replace('#define TRACK_TCP 1', '#define TRACK_TCP 0') - if not dp_config.get('track_udp', True): - bpf_source = bpf_source.replace('#define TRACK_UDP 1', '#define TRACK_UDP 0') - + dp_config = self.config.get("data_plane", {}) + if not dp_config.get("track_tcp", True): + bpf_source = bpf_source.replace( + "#define TRACK_TCP 1", "#define TRACK_TCP 0" + ) + if not dp_config.get("track_udp", True): + bpf_source = bpf_source.replace( + "#define TRACK_UDP 1", "#define TRACK_UDP 0" + ) + # Compile the BPF program try: - bpf = BPF(text=bpf_source, cflags=['-Wno-macro-redefined']) + bpf = BPF(text=bpf_source, cflags=["-Wno-macro-redefined"]) logger.info("eBPF program compiled successfully") return bpf except Exception as e: logger.error(f"Failed to compile eBPF program: {e}") raise - + def _attach_ebpf_program(self): logger.info(f"Attaching eBPF program to interface {self.interface}") - + # Try XDP first (best performance), fall back to TC try: # Attach XDP program fn = self._bpf.load_func("xdp_connection_tracker", BPF.XDP) self._bpf.attach_xdp(self.interface, fn, 0) logger.info(f"Attached XDP program to {self.interface}") - self._attachment_mode = 'xdp' + self._attachment_mode = "xdp" except Exception as e: logger.warning(f"XDP attachment failed, trying TC: {e}") - + # Fall back to TC (Traffic Control) try: # Attach TC ingress fn_ingress = self._bpf.load_func("tc_ingress_tracker", BPF.SCHED_CLS) self._bpf.attach_raw_socket(fn_ingress, self.interface) - + logger.info(f"Attached TC program to {self.interface}") - self._attachment_mode = 'tc' + self._attachment_mode = "tc" except Exception as e2: logger.error(f"Failed to attach eBPF program: {e2}") raise RuntimeError(f"Could not attach eBPF program: {e2}") - + def _detach_ebpf_program(self): - if self._bpf and hasattr(self, '_attachment_mode'): + if self._bpf and hasattr(self, "_attachment_mode"): try: - if self._attachment_mode == 'xdp': + if self._attachment_mode == "xdp": self._bpf.remove_xdp(self.interface, 0) logger.info(f"Detached eBPF program from {self.interface}") except Exception as e: logger.warning(f"Error detaching eBPF program: {e}") - + def _setup_ring_buffer(self): - + def ring_buffer_callback(ctx, data, size): try: # Cast the raw data to our event structure event = ctypes.cast(data, ctypes.POINTER(ConnectionEventCType)).contents - + # Convert to Python object conn_event = ConnectionEvent.from_ctype(event, self.node_id) - + # Apply whitelist filtering if self._should_filter(conn_event): return - + # Add to buffer if not self._buffer.add(conn_event): logger.warning("Buffer overflow, events being dropped") - + except Exception as e: logger.error(f"Error processing ring buffer event: {e}") - + # Get the ring buffer and set up polling self._bpf["events"].open_ring_buffer(ring_buffer_callback) logger.info("Ring buffer callback configured") - + def _should_filter(self, event: ConnectionEvent): - whitelist = self.config.get('whitelist', {}) - + whitelist = self.config.get("whitelist", {}) + # Check source IP whitelist - if event.src_ip in whitelist.get('source_ips', []): + if event.src_ip in whitelist.get("source_ips", []): return True - + # Check destination IP whitelist - if event.dst_ip in whitelist.get('destination_ips', []): + if event.dst_ip in whitelist.get("destination_ips", []): return True - + # Check port whitelist - if event.dst_port in whitelist.get('ports', []): + if event.dst_port in whitelist.get("ports", []): return True - + # Check specific pairs pair_key = f"{event.src_ip}:{event.dst_ip}:{event.dst_port}" - if pair_key in whitelist.get('pairs', []): + if pair_key in whitelist.get("pairs", []): return True - + return False - + def _read_stats(self): try: @@ -231,17 +225,17 @@ def _read_stats(self): self._stats.events_buffered = self._buffer.size except Exception as e: logger.warning(f"Error reading stats: {e}") - + def _export_telemetry(self): events = self._buffer.drain() - + if not events: logger.debug("No events to export") return - + logger.info(f"Exporting {len(events)} events to control plane") - + try: success = self._exporter.export_events(events) if success: @@ -257,106 +251,106 @@ def _export_telemetry(self): logger.error(f"Export error: {e}") # Re-add events to buffer self._buffer.add_batch(events) - + def start(self): logger.info("Starting data plane collector...") - + # Check root privileges if os.geteuid() != 0: raise PermissionError("Root privileges required for eBPF") - + # Load and attach eBPF program self._bpf = self._load_ebpf_program() self._attach_ebpf_program() self._setup_ring_buffer() - + # Start exporter self._exporter.start() - + self._running = True self._last_export_time = time.time() - + logger.info("Data plane collector started") - + def run(self): if not self._running: self.start() - + logger.info("Entering main event loop") - + try: while not self._shutdown_event.is_set(): # Poll ring buffer for new events (100ms timeout) self._bpf.ring_buffer_poll(100) - + # Check if it's time to export current_time = time.time() if current_time - self._last_export_time >= self._export_interval: self._read_stats() self._export_telemetry() self._last_export_time = current_time - + # Log statistics logger.info(f"Stats: {self._stats}") - + except KeyboardInterrupt: logger.info("Interrupted by user") finally: self.stop() - + def stop(self): logger.info("Stopping data plane collector...") - + self._shutdown_event.set() self._running = False - + # Final export of any remaining events if self._buffer.size > 0: logger.info("Exporting remaining buffered events...") self._export_telemetry() - + # Stop exporter self._exporter.stop() - + # Detach and cleanup eBPF self._detach_ebpf_program() - + if self._bpf: self._bpf = None - + logger.info("Data plane collector stopped") - + @property def statistics(self): self._read_stats() return { - 'node_id': self.node_id, - 'interface': self.interface, - 'running': self._running, - 'ebpf_stats': self._stats.to_dict(), - 'exporter_stats': self._exporter.statistics, - 'buffer_size': self._buffer.size, - 'buffer_overflow': self._buffer.overflow_count + "node_id": self.node_id, + "interface": self.interface, + "running": self._running, + "ebpf_stats": self._stats.to_dict(), + "exporter_stats": self._exporter.statistics, + "buffer_size": self._buffer.size, + "buffer_overflow": self._buffer.overflow_count, } def load_config(config_path: str): - with open(config_path, 'r') as f: + with open(config_path, "r") as f: return yaml.safe_load(f) def setup_signal_handlers(collector): - + def signal_handler(signum, frame): logger.info(f"Received signal {signum}, initiating shutdown...") collector.stop() sys.exit(0) - + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) @@ -364,58 +358,52 @@ def signal_handler(signum, frame): def main(): parser = argparse.ArgumentParser( - description='eBPF Beaconing Detection - Data Plane Collector' + description="eBPF Beaconing Detection - Data Plane Collector" ) parser.add_argument( - '-c', '--config', - required=True, - help='Path to configuration file' + "-c", "--config", required=True, help="Path to configuration file" ) parser.add_argument( - '-i', '--interface', - help='Network interface to monitor (overrides config)' + "-i", "--interface", help="Network interface to monitor (overrides config)" ) parser.add_argument( - '-n', '--node-id', - help='Unique node identifier (auto-generated if not provided)' + "-n", + "--node-id", + help="Unique node identifier (auto-generated if not provided)", ) parser.add_argument( - '-v', '--verbose', - action='store_true', - help='Enable verbose logging' + "-v", "--verbose", action="store_true", help="Enable verbose logging" ) - + args = parser.parse_args() - + # Set log level if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + # Load configuration try: config = load_config(args.config) except Exception as e: logger.error(f"Failed to load configuration: {e}") sys.exit(1) - + # Determine interface - interface = args.interface or config.get('data_plane', {}).get('interface') + interface = args.interface or config.get("data_plane", {}).get("interface") if not interface: logger.error("Network interface not specified") sys.exit(1) - + # Create and run collector try: collector = DataPlaneCollector( - interface=interface, - config=config, - node_id=args.node_id + interface=interface, config=config, node_id=args.node_id ) - + setup_signal_handlers(collector) - + collector.run() - + except PermissionError: logger.error("Root privileges required. Run with sudo.") sys.exit(1) @@ -424,5 +412,5 @@ def main(): sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/data_plane/exporter.py b/data_plane/exporter.py index 4f53e25..decfaf8 100644 --- a/data_plane/exporter.py +++ b/data_plane/exporter.py @@ -6,17 +6,12 @@ from dataclasses import dataclass import aiohttp -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type -) +from tenacity import (retry, retry_if_exception_type, stop_after_attempt, + wait_exponential) from .telemetry import TelemetryBatch - -logger = logging.getLogger('beacon_detect.data_plane.exporter') +logger = logging.getLogger("beacon_detect.data_plane.exporter") @dataclass @@ -29,14 +24,16 @@ class ExporterConfig: batch_size: int = 1000 compression_enabled: bool = True node_id: str = None - + @property def control_plane_url(self): return f"http://{self.control_plane_host}:{self.control_plane_port}/api/v1/telemetry" - + @property def health_check_url(self): - return f"http://{self.control_plane_host}:{self.control_plane_port}/api/v1/health" + return ( + f"http://{self.control_plane_host}:{self.control_plane_port}/api/v1/health" + ) class ExportError(Exception): @@ -44,13 +41,13 @@ class ExportError(Exception): class TelemetryExporter: - + def __init__(self, config: ExporterConfig): self.config = config self.node_id = config.node_id or self._generate_node_id() - + self._session: aiohttp.ClientSession = None - + # Statistics self._batches_sent = 0 self._batches_failed = 0 @@ -58,56 +55,50 @@ def __init__(self, config: ExporterConfig): self._bytes_sent = 0 self._last_export_time: float = None self._last_error: str = None - + # Connection state self._control_plane_healthy = False self._last_health_check: float = None - + logger.info(f"TelemetryExporter initialized with node_id={self.node_id}") - + def _generate_node_id(self): hostname = socket.gethostname() unique_suffix = uuid.uuid4().hex[:8] return f"{hostname}-{unique_suffix}" - + async def start(self): timeout = aiohttp.ClientTimeout( - total=self.config.request_timeout, - connect=self.config.connection_timeout + total=self.config.request_timeout, connect=self.config.connection_timeout ) - + connector = aiohttp.TCPConnector( - limit=10, - limit_per_host=5, - keepalive_timeout=60 + limit=10, limit_per_host=5, keepalive_timeout=60 ) - + self._session = aiohttp.ClientSession( timeout=timeout, connector=connector, - headers={ - 'Content-Type': 'application/json', - 'X-Node-ID': self.node_id - } + headers={"Content-Type": "application/json", "X-Node-ID": self.node_id}, ) - + # Initial health check await self._check_health() - + logger.info("TelemetryExporter started") - + async def stop(self): if self._session: await self._session.close() self._session = None - + logger.info("TelemetryExporter stopped") - + async def _check_health(self): if not self._session: return False - + try: async with self._session.get(self.config.health_check_url) as response: if response.status == 200: @@ -117,61 +108,60 @@ async def _check_health(self): return True else: self._control_plane_healthy = False - logger.warning(f"Control plane health check failed: {response.status}") + logger.warning( + f"Control plane health check failed: {response.status}" + ) return False except Exception as e: self._control_plane_healthy = False self._last_error = str(e) logger.warning(f"Control plane health check error: {e}") return False - + async def export_events(self, events: list): if not events: logger.debug("No events to export") return True - + # Create batch batch = TelemetryBatch( - batch_id=str(uuid.uuid4()), - node_id=self.node_id, - events=events + batch_id=str(uuid.uuid4()), node_id=self.node_id, events=events ) - + return await self.export_batch(batch) - + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type(ExportError), - reraise=True + reraise=True, ) async def export_batch(self, batch): if not self._session: raise ExportError("Exporter not started") - + # Serialize batch try: payload = batch.to_json() - payload_bytes = payload.encode('utf-8') + payload_bytes = payload.encode("utf-8") except Exception as e: logger.error(f"Failed to serialize batch: {e}") self._batches_failed += 1 raise ExportError(f"Serialization failed: {e}") - + # Compress if enabled and beneficial headers = {} if self.config.compression_enabled and len(payload_bytes) > 1024: import gzip + payload_bytes = gzip.compress(payload_bytes) - headers['Content-Encoding'] = 'gzip' - + headers["Content-Encoding"] = "gzip" + # Send to control plane try: async with self._session.post( - self.config.control_plane_url, - data=payload_bytes, - headers=headers + self.config.control_plane_url, data=payload_bytes, headers=headers ) as response: if response.status == 200: self._batches_sent += 1 @@ -179,7 +169,7 @@ async def export_batch(self, batch): self._bytes_sent += len(payload_bytes) self._last_export_time = time.time() self._control_plane_healthy = True - + logger.info( f"Exported batch {batch.batch_id}: " f"{batch.event_count} events, {len(payload_bytes)} bytes" @@ -187,7 +177,7 @@ async def export_batch(self, batch): return True elif response.status == 429: # Rate limited - wait and retry - retry_after = int(response.headers.get('Retry-After', 5)) + retry_after = int(response.headers.get("Retry-After", 5)) logger.warning(f"Rate limited, waiting {retry_after}s") await asyncio.sleep(retry_after) raise ExportError(f"Rate limited: {response.status}") @@ -196,26 +186,22 @@ async def export_batch(self, batch): logger.error(f"Export failed: {response.status} - {body}") self._batches_failed += 1 raise ExportError(f"HTTP {response.status}: {body}") - + except aiohttp.ClientError as e: self._batches_failed += 1 self._control_plane_healthy = False self._last_error = str(e) logger.error(f"Connection error during export: {e}") raise ExportError(f"Connection error: {e}") - - async def export_events_chunked( - self, - events: list, - chunk_size: int = None - ): - + + async def export_events_chunked(self, events: list, chunk_size: int = None): + chunk_size = chunk_size or self.config.batch_size successful = 0 failed = 0 - + for i in range(0, len(events), chunk_size): - chunk = events[i:i + chunk_size] + chunk = events[i : i + chunk_size] try: if await self.export_events(chunk): successful += len(chunk) @@ -223,107 +209,105 @@ async def export_events_chunked( failed += len(chunk) except ExportError: failed += len(chunk) - + return successful, failed - + @property def is_healthy(self): return self._control_plane_healthy - + @property def statistics(self): return { - 'node_id': self.node_id, - 'batches_sent': self._batches_sent, - 'batches_failed': self._batches_failed, - 'events_sent': self._events_sent, - 'bytes_sent': self._bytes_sent, - 'last_export_time': self._last_export_time, - 'last_error': self._last_error, - 'control_plane_healthy': self._control_plane_healthy, - 'last_health_check': self._last_health_check + "node_id": self.node_id, + "batches_sent": self._batches_sent, + "batches_failed": self._batches_failed, + "events_sent": self._events_sent, + "bytes_sent": self._bytes_sent, + "last_export_time": self._last_export_time, + "last_error": self._last_error, + "control_plane_healthy": self._control_plane_healthy, + "last_health_check": self._last_health_check, } class SyncTelemetryExporter: - + def __init__(self, config: ExporterConfig): self.config = config self._async_exporter: TelemetryExporter = None self._loop: asyncio.AbstractEventLoop = None self._thread = None - + def start(self): import threading - + def run_event_loop(): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - + self._async_exporter = TelemetryExporter(self.config) self._loop.run_until_complete(self._async_exporter.start()) - + # Keep the loop running self._loop.run_forever() - + self._thread = threading.Thread(target=run_event_loop, daemon=True) self._thread.start() - + # Wait for initialization time.sleep(0.5) logger.info("SyncTelemetryExporter started") - + def stop(self): if self._loop and self._async_exporter: # Schedule stop coroutine and wait for it to complete try: future = asyncio.run_coroutine_threadsafe( - self._async_exporter.stop(), - self._loop + self._async_exporter.stop(), self._loop ) # Wait for the stop coroutine to complete with timeout future.result(timeout=5) except Exception as e: logger.warning(f"Error stopping async exporter: {e}") - + # Stop the event loop try: self._loop.call_soon_threadsafe(self._loop.stop) except Exception as e: logger.warning(f"Error stopping event loop: {e}") - + if self._thread: try: self._thread.join(timeout=5) except Exception as e: logger.warning(f"Error joining thread: {e}") - + logger.info("SyncTelemetryExporter stopped") - + def export_events(self, events: list): if not self._loop or not self._async_exporter: logger.error("Exporter not started") return False - + future = asyncio.run_coroutine_threadsafe( - self._async_exporter.export_events(events), - self._loop + self._async_exporter.export_events(events), self._loop ) - + try: return future.result(timeout=self.config.request_timeout + 5) except Exception as e: logger.error(f"Export failed: {e}") return False - + @property def statistics(self): if self._async_exporter: return self._async_exporter.statistics return {} - + @property def is_healthy(self): if self._async_exporter: diff --git a/data_plane/telemetry.py b/data_plane/telemetry.py index 3ef6521..4887bc6 100644 --- a/data_plane/telemetry.py +++ b/data_plane/telemetry.py @@ -2,10 +2,10 @@ import json import socket import struct -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from datetime import datetime, timezone -from typing import List from enum import IntEnum +from typing import List class Protocol(IntEnum): @@ -31,7 +31,7 @@ class TCPFlags(IntEnum): class ConnectionEventCType(ctypes.Structure): - + # C-compatible structure matching the eBPF connection_event struct. _fields_ = [ ("timestamp_ns", ctypes.c_uint64), @@ -59,19 +59,21 @@ class ConnectionEvent: protocol: int tcp_flags: int = 0 direction: int = 0 - + # Derived fields (computed, not transmitted from kernel) timestamp_utc: str = None node_id: str = None - + def __post_init__(self): if self.timestamp_utc is None: # Convert kernel timestamp to UTC datetime string # Note: bpf_ktime_get_ns() returns time since boot, not epoch # We'll use current time for UTC representation - self.timestamp_utc = datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z') - + self.timestamp_utc = ( + datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + ) + @classmethod def from_ctype(cls, event: ConnectionEventCType, node_id: str = None): """ @@ -87,26 +89,30 @@ def from_ctype(cls, event: ConnectionEventCType, node_id: str = None): protocol=event.protocol, tcp_flags=event.tcp_flags, direction=event.direction, - node_id=node_id + node_id=node_id, ) - + @staticmethod def _int_to_ip(ip_int: int): # The IP is in network byte order from the kernel # struct.pack('!I', ...) expects host order, so we use ntohl to convert - return socket.inet_ntoa(struct.pack('I', socket.ntohl(ip_int))) - + return socket.inet_ntoa(struct.pack("I", socket.ntohl(ip_int))) + @property def protocol_name(self): - return Protocol(self.protocol).name if self.protocol in [6, 17] else f"UNKNOWN({self.protocol})" - + return ( + Protocol(self.protocol).name + if self.protocol in [6, 17] + else f"UNKNOWN({self.protocol})" + ) + @property def direction_name(self): return Direction(self.direction).name - + @property def tcp_flags_list(self): @@ -114,87 +120,89 @@ def tcp_flags_list(self): return [] flags = [] if self.tcp_flags & TCPFlags.FIN: - flags.append('FIN') + flags.append("FIN") if self.tcp_flags & TCPFlags.SYN: - flags.append('SYN') + flags.append("SYN") if self.tcp_flags & TCPFlags.RST: - flags.append('RST') + flags.append("RST") if self.tcp_flags & TCPFlags.PSH: - flags.append('PSH') + flags.append("PSH") if self.tcp_flags & TCPFlags.ACK: - flags.append('ACK') + flags.append("ACK") if self.tcp_flags & TCPFlags.URG: - flags.append('URG') + flags.append("URG") return flags - + @property def connection_key(self): - + # Generate a unique key for this connection pair. - # Used for grouping events by connection. + # Used for grouping events by connection. return f"{self.src_ip}:{self.src_port}->{self.dst_ip}:{self.dst_port}/{self.protocol_name}" - + @property def bidirectional_key(self): - + # Generate a key that is the same regardless of direction. # Used for matching request/response pairs. # Sort the endpoints to get consistent key ep1 = (self.src_ip, self.src_port) ep2 = (self.dst_ip, self.dst_port) - + if ep1 < ep2: return f"{self.src_ip}:{self.src_port}<->{self.dst_ip}:{self.dst_port}/{self.protocol_name}" else: return f"{self.dst_ip}:{self.dst_port}<->{self.src_ip}:{self.src_port}/{self.protocol_name}" - + def to_dict(self): - + return { - 'timestamp_ns': self.timestamp_ns, - 'timestamp_utc': self.timestamp_utc, - 'src_ip': self.src_ip, - 'dst_ip': self.dst_ip, - 'src_port': self.src_port, - 'dst_port': self.dst_port, - 'packet_size': self.packet_size, - 'protocol': self.protocol, - 'protocol_name': self.protocol_name, - 'tcp_flags': self.tcp_flags, - 'tcp_flags_list': self.tcp_flags_list, - 'direction': self.direction, - 'direction_name': self.direction_name, - 'node_id': self.node_id, - 'connection_key': self.connection_key + "timestamp_ns": self.timestamp_ns, + "timestamp_utc": self.timestamp_utc, + "src_ip": self.src_ip, + "dst_ip": self.dst_ip, + "src_port": self.src_port, + "dst_port": self.dst_port, + "packet_size": self.packet_size, + "protocol": self.protocol, + "protocol_name": self.protocol_name, + "tcp_flags": self.tcp_flags, + "tcp_flags_list": self.tcp_flags_list, + "direction": self.direction, + "direction_name": self.direction_name, + "node_id": self.node_id, + "connection_key": self.connection_key, } - + def to_json(self): return json.dumps(self.to_dict()) - + @classmethod def from_dict(cls, data): return cls( - timestamp_ns=data['timestamp_ns'], - src_ip=data['src_ip'], - dst_ip=data['dst_ip'], - src_port=data['src_port'], - dst_port=data['dst_port'], - packet_size=data['packet_size'], - protocol=data['protocol'], - tcp_flags=data.get('tcp_flags', 0), - direction=data.get('direction', 0), - timestamp_utc=data.get('timestamp_utc'), - node_id=data.get('node_id') + timestamp_ns=data["timestamp_ns"], + src_ip=data["src_ip"], + dst_ip=data["dst_ip"], + src_port=data["src_port"], + dst_port=data["dst_port"], + packet_size=data["packet_size"], + protocol=data["protocol"], + tcp_flags=data.get("tcp_flags", 0), + direction=data.get("direction", 0), + timestamp_utc=data.get("timestamp_utc"), + node_id=data.get("node_id"), ) - + def __repr__(self): flags_str = f" [{','.join(self.tcp_flags_list)}]" if self.tcp_flags_list else "" - return (f"ConnectionEvent({self.connection_key}{flags_str} " - f"size={self.packet_size} dir={self.direction_name})") + return ( + f"ConnectionEvent({self.connection_key}{flags_str} " + f"size={self.packet_size} dir={self.direction_name})" + ) @dataclass @@ -203,56 +211,65 @@ class TelemetryBatch: batch_id: str node_id: str events: List[ConnectionEvent] - created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')) - + created_at: str = field( + default_factory=lambda: datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z") + ) + # Statistics event_count: int = field(init=False) tcp_count: int = field(init=False) udp_count: int = field(init=False) unique_connections: int = field(init=False) - + def __post_init__(self): self.event_count = len(self.events) self.tcp_count = sum(1 for e in self.events if e.protocol == Protocol.TCP) self.udp_count = sum(1 for e in self.events if e.protocol == Protocol.UDP) self.unique_connections = len(set(e.connection_key for e in self.events)) - + def to_dict(self): return { - 'batch_id': self.batch_id, - 'node_id': self.node_id, - 'created_at': self.created_at, - 'event_count': self.event_count, - 'tcp_count': self.tcp_count, - 'udp_count': self.udp_count, - 'unique_connections': self.unique_connections, - 'events': [e.to_dict() for e in self.events] + "batch_id": self.batch_id, + "node_id": self.node_id, + "created_at": self.created_at, + "event_count": self.event_count, + "tcp_count": self.tcp_count, + "udp_count": self.udp_count, + "unique_connections": self.unique_connections, + "events": [e.to_dict() for e in self.events], } - + def to_json(self): return json.dumps(self.to_dict()) - + @classmethod def from_dict(cls, data): - events = [ConnectionEvent.from_dict(e) for e in data['events']] + events = [ConnectionEvent.from_dict(e) for e in data["events"]] batch = cls( - batch_id=data['batch_id'], - node_id=data['node_id'], + batch_id=data["batch_id"], + node_id=data["node_id"], events=events, - created_at=data.get('created_at', datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')) + created_at=data.get( + "created_at", + datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + ), ) return batch - + @classmethod def from_json(cls, json_str: str): return cls.from_dict(json.loads(json_str)) - + def __repr__(self): - return (f"TelemetryBatch(id={self.batch_id}, node={self.node_id}, " - f"events={self.event_count}, unique_conns={self.unique_connections})") + return ( + f"TelemetryBatch(id={self.batch_id}, node={self.node_id}, " + f"events={self.event_count}, unique_conns={self.unique_connections})" + ) @dataclass @@ -266,33 +283,35 @@ class DataPlaneStats: events_dropped: int = 0 dedup_hits: int = 0 parse_errors: int = 0 - + # User-space statistics batches_sent: int = 0 batches_failed: int = 0 events_buffered: int = 0 - + def to_dict(self): return asdict(self) - + def __repr__(self): - return (f"DataPlaneStats(total={self.packets_total}, tcp={self.packets_tcp}, " - f"udp={self.packets_udp}, submitted={self.events_submitted}, " - f"dropped={self.events_dropped})") + return ( + f"DataPlaneStats(total={self.packets_total}, tcp={self.packets_tcp}, " + f"udp={self.packets_udp}, submitted={self.events_submitted}, " + f"dropped={self.events_dropped})" + ) class TelemetryBuffer: - + def __init__(self, max_size: int = 100000): - + import threading - + self.max_size = max_size self._buffer: List[ConnectionEvent] = [] self._lock = threading.Lock() self._overflow_count = 0 - + def add(self, event): with self._lock: @@ -301,7 +320,7 @@ def add(self, event): return False self._buffer.append(event) return True - + def add_batch(self, events): with self._lock: @@ -311,28 +330,28 @@ def add_batch(self, events): overflow = len(events) - len(to_add) self._overflow_count += overflow return len(to_add) - + def drain(self): - - # Remove and return all events from the buffer. - + + # Remove and return all events from the buffer. + with self._lock: events = self._buffer self._buffer = [] return events - + @property def size(self): with self._lock: return len(self._buffer) - + @property def overflow_count(self): with self._lock: return self._overflow_count - + def reset_overflow_count(self): with self._lock: diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 14d1d0f..b2152f8 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -1,22 +1,26 @@ -import pytest import sys import time from pathlib import Path from unittest.mock import Mock +import pytest + # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) -from control_plane.storage import ConnectionStorage, ConnectionRecord, ConnectionPair -from control_plane.detector import BeaconDetector, DetectorConfig, DetectionResult, BeaconConfidence -from control_plane.analyzer import ConnectionAnalyzer, AnalyzerConfig, AnalysisRun -from control_plane.alerter import AlertManager, Alert, AlertSeverity +from control_plane.alerter import Alert, AlertManager, AlertSeverity +from control_plane.analyzer import (AnalysisRun, AnalyzerConfig, + ConnectionAnalyzer) +from control_plane.detector import (BeaconConfidence, BeaconDetector, + DetectionResult, DetectorConfig) +from control_plane.storage import (ConnectionPair, ConnectionRecord, + ConnectionStorage) -class TestConnectionStorage: +class TestConnectionStorage: def test_add_record(self): - storage = ConnectionStorage(retention_seconds=3600) + storage = ConnectionStorage(retention_seconds=3600) record = ConnectionRecord( timestamp_ns=1000000000, timestamp_utc="2024-01-01T00:00:00Z", @@ -30,46 +34,46 @@ def test_add_record(self): tcp_flags=0x10, direction=1, node_id="test-node", - connection_key="192.168.1.100:54321->10.0.0.1:443/TCP" + connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", ) - + storage.add_record(record) - + stats = storage.statistics - assert stats['records_added'] == 1 - assert stats['pair_count'] == 1 - + assert stats["records_added"] == 1 + assert stats["pair_count"] == 1 + def test_add_batch(self): storage = ConnectionStorage(retention_seconds=3600) - + records = [ { - 'timestamp_ns': 1000000000 + i * 1000000, - 'timestamp_utc': f"2024-01-01T00:00:0{i}Z", - 'src_ip': "192.168.1.100", - 'dst_ip': "10.0.0.1", - 'src_port': 54321 + i, - 'dst_port': 443, - 'packet_size': 1500, - 'protocol': 6, - 'protocol_name': "TCP", - 'tcp_flags': 0x10, - 'direction': 1, - 'node_id': "test-node", - 'connection_key': f"192.168.1.100:{54321+i}->10.0.0.1:443/TCP" + "timestamp_ns": 1000000000 + i * 1000000, + "timestamp_utc": f"2024-01-01T00:00:0{i}Z", + "src_ip": "192.168.1.100", + "dst_ip": "10.0.0.1", + "src_port": 54321 + i, + "dst_port": 443, + "packet_size": 1500, + "protocol": 6, + "protocol_name": "TCP", + "tcp_flags": 0x10, + "direction": 1, + "node_id": "test-node", + "connection_key": f"192.168.1.100:{54321+i}->10.0.0.1:443/TCP", } for i in range(10) ] - + storage.add_batch(records) - + stats = storage.statistics - assert stats['records_added'] == 10 - + assert stats["records_added"] == 10 + def test_get_pairs_by_src(self): storage = ConnectionStorage() - + for i in range(5): record = ConnectionRecord( timestamp_ns=1000000000 + i * 1000000, @@ -84,20 +88,20 @@ def test_get_pairs_by_src(self): tcp_flags=0x10, direction=1, node_id="test-node", - connection_key=f"192.168.1.100:54321->10.0.0.{i}:443/TCP" + connection_key=f"192.168.1.100:54321->10.0.0.{i}:443/TCP", ) storage.add_record(record) - + pairs = storage.get_pairs_by_src("192.168.1.100") assert len(pairs) == 5 - + pairs = storage.get_pairs_by_src("192.168.1.200") assert len(pairs) == 0 - + def test_get_analyzable_pairs(self): storage = ConnectionStorage() - + for i in range(15): record = ConnectionRecord( timestamp_ns=1000000000 + i * 60000000000, # 60s apart @@ -112,10 +116,10 @@ def test_get_analyzable_pairs(self): tcp_flags=0x10, direction=1, node_id="test-node", - connection_key="192.168.1.100:54321->10.0.0.1:443/TCP" + connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", ) storage.add_record(record) - + for i in range(3): record = ConnectionRecord( timestamp_ns=1000000000 + i * 60000000000, @@ -130,139 +134,128 @@ def test_get_analyzable_pairs(self): tcp_flags=0x10, direction=1, node_id="test-node", - connection_key="192.168.1.101:54322->10.0.0.2:80/TCP" + connection_key="192.168.1.101:54322->10.0.0.2:80/TCP", ) storage.add_record(record) - + pairs = storage.get_analyzable_pairs(min_connections=10, min_duration=60) assert len(pairs) == 1 assert pairs[0].src_ip == "192.168.1.100" class TestConnectionPair: - + def test_get_intervals(self): pair = ConnectionPair( - src_ip="192.168.1.100", - dst_ip="10.0.0.1", - dst_port=443, - protocol="TCP" + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" ) - + for i in range(5): pair.timestamps.append(1000.0 + i * 60.0) - + intervals = pair.get_intervals() - + assert len(intervals) == 4 assert all(i == 60.0 for i in intervals) - + def test_duration(self): pair = ConnectionPair( - src_ip="192.168.1.100", - dst_ip="10.0.0.1", - dst_port=443, - protocol="TCP" + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" ) - + pair.first_seen = 1000.0 pair.last_seen = 2000.0 - + assert pair.duration_seconds == 1000.0 - + def test_prune_old(self): pair = ConnectionPair( - src_ip="192.168.1.100", - dst_ip="10.0.0.1", - dst_port=443, - protocol="TCP" + src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" ) - - base_time = time.time() - 7200 + + base_time = time.time() - 7200 for i in range(10): - pair.timestamps.append(base_time + i * 600) + pair.timestamps.append(base_time + i * 600) pair.first_seen = pair.timestamps[0] pair.last_seen = pair.timestamps[-1] - + cutoff = time.time() - 3600 pair.prune_old(cutoff) - + assert pair.connection_count < 10 class TestAnalyzerConfig: - + def test_default_config(self): config = AnalyzerConfig() - + assert config.analysis_interval == 60 assert config.min_connections == 10 assert config.min_duration == 300.0 assert config.alert_cooldown == 300 - + def test_custom_config(self): config = AnalyzerConfig( - analysis_interval=120, - min_connections=20, - alert_cooldown=600 + analysis_interval=120, min_connections=20, alert_cooldown=600 ) - + assert config.analysis_interval == 120 assert config.min_connections == 20 assert config.alert_cooldown == 600 class TestAnalysisRun: - + def test_analysis_run_creation(self): run = AnalysisRun("test-run-1") - + assert run.run_id == "test-run-1" assert run.pairs_analyzed == 0 assert run.beacons_detected == 0 assert run.end_time is None - + def test_analysis_run_completion(self): run = AnalysisRun("test-run-1") run.pairs_analyzed = 100 run.beacons_detected = 2 run.alerts_generated = 2 - + run.complete() - + assert run.end_time is not None assert run.duration_seconds >= 0 - + def test_to_dict(self): run = AnalysisRun("test-run-1") run.pairs_analyzed = 50 run.complete() - + d = run.to_dict() - - assert d['run_id'] == "test-run-1" - assert d['pairs_analyzed'] == 50 - assert 'start_time' in d - assert 'end_time' in d + + assert d["run_id"] == "test-run-1" + assert d["pairs_analyzed"] == 50 + assert "start_time" in d + assert "end_time" in d class TestConnectionAnalyzer: - + def setup_method(self): self.storage = ConnectionStorage() self.detector = BeaconDetector(DetectorConfig(min_connections=5)) self.alert_manager = Mock(spec=AlertManager) self.alert_manager.send_alert = Mock() - + self.analyzer = ConnectionAnalyzer( storage=self.storage, detector=self.detector, @@ -271,17 +264,17 @@ def setup_method(self): analysis_interval=60, min_connections=5, min_duration=60, - alert_cooldown=60 - ) + alert_cooldown=60, + ), ) - + def test_run_analysis_empty_storage(self): run = self.analyzer.run_analysis() - + assert run.pairs_analyzed == 0 assert run.beacons_detected == 0 - + def test_run_analysis_with_data(self): base_time = time.time() @@ -299,54 +292,54 @@ def test_run_analysis_with_data(self): tcp_flags=0x10, direction=1, node_id="test-node", - connection_key="192.168.1.100:54321->10.0.0.1:443/TCP" + connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", ) record.timestamp_epoch = base_time + i * 60 self.storage.add_record(record) - + run = self.analyzer.run_analysis() - + assert run.pairs_analyzed >= 1 - + def test_alert_cooldown(self): pair_key = "192.168.1.100->10.0.0.1:443/TCP" self.analyzer._alert_cooldowns[pair_key] = time.time() - + mock_result = Mock(spec=DetectionResult) mock_result.pair_key = pair_key mock_result.combined_score = 0.9 mock_result.confidence = BeaconConfidence.HIGH - + self.analyzer._known_beacons[pair_key] = mock_result - + should_alert = self.analyzer._should_alert(mock_result) assert not should_alert - + def test_get_known_beacons(self): mock_result = Mock(spec=DetectionResult) mock_result.pair_key = "test-pair" - + self.analyzer._known_beacons["test-pair"] = mock_result - + beacons = self.analyzer.get_known_beacons() - + assert len(beacons) == 1 - + def test_statistics(self): stats = self.analyzer.statistics - - assert 'running' in stats - assert 'analysis_interval' in stats - assert 'total_runs' in stats - assert 'current_known_beacons' in stats + + assert "running" in stats + assert "analysis_interval" in stats + assert "total_runs" in stats + assert "current_known_beacons" in stats class TestAlertManager: - + def test_alert_creation(self): alert = Alert( @@ -354,13 +347,13 @@ def test_alert_creation(self): title="Test Alert", description="This is a test alert", severity=AlertSeverity.HIGH, - source="test" + source="test", ) - + assert alert.alert_id == "test-alert-1" assert alert.severity == AlertSeverity.HIGH assert alert.timestamp is not None - + def test_alert_to_dict(self): alert = Alert( @@ -369,35 +362,35 @@ def test_alert_to_dict(self): description="This is a test alert", severity=AlertSeverity.CRITICAL, source="test", - details={'key': 'value'} + details={"key": "value"}, ) - + d = alert.to_dict() - - assert d['alert_id'] == "test-alert-1" - assert d['severity'] == 'critical' - assert d['details'] == {'key': 'value'} - + + assert d["alert_id"] == "test-alert-1" + assert d["severity"] == "critical" + assert d["details"] == {"key": "value"} + def test_alert_to_syslog(self): alert = Alert( alert_id="test-alert-1", title="Beacon Detected", description="Beaconing detected from 192.168.1.100", severity=AlertSeverity.HIGH, - source="beacon_detector" + source="beacon_detector", ) - + msg = alert.to_syslog_message() - + assert "[HIGH]" in msg assert "Beacon Detected" in msg class TestAlertSeverity: - + def test_syslog_priority_mapping(self): import logging - + assert AlertSeverity.INFO.syslog_priority == logging.INFO assert AlertSeverity.LOW.syslog_priority == logging.WARNING assert AlertSeverity.MEDIUM.syslog_priority == logging.WARNING @@ -405,5 +398,5 @@ def test_syslog_priority_mapping(self): assert AlertSeverity.CRITICAL.syslog_priority == logging.CRITICAL -if __name__ == '__main__': - pytest.main([__file__, '-v']) +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_detector.py b/tests/test_detector.py index f687c48..fd7ee51 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -2,63 +2,62 @@ Tests for Beacon Detection Algorithms Run with: pytest tests/test_detector.py -v """ + import random -import pytest import sys from pathlib import Path +import pytest + # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) +from control_plane.detector import (BeaconConfidence, BeaconDetector, + DetectionResult, DetectorConfig, + IntervalStats, PeriodicityResult) from control_plane.storage import ConnectionPair -from control_plane.detector import ( - BeaconDetector, - DetectorConfig, - DetectionResult, - BeaconConfidence, - IntervalStats, - PeriodicityResult -) class TestIntervalStats: - + def test_regular_intervals(self): detector = BeaconDetector() intervals = [60.0] * 20 # 20 intervals of 60 seconds each - + stats = detector._calculate_interval_stats(intervals) - + assert stats.count == 20 assert stats.mean == 60.0 assert stats.std_dev == 0.0 assert stats.cv == 0.0 assert stats.median == 60.0 assert stats.jitter == 0.0 - + def test_irregular_intervals(self): detector = BeaconDetector() intervals = [10.0, 50.0, 30.0, 70.0, 40.0] # Highly variable - + stats = detector._calculate_interval_stats(intervals) - + assert stats.count == 5 assert stats.mean == 40.0 assert stats.cv >= 0.5 # High coefficient of variation assert stats.min_interval == 10.0 assert stats.max_interval == 70.0 - + def test_small_jitter(self): detector = BeaconDetector() base = 60.0 jitter_range = 2.0 - intervals = [base + random.uniform(-jitter_range, jitter_range) for _ in range(50)] - + intervals = [ + base + random.uniform(-jitter_range, jitter_range) for _ in range(50) + ] + stats = detector._calculate_interval_stats(intervals) - + assert stats.count == 50 assert 58.0 < stats.mean < 62.0 assert stats.cv < 0.1 # Low CV due to small jitter @@ -66,19 +65,19 @@ def test_small_jitter(self): class TestCVScore: - + def test_zero_cv_max_score(self): detector = BeaconDetector() score = detector._calculate_cv_score(0.0) assert score == 1.0 - + def test_high_cv_low_score(self): detector = BeaconDetector() score = detector._calculate_cv_score(1.0) assert score < 0.1 - + def test_threshold_cv_medium_score(self): config = DetectorConfig(cv_threshold=0.15) @@ -89,19 +88,18 @@ def test_threshold_cv_medium_score(self): class TestPeriodicityAnalysis: - def test_perfectly_periodic(self): detector = BeaconDetector() # Create perfectly periodic intervals intervals = [60.0] * 30 - + result = detector._analyze_periodicity(intervals) - + # Perfect periodicity should have low score (no variation to detect) # But our implementation should handle this edge case assert isinstance(result, PeriodicityResult) - + def test_periodic_with_noise(self): detector = BeaconDetector() @@ -109,40 +107,40 @@ def test_periodic_with_noise(self): base = 60.0 noise = 3.0 intervals = [base + random.gauss(0, noise) for _ in range(50)] - + result = detector._analyze_periodicity(intervals) - + assert isinstance(result, PeriodicityResult) # Should have some periodicity detected - + def test_random_no_periodicity(self): detector = BeaconDetector() # Create random intervals intervals = [random.uniform(10, 300) for _ in range(50)] - + result = detector._analyze_periodicity(intervals) - + assert isinstance(result, PeriodicityResult) # Random data should have low periodicity score assert result.periodicity_score < 0.5 class TestJitterScore: - + def test_zero_jitter_max_score(self): detector = BeaconDetector() score = detector._calculate_jitter_score(0.0) assert score == 1.0 - + def test_threshold_jitter(self): config = DetectorConfig(jitter_threshold=5.0) detector = BeaconDetector(config) score = detector._calculate_jitter_score(5.0) assert 0.4 < score < 0.6 - + def test_high_jitter_low_score(self): config = DetectorConfig(jitter_threshold=5.0) @@ -152,33 +150,30 @@ def test_high_jitter_low_score(self): class TestBeaconDetection: - + def create_connection_pair( self, intervals: list, src_ip: str = "192.168.1.100", dst_ip: str = "10.0.0.1", - dst_port: int = 443 + dst_port: int = 443, ): pair = ConnectionPair( - src_ip=src_ip, - dst_ip=dst_ip, - dst_port=dst_port, - protocol="TCP" + src_ip=src_ip, dst_ip=dst_ip, dst_port=dst_port, protocol="TCP" ) - + # Generate timestamps from intervals timestamp = 1000000.0 pair.timestamps.append(timestamp) for interval in intervals: timestamp += interval pair.timestamps.append(timestamp) - + pair.first_seen = pair.timestamps[0] pair.last_seen = pair.timestamps[-1] - + return pair - + def test_detect_regular_beacon(self): config = DetectorConfig( @@ -186,138 +181,147 @@ def test_detect_regular_beacon(self): cv_threshold=0.15, periodicity_threshold=0.5, jitter_threshold=5.0, - alert_threshold=0.6 + alert_threshold=0.6, ) detector = BeaconDetector(config) - + # Create regular beacon pattern (60s intervals with small jitter) intervals = [60.0 + random.uniform(-1, 1) for _ in range(30)] pair = self.create_connection_pair(intervals) - + result = detector.analyze(pair) - + assert result is not None assert result.cv_score > 0.7 # High score for regular intervals assert result.combined_score > 0.5 # Note: May or may not trigger is_beacon depending on exact jitter - + def test_detect_random_traffic(self): config = DetectorConfig( - min_connections=10, - cv_threshold=0.15, - alert_threshold=0.7 + min_connections=10, cv_threshold=0.15, alert_threshold=0.7 ) detector = BeaconDetector(config) - + # Create random traffic pattern intervals = [random.uniform(5, 300) for _ in range(30)] pair = self.create_connection_pair(intervals) - + result = detector.analyze(pair) - + assert result is not None assert result.cv_score < 0.5 # Low score for irregular intervals assert not result.is_beacon - + def test_insufficient_data(self): config = DetectorConfig(min_connections=20) detector = BeaconDetector(config) - + # Only 10 connections (need 20) intervals = [60.0] * 9 pair = self.create_connection_pair(intervals) - + result = detector.analyze(pair) - + assert result is None - + def test_confidence_levels(self): detector = BeaconDetector() - + # Test various score combinations - assert detector._determine_confidence(0.1, 0.1, 0.1, 0.1) == BeaconConfidence.NONE - assert detector._determine_confidence(0.4, 0.4, 0.4, 0.4) == BeaconConfidence.LOW - assert detector._determine_confidence(0.6, 0.6, 0.6, 0.6) == BeaconConfidence.MEDIUM - assert detector._determine_confidence(0.8, 0.8, 0.8, 0.8) == BeaconConfidence.HIGH - + assert ( + detector._determine_confidence(0.1, 0.1, 0.1, 0.1) == BeaconConfidence.NONE + ) + assert ( + detector._determine_confidence(0.4, 0.4, 0.4, 0.4) == BeaconConfidence.LOW + ) + assert ( + detector._determine_confidence(0.6, 0.6, 0.6, 0.6) + == BeaconConfidence.MEDIUM + ) + assert ( + detector._determine_confidence(0.8, 0.8, 0.8, 0.8) == BeaconConfidence.HIGH + ) + def test_batch_analyze(self): detector = BeaconDetector(DetectorConfig(min_connections=5)) - + # Create multiple pairs pairs = [] - + # Regular beacon - pairs.append(self.create_connection_pair( - [60.0 + random.uniform(-1, 1) for _ in range(20)], - src_ip="192.168.1.100", - dst_ip="10.0.0.1" - )) - + pairs.append( + self.create_connection_pair( + [60.0 + random.uniform(-1, 1) for _ in range(20)], + src_ip="192.168.1.100", + dst_ip="10.0.0.1", + ) + ) + # Random traffic - pairs.append(self.create_connection_pair( - [random.uniform(5, 300) for _ in range(20)], - src_ip="192.168.1.101", - dst_ip="10.0.0.2" - )) - + pairs.append( + self.create_connection_pair( + [random.uniform(5, 300) for _ in range(20)], + src_ip="192.168.1.101", + dst_ip="10.0.0.2", + ) + ) + # Another beacon with different interval - pairs.append(self.create_connection_pair( - [120.0 + random.uniform(-2, 2) for _ in range(20)], - src_ip="192.168.1.102", - dst_ip="10.0.0.3" - )) - + pairs.append( + self.create_connection_pair( + [120.0 + random.uniform(-2, 2) for _ in range(20)], + src_ip="192.168.1.102", + dst_ip="10.0.0.3", + ) + ) + results = detector.batch_analyze(pairs) - + assert len(results) == 3 # Results should be sorted by score descending assert results[0].combined_score >= results[-1].combined_score class TestDetectorConfig: - + def test_default_config(self): config = DetectorConfig() - + assert config.min_connections == 10 assert config.cv_threshold == 0.15 assert config.periodicity_threshold == 0.7 assert config.jitter_threshold == 5.0 assert config.alert_threshold == 0.7 - + def test_custom_config(self): config = DetectorConfig( - min_connections=20, - cv_threshold=0.1, - alert_threshold=0.8 + min_connections=20, cv_threshold=0.1, alert_threshold=0.8 ) - + assert config.min_connections == 20 assert config.cv_threshold == 0.1 assert config.alert_threshold == 0.8 - + def test_weight_validation(self): # Weights should sum to 1.0 config = DetectorConfig( - cv_weight=0.5, - periodicity_weight=0.3, - jitter_weight=0.2 + cv_weight=0.5, periodicity_weight=0.3, jitter_weight=0.2 ) detector = BeaconDetector(config) - + # Should not raise any warnings total = config.cv_weight + config.periodicity_weight + config.jitter_weight assert abs(total - 1.0) < 0.01 class TestDetectionResult: - + def test_to_dict(self): # Create a mock result @@ -329,16 +333,16 @@ def test_to_dict(self): median=60.0, min_interval=58.0, max_interval=62.0, - jitter=2.0 + jitter=2.0, ) - + periodicity_result = PeriodicityResult( is_periodic=True, dominant_period=60.0, periodicity_score=0.8, - frequency_peaks=[(0.0167, 0.8)] + frequency_peaks=[(0.0167, 0.8)], ) - + result = DetectionResult( pair_key="192.168.1.100->10.0.0.1:443/TCP", src_ip="192.168.1.100", @@ -356,17 +360,17 @@ def test_to_dict(self): connection_count=21, duration_seconds=1200.0, first_seen="2024-01-01T00:00:00Z", - last_seen="2024-01-01T00:20:00Z" + last_seen="2024-01-01T00:20:00Z", ) - + d = result.to_dict() - - assert d['pair_key'] == "192.168.1.100->10.0.0.1:443/TCP" - assert d['is_beacon'] == True - assert d['confidence'] == 'high' - assert 'interval_stats' in d - assert 'periodicity_result' in d + + assert d["pair_key"] == "192.168.1.100->10.0.0.1:443/TCP" + assert d["is_beacon"] == True + assert d["confidence"] == "high" + assert "interval_stats" in d + assert "periodicity_result" in d -if __name__ == '__main__': - pytest.main([__file__, '-v']) +if __name__ == "__main__": + pytest.main([__file__, "-v"])