From 8e522950b05e7567cb4375fcfdd3d0ab6e5d142d Mon Sep 17 00:00:00 2001 From: litemars Date: Mon, 27 Apr 2026 21:03:14 +0200 Subject: [PATCH 1/2] update --- .gitignore | 1 + CONTRIBUTING.md | 26 +-- config/config.yaml | 161 ++++++++++----- control_plane/alerter.py | 28 ++- control_plane/analyzer.py | 194 +++++++++++------- control_plane/detector.py | 403 +++++++++++++++++++++++++++--------- control_plane/server.py | 125 ++++++++--- control_plane/storage.py | 16 +- data_plane/collector.py | 33 ++- data_plane/ebpf_program.c | 42 ++-- data_plane/telemetry.py | 32 ++- requirements.txt | 3 - tests/test_analyzer.py | 272 ++++++++++++++---------- tests/test_detector.py | 421 +++++++++++++++++++++----------------- 14 files changed, 1155 insertions(+), 602 deletions(-) diff --git a/.gitignore b/.gitignore index 7d2933e..37c827d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ venv vevn env ENV +* .venv # Python diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 608b60c..3068e2a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,8 +7,8 @@ Thank you for your interest in contributing! This document provides guidelines a 1. **Fork the repository** on GitHub 2. **Clone your fork** locally: ```bash - git clone https://github.com/yourusername/BeaconDetectionSystemGit.git - cd BeaconDetectionSystemGit + git clone https://github.com/litemars/BeaconDetectionSystem.git + cd BeaconDetectionSystem ``` 3. **Set up the development environment**: @@ -34,7 +34,6 @@ Thank you for your interest in contributing! This document provides guidelines a 3. **Run tests and linting**: ```bash pytest - pylint control_plane data_plane black --check control_plane data_plane ``` @@ -45,24 +44,3 @@ Thank you for your interest in contributing! This document provides guidelines a 5. **Create a Pull Request** with a clear description of your changes -## Testing - -- Write tests for new features in the `tests/` directory -- Ensure all tests pass: `pytest` -- Aim for high test coverage, especially for critical components -- Use descriptive test names that explain what is being tested - -## Reporting Issues - -- Use GitHub Issues for bug reports and feature requests -- Provide clear steps to reproduce bugs -- Include relevant system information and logs - -## Documentation - -- Update relevant documentation for new features -- Keep the README.md up to date - -## Questions? - -Feel free to open an issue or discussion for questions. diff --git a/config/config.yaml b/config/config.yaml index ca521ea..c9c6c97 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -56,48 +56,46 @@ control_plane: # ============================================================================= detection: # Minimum number of connection events required before analysis - # Too low = false positives, too high = miss short beacon sessions min_connections: 10 - - # Analysis time window in seconds - # Connection pairs are analyzed over this sliding window + + # Minimum pair duration before analysis (seconds) + # A pair that has existed for less than this is not yet analyzed. + # At least 2x the expected beacon interval is recommended. + min_duration: 300 + + # Analysis time window in seconds (enforced at analysis time) + # Timestamps older than last_seen - time_window are excluded from scoring. time_window: 3600 # 1 hour - + # Coefficient of Variation (CV) threshold # CV = std_dev / mean of intervals - # Lower values = more regular intervals (beacon-like) - # Typical beacons: CV < 0.1 - # Normal traffic: CV > 0.3 + # Typical beacons: CV < 0.1 — Normal traffic: CV > 0.3 cv_threshold: 0.15 - + # FFT periodicity score threshold (0.0 to 1.0) - # Higher values indicate stronger periodic patterns - # Typical beacons: > 0.7 periodicity_threshold: 0.7 - - # Maximum allowed jitter in seconds - # Jitter = max deviation from median interval - # Lower jitter = more beacon-like + + # Maximum allowed jitter in seconds (max deviation from median interval) jitter_threshold: 5.0 - - # Minimum beacon interval to detect (seconds) - # Intervals shorter than this are ignored (likely not beacons) - min_beacon_interval: 10 - - # Maximum beacon interval to detect (seconds) - # Intervals longer than this are ignored - max_beacon_interval: 3600 - - # Scoring weights for combined detection - # Sum should equal 1.0 - cv_weight: 0.4 - periodicity_weight: 0.4 - jitter_weight: 0.2 - - # Final score threshold for alert (0.0 to 1.0) + + # Interval bounds + min_beacon_interval: 10 # seconds + max_beacon_interval: 3600 # seconds + + # Scoring weights — must sum to 1.0 + # cv: temporal regularity (coefficient of variation) + # periodicity: FFT-derived frequency dominance + # jitter: max single-interval deviation from median + # size: packet-size consistency (low CV of payload sizes) + cv_weight: 0.35 + periodicity_weight: 0.35 + jitter_weight: 0.15 + size_weight: 0.15 + + # Combined score threshold for beacon classification (0.0 to 1.0) alert_threshold: 0.7 - - # Cooldown period between alerts for same connection pair (seconds) + + # Cooldown between alerts for the same connection pair (seconds) alert_cooldown: 300 # ============================================================================= @@ -155,23 +153,92 @@ logging: # WHITELIST CONFIGURATION # ============================================================================= whitelist: - # IP addresses to exclude from analysis - # Use for known legitimate periodic services (NTP, health checks, etc.) + # IP/port allowlist applied in the data plane (before telemetry export). source_ips: [] # - "10.0.0.1" - # - "192.168.1.1" - destination_ips: [] - # - "8.8.8.8" # Google DNS - # - "1.1.1.1" # Cloudflare DNS - - # Destination ports to exclude + # - "8.8.8.8" ports: [] - # - 53 # DNS - # - 123 # NTP - # - 443 # HTTPS (optional, may want to monitor this) - - # Specific source:destination pairs to exclude - # Format: "src_ip:dst_ip:dst_port" + # - 443 pairs: [] # - "10.0.0.5:192.168.1.1:80" + +# ============================================================================= +# BENIGN TRAFFIC BASELINE +# ============================================================================= +# Applied in the analysis stage (control plane), before full scoring. +# Pairs matching a pattern are suppressed and do not generate alerts. +# Add entries for known-legitimate periodic services in your environment. +# +# Security note: DNS (UDP/53) is intentionally excluded from defaults because +# DNS-over-UDP is a common C2 channel. Add it only if you have high confidence +# in your resolver infrastructure. +benign_baseline: + enabled: true + # + # Each pattern suppresses connection pairs whose destination port (and + # optionally protocol) match. Suppressed pairs are logged with the label + # but never reach the FFT scorer. + # + # Guidance per service class: + # + # NTP — UDP/123. Safe to suppress; packet size (48 B) and interval + # (~64 s) are very distinct from C2 beacons. + # + # DNS — UDP/53 and TCP/53. EXCLUDED from defaults: DNS-over-UDP is a + # well-known C2 exfiltration channel. Add only when you have + # high confidence in your resolver infrastructure and DNS logs + # are monitored by a separate analytic. + # + # OCSP/CRL — TCP/80 toward PKI responders (Microsoft, DigiCert, etc.). + # Port alone is too broad; suppress only when specific responder + # IPs are known and stable. Uncomment + restrict dst_ip if needed. + # + # Windows Update — HTTPS (TCP/443) to update.microsoft.com, etc. + # Port too broad. Suppress only with specific IP or FQDN + # allowlisting at the network layer. + # + # Browser telemetry — TCP/443 at browser-specific intervals (~30–120 s). + # No single port identifies these; handled best by allowlisting + # the destination IP ranges rather than suppression here. + # + patterns: + # --- Always-on defaults --- + - dst_port: 123 + protocol: UDP + label: NTP + + # --- Commonly added in enterprise environments --- + # - dst_port: 53 + # protocol: UDP + # label: DNS-UDP # add only when DNS C2 risk is mitigated + # - dst_port: 53 + # protocol: TCP + # label: DNS-TCP # TCP/53 used for zone transfers and large responses + # - dst_port: 8125 + # protocol: UDP + # label: StatsD # metrics collection daemon + # - dst_port: 8126 + # protocol: TCP + # label: StatsD-mgmt + + # --- OCSP/CRL (uncomment when responder IPs are known and stable) --- + # - dst_port: 80 + # protocol: TCP + # label: OCSP-HTTP # certificate revocation checks over HTTP + + # --- Windows Update (too broad by port alone; see guidance above) --- + # - dst_port: 443 + # protocol: TCP + # label: WindowsUpdate # DANGEROUS: port 443 is C2-relevant; only add + # # with destination IP allowlisting at firewall + +# ============================================================================= +# API SECURITY +# ============================================================================= +# Set api_key to a non-empty secret to require X-API-Key authentication on +# all write endpoints (POST /api/v1/telemetry, POST /api/v1/config, DELETE *). +# Leave empty to disable authentication (default, suitable for localhost-only). +# Read-only GET endpoints are always unauthenticated. +# +# Example: api_key: "change-me-before-production" diff --git a/control_plane/alerter.py b/control_plane/alerter.py index 23a5a3a..33b5136 100644 --- a/control_plane/alerter.py +++ b/control_plane/alerter.py @@ -91,13 +91,39 @@ def to_json(self): return json.dumps(self.to_dict(), indent=2) def to_syslog_message(self): - return ( + """Return a single-line syslog message. + + When the alert's ``details`` dict contains an ``explanation`` with a + ``contributing_signals`` list (populated by BeaconDetector), the per-signal + scores and weights are appended for structured SIEM ingestion. + """ + msg = ( f"[{self.severity.value.upper()}] {self.title} | " f"Source: {self.source} | " f"ID: {self.alert_id} | " f"Description: {self.description}" ) + # Append per-signal scores if available + try: + signals = ( + self.details.get("explanation", {}).get("contributing_signals", []) + if isinstance(self.details, dict) + else [] + ) + if signals: + parts = [ + f"{s['name']}={s['score']:.2f}(w={s['weight']})" + for s in signals + if isinstance(s, dict) and "name" in s and "score" in s + ] + if parts: + msg += f" | Signals: {', '.join(parts)}" + except Exception: + pass # never let formatting failure suppress the syslog write + + return msg + @dataclass class AlertingConfig: diff --git a/control_plane/analyzer.py b/control_plane/analyzer.py index 2adc5c8..809e1c5 100644 --- a/control_plane/analyzer.py +++ b/control_plane/analyzer.py @@ -1,34 +1,63 @@ import logging import threading import time -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Dict, List +from typing import Dict, List, Optional from .alerter import Alert, AlertManager, AlertSeverity from .detector import BeaconDetector, DetectionResult -from .storage import ConnectionStorage +from .storage import ConnectionPair, ConnectionStorage logger = logging.getLogger("beacon_detect.control_plane.analyzer") @dataclass -class AnalyzerConfig: - # How often to run analysis (seconds) - analysis_interval: int = 60 +class BenignPattern: + """A known-benign periodic traffic pattern. + + A pair is suppressed when dst_port matches (and protocol matches if set). + The suppression is logged with the label so analysts can audit the decision. + """ + + dst_port: int + protocol: Optional[str] = None # "TCP", "UDP", or None (matches both) + label: str = "benign" + + def matches(self, pair: ConnectionPair) -> bool: + if pair.dst_port != self.dst_port: + return False + if self.protocol is not None and pair.protocol.upper() != self.protocol.upper(): + return False + return True - # Minimum connections required for analysis - min_connections: int = 10 - # Minimum duration for a pair to be analyzed - min_duration: float = 300.0 # 5 minutes +# Default benign baseline applied when benign_baseline.enabled is true +# and no custom patterns are configured. +DEFAULT_BENIGN_PATTERNS: List[BenignPattern] = [ + BenignPattern(dst_port=123, protocol="UDP", label="NTP"), +] - # Alert cooldown - don't re-alert same pair within this time - alert_cooldown: int = 300 # 5 minutes - # Maximum pairs to analyze per run (for performance) +@dataclass +class AnalyzerConfig: + + analysis_interval: int = 60 + min_connections: int = 10 + min_duration: float = 300.0 + alert_cooldown: int = 300 max_pairs_per_run: int = 10000 + # Benign traffic suppression + benign_baseline_enabled: bool = True + benign_patterns: List[BenignPattern] = field(default_factory=list) + + def get_effective_benign_patterns(self) -> List[BenignPattern]: + """Return user-configured patterns, falling back to defaults.""" + if not self.benign_baseline_enabled: + return [] + return self.benign_patterns if self.benign_patterns else DEFAULT_BENIGN_PATTERNS + class AnalysisRun: @@ -37,18 +66,17 @@ def __init__(self, run_id: str): self.start_time = datetime.now(timezone.utc) self.end_time = None self.pairs_analyzed = 0 + self.pairs_suppressed = 0 self.beacons_detected = 0 self.alerts_generated = 0 self.errors = 0 self.results: List[DetectionResult] = [] def complete(self): - self.end_time = datetime.now(timezone.utc) @property def duration_seconds(self): - end = self.end_time or datetime.now(timezone.utc) return (end - self.start_time).total_seconds() @@ -59,6 +87,7 @@ def to_dict(self): "end_time": self.end_time.isoformat() + "Z" if self.end_time else None, "duration_seconds": float(round(self.duration_seconds, 2)), "pairs_analyzed": int(self.pairs_analyzed), + "pairs_suppressed": int(self.pairs_suppressed), "beacons_detected": int(self.beacons_detected), "alerts_generated": int(self.alerts_generated), "errors": int(self.errors), @@ -74,44 +103,26 @@ def __init__( alert_manager: AlertManager, config=None, ): - """ - Initialize the analyzer. - - Args: - storage: ConnectionStorage instance - detector: BeaconDetector instance - alert_manager: AlertManager instance - config: Analyzer configuration - """ self.storage = storage self.detector = detector self.alert_manager = alert_manager self.config = config or AnalyzerConfig() - # Thread safety lock for shared state self._lock = threading.RLock() - - # Track alert cooldowns: pair_key -> last_alert_time self._alert_cooldowns: Dict[str, float] = {} - - # Track known beacons for monitoring self._known_beacons: Dict[str, DetectionResult] = {} - - # Analysis run history self._run_history: List[AnalysisRun] = [] self._max_run_history = 100 - # Running state self._running = False self._analysis_thread = None self._stop_event = threading.Event() - # Statistics self._total_runs = 0 self._total_beacons_detected = 0 self._total_alerts_generated = 0 + self._total_suppressed = 0 - # Run counter for IDs self._run_counter = 0 logger.info( @@ -131,7 +142,6 @@ def start(self): target=self._analysis_loop, daemon=True ) self._analysis_thread.start() - logger.info("ConnectionAnalyzer started") def stop(self): @@ -156,7 +166,23 @@ def _analysis_loop(self): logger.info("Analysis loop stopped") - def run_analysis(self): + # ------------------------------------------------------------------ + # Benign traffic suppression + # ------------------------------------------------------------------ + + def _get_suppression_reason(self, pair: ConnectionPair) -> Optional[str]: + """Return a human-readable suppression label if this pair matches a + benign baseline pattern, or None if it should be analyzed normally.""" + for pattern in self.config.get_effective_benign_patterns(): + if pattern.matches(pair): + return pattern.label + return None + + # ------------------------------------------------------------------ + # Main analysis run + # ------------------------------------------------------------------ + + def run_analysis(self) -> AnalysisRun: self._run_counter += 1 run_id = f"run-{self._run_counter}-{int(time.time())}" @@ -165,37 +191,46 @@ def run_analysis(self): logger.info(f"Starting analysis run: {run_id}") try: - # logger.info(f"self.config.min_connections {self.config.min_connections} and self.config.min_duration {self.config.min_duration}") - # Get analyzable pairs from storage pairs = self.storage.get_analyzable_pairs( min_connections=self.config.min_connections, min_duration=self.config.min_duration, ) - # logger.info(f"pairs{pairs}") - # Limit pairs for performance + if len(pairs) > self.config.max_pairs_per_run: logger.warning( f"Limiting analysis to {self.config.max_pairs_per_run} " f"pairs (total: {len(pairs)})" ) - # Prioritize pairs with more connections pairs.sort(key=lambda p: p.connection_count, reverse=True) pairs = pairs[: self.config.max_pairs_per_run] - run.pairs_analyzed = len(pairs) - logger.info(f"Analyzing {len(pairs)} connection pairs") + # Stage 1: candidate filter — suppress known-benign patterns + candidates = [] + for pair in pairs: + reason = self._get_suppression_reason(pair) + if reason: + run.pairs_suppressed += 1 + logger.debug( + f"Suppressed {pair.pair_key}: matched benign pattern '{reason}'" + ) + else: + candidates.append(pair) + + run.pairs_analyzed = len(candidates) + logger.info( + f"Analyzing {len(candidates)} pairs " + f"({run.pairs_suppressed} suppressed by benign baseline)" + ) - # Run detection - results = self.detector.batch_analyze(pairs) + # Stage 2: full scoring + results = self.detector.batch_analyze(candidates) run.results = results - # Process results beacons = [r for r in results if r.is_beacon] run.beacons_detected = len(beacons) logger.info(f"Detection complete: {len(beacons)} beacons found") - # Generate alerts for new/updated beacons for result in beacons: try: if self._should_alert(result): @@ -204,7 +239,6 @@ def run_analysis(self): with self._lock: self._alert_cooldowns[result.pair_key] = time.time() - # Update known beacons with self._lock: self._known_beacons[result.pair_key] = result @@ -212,7 +246,7 @@ def run_analysis(self): logger.error(f"Error generating alert for {result.pair_key}: {e}") run.errors += 1 - # Clean up cooldowns for pairs no longer seen as beacons + # Evict pairs that are no longer scored as beacons with self._lock: current_beacon_keys = {r.pair_key for r in beacons} stale_keys = [ @@ -227,45 +261,45 @@ def run_analysis(self): run.complete() - # Update statistics self._total_runs += 1 self._total_beacons_detected += run.beacons_detected self._total_alerts_generated += run.alerts_generated + self._total_suppressed += run.pairs_suppressed - # Store run history self._run_history.append(run) if len(self._run_history) > self._max_run_history: self._run_history = self._run_history[-self._max_run_history :] logger.info( - f"Analysis run complete: {run_id} - " - f"{run.pairs_analyzed} pairs, {run.beacons_detected} beacons, " - f"{run.alerts_generated} alerts, {run.duration_seconds:.2f}s" + f"Analysis run complete: {run_id} — " + f"{run.pairs_analyzed} analyzed, {run.pairs_suppressed} suppressed, " + f"{run.beacons_detected} beacons, {run.alerts_generated} alerts, " + f"{run.duration_seconds:.2f}s" ) return run - def _should_alert(self, result): + # ------------------------------------------------------------------ + # Alert gating + # ------------------------------------------------------------------ + + def _should_alert(self, result: DetectionResult) -> bool: + pair_key = result.pair_key with self._lock: - # Check cooldown last_alert = self._alert_cooldowns.get(pair_key, 0) if time.time() - last_alert < self.config.alert_cooldown: logger.debug(f"Skipping alert for {pair_key}: cooldown active") return False - # Check if this is a new detection or significant change previous = self._known_beacons.get(pair_key) if previous is None: - # New beacon return True - # Check for significant score increase if result.combined_score - previous.combined_score > 0.1: return True - # Check for confidence upgrade conf_order = ["none", "low", "medium", "high", "critical"] if conf_order.index(result.confidence.value) > conf_order.index( previous.confidence.value @@ -274,9 +308,8 @@ def _should_alert(self, result): return False - def _generate_alert(self, result): + def _generate_alert(self, result: DetectionResult): - # Map confidence to severity severity_map = { "none": AlertSeverity.INFO, "low": AlertSeverity.LOW, @@ -286,49 +319,58 @@ def _generate_alert(self, result): } severity = severity_map.get(result.confidence.value, AlertSeverity.MEDIUM) - # Create alert + explanation = result.explanation + interval_s = explanation.get("detected_interval_seconds", "?") + fft_period = explanation.get("dominant_fft_period_seconds") + fft_str = f", FFT period {fft_period}s" if fft_period else "" + alert = Alert( alert_id=f"beacon-{result.pair_key}-{int(time.time())}", title=f"Beaconing Detected: {result.src_ip} -> {result.dst_ip}:{result.dst_port}", description=( - f"Potential beaconing behavior detected between {result.src_ip} " - f"and {result.dst_ip}:{result.dst_port}/{result.protocol}. " - f"Detection confidence: {result.confidence.value.upper()}. " - f"Combined score: {result.combined_score:.3f}. " - f"Observed {result.connection_count} connections over " - f"{result.duration_seconds/60:.1f} minutes." + f"Potential beaconing between {result.src_ip} and " + f"{result.dst_ip}:{result.dst_port}/{result.protocol}. " + f"Confidence: {result.confidence.value.upper()}. " + f"Score: {result.combined_score:.3f}. " + f"Interval: ~{interval_s}s{fft_str}. " + f"{result.connection_count} connections over " + f"{result.duration_seconds / 60:.1f} min." ), severity=severity, source="beacon_detector", details=result.to_dict(), ) - # Send alert through alert manager self.alert_manager.send_alert(alert) logger.warning( - f"Alert generated: {alert.alert_id} - " + f"Alert generated: {alert.alert_id} — " f"{result.pair_key} (score={result.combined_score:.3f})" ) - def get_known_beacons(self): - """Get list of known beacons (thread-safe).""" + # ------------------------------------------------------------------ + # Public accessors + # ------------------------------------------------------------------ + + def get_known_beacons(self) -> List[DetectionResult]: with self._lock: return list(self._known_beacons.values()) - def get_run_history(self, limit: int = 10): + def get_run_history(self, limit: int = 10) -> List[dict]: runs = self._run_history[-limit:] return [r.to_dict() for r in reversed(runs)] @property - def statistics(self): - """Get analyzer statistics""" + def statistics(self) -> dict: return { "running": self._running, "analysis_interval": self.config.analysis_interval, "total_runs": self._total_runs, "total_beacons_detected": self._total_beacons_detected, "total_alerts_generated": self._total_alerts_generated, + "total_suppressed": self._total_suppressed, "current_known_beacons": len(self._known_beacons), "active_cooldowns": len(self._alert_cooldowns), + "benign_baseline_enabled": self.config.benign_baseline_enabled, + "benign_pattern_count": len(self.config.get_effective_benign_patterns()), } diff --git a/control_plane/detector.py b/control_plane/detector.py index 92139c5..ec06a7b 100644 --- a/control_plane/detector.py +++ b/control_plane/detector.py @@ -1,9 +1,10 @@ +import bisect import logging import math from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np from scipy import fft @@ -28,11 +29,11 @@ class IntervalStats: count: int mean: float std_dev: float - cv: float # Coefficient of variation + cv: float median: float min_interval: float max_interval: float - jitter: float # Max deviation from median + jitter: float # max deviation from median def to_dict(self): return { @@ -51,9 +52,9 @@ def to_dict(self): class PeriodicityResult: is_periodic: bool - dominant_period: float # In seconds + dominant_period: Optional[float] # seconds periodicity_score: float # 0.0 to 1.0 - frequency_peaks: List[Tuple[float, float]] # (frequency, magnitude) pairs + frequency_peaks: List[Tuple[float, float]] # (frequency, magnitude) def to_dict(self): return { @@ -78,10 +79,11 @@ class DetectionResult: dst_port: int protocol: str - # Detection scores (0.0 to 1.0, higher = more beacon-like) + # Detection scores (0.0–1.0, higher = more beacon-like) cv_score: float periodicity_score: float jitter_score: float + size_score: float combined_score: float # Detection outcome @@ -103,6 +105,9 @@ class DetectionResult: .replace("+00:00", "Z") ) + # Analyst-facing explanation (structured per-signal breakdown) + explanation: Dict = field(default_factory=dict) + def to_dict(self): return { "pair_key": str(self.pair_key), @@ -113,6 +118,7 @@ def to_dict(self): "cv_score": float(round(self.cv_score, 4)), "periodicity_score": float(round(self.periodicity_score, 4)), "jitter_score": float(round(self.jitter_score, 4)), + "size_score": float(round(self.size_score, 4)), "combined_score": float(round(self.combined_score, 4)), "is_beacon": bool(self.is_beacon), "confidence": str(self.confidence.value), @@ -123,6 +129,7 @@ def to_dict(self): "first_seen": str(self.first_seen), "last_seen": str(self.last_seen), "analysis_time": str(self.analysis_time), + "explanation": self.explanation, } @@ -131,7 +138,7 @@ class DetectorConfig: # Minimum data requirements min_connections: int = 10 - time_window: int = 3600 # seconds + time_window: int = 3600 # seconds — analysis window enforced at analysis time # CV threshold (lower = more regular, beacon-like) cv_threshold: float = 0.15 @@ -146,10 +153,11 @@ class DetectorConfig: min_beacon_interval: float = 10.0 # seconds max_beacon_interval: float = 3600.0 # seconds - # Score weights (should sum to 1.0) - cv_weight: float = 0.4 - periodicity_weight: float = 0.4 - jitter_weight: float = 0.2 + # Score weights (must sum to 1.0) + cv_weight: float = 0.35 + periodicity_weight: float = 0.35 + jitter_weight: float = 0.15 + size_weight: float = 0.15 # Final threshold for beacon classification alert_threshold: float = 0.7 @@ -161,19 +169,19 @@ def __init__(self, config=None): self.config = config or DetectorConfig() - # Validate weights total_weight = ( self.config.cv_weight + self.config.periodicity_weight + self.config.jitter_weight + + self.config.size_weight ) if not 0.99 <= total_weight <= 1.01: - logger.warning(f"Score weights sum to {total_weight}, should be 1.0") + logger.warning(f"Score weights sum to {total_weight:.4f}, should be 1.0") logger.info(f"BeaconDetector initialized with config: {self.config}") def analyze(self, pair: ConnectionPair): - # Check minimum data requirements + if pair.connection_count < self.config.min_connections: logger.debug( f"Insufficient connections for {pair.pair_key}: " @@ -181,12 +189,31 @@ def analyze(self, pair: ConnectionPair): ) return None - # Get intervals - intervals = pair.get_intervals() - if len(intervals) < self.config.min_connections - 1: + # Enforce time_window: slice timestamps to the last N seconds. + # pair.timestamps is bisect-maintained sorted, so we can binary-search. + if pair.last_seen is not None: + cutoff = pair.last_seen - self.config.time_window + start_idx = bisect.bisect_left(pair.timestamps, cutoff) + else: + start_idx = 0 + + windowed_ts = pair.timestamps[start_idx:] + # packet_sizes is kept index-aligned with timestamps (see storage.prune_old) + if len(pair.packet_sizes) == len(pair.timestamps): + windowed_sizes = pair.packet_sizes[start_idx:] + else: + windowed_sizes = pair.packet_sizes # safety fallback + + if len(windowed_ts) < self.config.min_connections: + logger.debug(f"Insufficient windowed connections for {pair.pair_key}") return None - # Filter intervals within bounds + # Compute intervals from the windowed timestamp slice + intervals = [ + windowed_ts[i] - windowed_ts[i - 1] for i in range(1, len(windowed_ts)) + ] + + # Filter intervals to the valid beacon range intervals = [ i for i in intervals @@ -197,34 +224,41 @@ def analyze(self, pair: ConnectionPair): logger.debug(f"Insufficient valid intervals for {pair.pair_key}") return None - # Calculate interval statistics interval_stats = self._calculate_interval_stats(intervals) - - # Calculate CV score (lower CV = higher score) cv_score = self._calculate_cv_score(interval_stats.cv) - - # Perform periodicity analysis periodicity_result = self._analyze_periodicity(intervals) + jitter_score = self._calculate_jitter_score( + interval_stats.jitter, interval_stats.median + ) + size_score = self._calculate_size_score(windowed_sizes) - # Calculate jitter score (lower jitter = higher score) - jitter_score = self._calculate_jitter_score(interval_stats.jitter) - # logger.info(f"interval {interval_stats}, cv_score {cv_score}, periodicity {periodicity_result}") - # Calculate combined score combined_score = ( self.config.cv_weight * cv_score + self.config.periodicity_weight * periodicity_result.periodicity_score + self.config.jitter_weight * jitter_score + + self.config.size_weight * size_score ) - # Determine if beacon is_beacon = combined_score >= self.config.alert_threshold - # Determine confidence level confidence = self._determine_confidence( - combined_score, cv_score, periodicity_result.periodicity_score, jitter_score + combined_score, + cv_score, + periodicity_result.periodicity_score, + jitter_score, + size_score, + ) + + explanation = self._build_explanation( + interval_stats=interval_stats, + periodicity_result=periodicity_result, + cv_score=cv_score, + jitter_score=jitter_score, + size_score=size_score, + windowed_sizes=windowed_sizes, + sample_count=len(intervals), ) - # Create result result = DetectionResult( pair_key=pair.pair_key, src_ip=pair.src_ip, @@ -234,6 +268,7 @@ def analyze(self, pair: ConnectionPair): cv_score=cv_score, periodicity_score=periodicity_result.periodicity_score, jitter_score=jitter_score, + size_score=size_score, combined_score=combined_score, is_beacon=is_beacon, confidence=confidence, @@ -251,6 +286,7 @@ def analyze(self, pair: ConnectionPair): if pair.last_seen else "" ), + explanation=explanation, ) if is_beacon: @@ -261,17 +297,17 @@ def analyze(self, pair: ConnectionPair): return result - def _calculate_interval_stats(self, intervals): - arr = np.array(intervals) + # ------------------------------------------------------------------ + # Scoring functions + # ------------------------------------------------------------------ + def _calculate_interval_stats(self, intervals: List[float]) -> IntervalStats: + + arr = np.array(intervals) mean = float(np.mean(arr)) std_dev = float(np.std(arr)) median = float(np.median(arr)) - - # Coefficient of variation cv = std_dev / mean if mean > 0 else float("inf") - - # Jitter = maximum deviation from median jitter = float(np.max(np.abs(arr - median))) return IntervalStats( @@ -285,23 +321,35 @@ def _calculate_interval_stats(self, intervals): jitter=jitter, ) - def _calculate_cv_score(self, cv: float): - + def _calculate_cv_score(self, cv: float) -> float: + """Sigmoid: score→1 when cv≪threshold, score→0 when cv≫threshold.""" if cv <= 0: return 1.0 - - # Sigmoid transformation - # At cv = threshold, score ≈ 0.5 - # cv << threshold: score → 1.0 - # cv >> threshold: score → 0.0 threshold = self.config.cv_threshold - k = 10.0 / threshold # Steepness factor - - score = 1.0 / (1.0 + math.exp(k * (cv - threshold))) - return score - - def _analyze_periodicity(self, intervals: List[float]): - + k = 10.0 / threshold + return 1.0 / (1.0 + math.exp(k * (cv - threshold))) + + def _analyze_periodicity(self, intervals: List[float]) -> PeriodicityResult: + """Hybrid periodicity scorer: consistency + FFT. + + Two complementary components are computed and the higher score wins: + + Consistency (fraction-within-tolerance) + Measures what fraction of intervals falls within a tolerance band + centred on the median. Tolerance = max(cv_threshold × median, + jitter_threshold). This is reliable for tight, uniform beacons + where the FFT fails (after DC-removal the centred series is near- + zero noise, leaving no dominant peak). + + FFT (spectral dominance) + Detects structured periodicity in the *interval deviation* sequence + (e.g. sinusoidal / alternating jitter patterns added by evasive C2 + implants). Weak for purely uniform intervals; the ratio is capped + at 3.0 to prevent super-linear amplification on small samples. + + Both components share the same linear sample-count ramp-up penalty: + score → 0 at n=4, unchanged at n≥20. + """ if len(intervals) < 4: return PeriodicityResult( is_periodic=False, @@ -312,52 +360,65 @@ def _analyze_periodicity(self, intervals: List[float]): arr = np.array(intervals) n = len(arr) + median = float(np.median(arr)) - # Remove mean (DC component) - arr_centered = arr - np.mean(arr) + # ---- Consistency-based score ------------------------------------ + # Effective tolerance: relative floor prevents over-penalising + # long-period beacons with proportionally larger jitter. + if median > 0: + tolerance = max( + self.config.cv_threshold * median, + self.config.jitter_threshold, + ) + consistency_score = float(np.sum(np.abs(arr - median) <= tolerance)) / n + else: + consistency_score = 0.0 - # Perform FFT + # ---- FFT-based score -------------------------------------------- + mean_interval = float(np.mean(arr)) + arr_centered = arr - mean_interval fft_result = fft.fft(arr_centered) - frequencies = fft.fftfreq(n, d=np.mean(arr)) + frequencies = fft.fftfreq(n, d=max(mean_interval, 1e-10)) - # Get magnitude spectrum (positive frequencies only) magnitude = np.abs(fft_result[: n // 2]) freq_positive = frequencies[: n // 2] - - # Normalize magnitude magnitude_norm = magnitude / (np.sum(magnitude) + 1e-10) - # Find peaks peaks = [] for i in range(1, len(magnitude) - 1): if magnitude[i] > magnitude[i - 1] and magnitude[i] > magnitude[i + 1]: - if freq_positive[i] > 0: # Skip DC + if freq_positive[i] > 0: peaks.append((freq_positive[i], magnitude_norm[i])) - # Sort peaks by magnitude peaks.sort(key=lambda x: x[1], reverse=True) top_peaks = peaks[:5] - # Calculate periodicity score - # Based on how dominant the main frequency is - if len(top_peaks) > 0: + if top_peaks: dominant_magnitude = top_peaks[0][1] dominant_freq = top_peaks[0][0] - dominant_period = 1.0 / dominant_freq if dominant_freq > 0 else None + fft_period = 1.0 / dominant_freq if dominant_freq > 0 else None - # Score is based on dominance of primary frequency - # A strong single frequency indicates regular periodicity if len(top_peaks) > 1: - # Ratio of dominant to second peak - ratio = dominant_magnitude / (top_peaks[1][1] + 1e-10) - periodicity_score = min(1.0, dominant_magnitude * ratio) + ratio = min(3.0, dominant_magnitude / (top_peaks[1][1] + 1e-10)) + fft_score = min(1.0, dominant_magnitude * ratio) else: - periodicity_score = dominant_magnitude + fft_score = dominant_magnitude else: - dominant_period = None - periodicity_score = 0.0 + fft_period = None + fft_score = 0.0 + + # dominant_period: prefer FFT-derived value; fall back to median + dominant_period = ( + fft_period if fft_period is not None else (median if median > 0 else None) + ) + + # ---- Sample-count penalty (applied to both components) ---------- + # Linear ramp-up: score → 0 at n=4, unchanged at n≥20. + sample_factor = min(1.0, max(0.0, (n - 4) / (20 - 4))) if n < 20 else 1.0 + + # Final: take the stronger of the two components, then penalise. + periodicity_score = max(fft_score, consistency_score) * sample_factor - # Determine if periodic based on threshold is_periodic = periodicity_score >= self.config.periodicity_threshold return PeriodicityResult( @@ -367,21 +428,47 @@ def _analyze_periodicity(self, intervals: List[float]): frequency_peaks=top_peaks, ) - def _calculate_jitter_score(self, jitter: float): + def _calculate_jitter_score( + self, jitter: float, interval_median: float = 0.0 + ) -> float: + """Piecewise: 1.0 at zero jitter, ~0.5 at threshold, exponential decay beyond. + The effective threshold is the larger of the configured absolute threshold and + 5% of the median interval. This prevents long-period beacons (e.g. 300s with + ±15s jitter) from being penalised by a threshold calibrated for 60s beacons. + """ if jitter <= 0: return 1.0 + relative_floor = interval_median * 0.05 if interval_median > 0 else 0.0 + threshold = max(self.config.jitter_threshold, relative_floor) + if jitter <= threshold: + return 1.0 - (jitter / threshold) * 0.5 + return max(0.0, 0.5 * math.exp(-(jitter - threshold) / threshold)) - threshold = self.config.jitter_threshold + def _calculate_size_score(self, packet_sizes: List[int]) -> float: + """Coefficient-of-variation of packet sizes. - # Linear scaling with cutoff - if jitter <= threshold: - score = 1.0 - (jitter / threshold) * 0.5 # Score 0.5-1.0 - else: - # Exponential decay beyond threshold - score = 0.5 * math.exp(-(jitter - threshold) / threshold) + Beacons typically use uniform packet sizes (low CV → high score). + Returns 0.5 as a neutral score when fewer than 5 samples are available. + """ + if len(packet_sizes) < 5: + return 0.5 - return max(0.0, min(1.0, score)) + arr = np.array(packet_sizes, dtype=float) + mean = float(np.mean(arr)) + if mean <= 0: + return 0.5 + + cv = float(np.std(arr)) / mean + # Reuse the same sigmoid as the interval CV scorer, with a slightly + # looser threshold since packet sizes vary more than intervals. + threshold = 0.25 + k = 10.0 / threshold + return 1.0 / (1.0 + math.exp(k * (cv - threshold))) + + # ------------------------------------------------------------------ + # Confidence determination + # ------------------------------------------------------------------ def _determine_confidence( self, @@ -389,7 +476,8 @@ def _determine_confidence( cv_score: float, periodicity_score: float, jitter_score: float, - ): + size_score: float = 0.5, + ) -> BeaconConfidence: if combined_score < 0.3: return BeaconConfidence.NONE @@ -400,28 +488,153 @@ def _determine_confidence( elif combined_score < 0.85: return BeaconConfidence.HIGH else: - # Critical requires all three indicators to be strong - if cv_score > 0.7 and periodicity_score > 0.7 and jitter_score > 0.7: - return BeaconConfidence.CRITICAL - return BeaconConfidence.HIGH + # CRITICAL requires all four signals to be individually strong + all_strong = ( + cv_score > 0.7 + and periodicity_score > 0.7 + and jitter_score > 0.7 + and size_score > 0.7 + ) + return BeaconConfidence.CRITICAL if all_strong else BeaconConfidence.HIGH - def batch_analyze(self, pairs: List[ConnectionPair]): - results = [] + # ------------------------------------------------------------------ + # Analyst-facing explanation + # ------------------------------------------------------------------ + def _build_explanation( + self, + interval_stats: IntervalStats, + periodicity_result: PeriodicityResult, + cv_score: float, + jitter_score: float, + size_score: float, + windowed_sizes: List[int], + sample_count: int, + ) -> Dict: + + size_cv = None + if len(windowed_sizes) >= 5: + arr = np.array(windowed_sizes, dtype=float) + mean = float(np.mean(arr)) + size_cv = round(float(np.std(arr)) / mean, 4) if mean > 0 else None + + signals = [ + { + "name": "cv", + "score": round(cv_score, 4), + "weight": self.config.cv_weight, + "contribution": round(self.config.cv_weight * cv_score, 4), + "raw_value": round(interval_stats.cv, 4), + "threshold": self.config.cv_threshold, + }, + { + "name": "periodicity", + "score": round(periodicity_result.periodicity_score, 4), + "weight": self.config.periodicity_weight, + "contribution": round( + self.config.periodicity_weight + * periodicity_result.periodicity_score, + 4, + ), + "raw_value": round(periodicity_result.periodicity_score, 4), + "threshold": self.config.periodicity_threshold, + }, + { + "name": "jitter", + "score": round(jitter_score, 4), + "weight": self.config.jitter_weight, + "contribution": round(self.config.jitter_weight * jitter_score, 4), + "raw_value": round(interval_stats.jitter, 3), + "threshold": round( + max(self.config.jitter_threshold, interval_stats.median * 0.05), 3 + ), + }, + { + "name": "packet_size_consistency", + "score": round(size_score, 4), + "weight": self.config.size_weight, + "contribution": round(self.config.size_weight * size_score, 4), + "raw_value": size_cv, + "threshold": 0.25, + }, + ] + + return { + "detected_interval_seconds": round(interval_stats.median, 2), + "interval_mean_seconds": round(interval_stats.mean, 2), + "interval_std_seconds": round(interval_stats.std_dev, 2), + "dominant_fft_period_seconds": ( + round(periodicity_result.dominant_period, 2) + if periodicity_result.dominant_period + else None + ), + "jitter_seconds": round(interval_stats.jitter, 2), + "packet_size_cv": size_cv, + "sample_count": sample_count, + "observation_window_seconds": self.config.time_window, + "contributing_signals": signals, + "signals_above_threshold": [ + s["name"] for s in signals if s["score"] >= 0.7 + ], + } + + # ------------------------------------------------------------------ + # Batch helpers + # ------------------------------------------------------------------ + + def _prefilter(self, pair: ConnectionPair) -> bool: + """Fast pre-checks applied before full FFT scoring. + + Returns False when the pair can be discarded without any heavy + computation — skipping numpy allocation, bisect, and FFT entirely. + + Checks (in order of cheapness): + 1. Raw connection count below minimum. + 2. Total observation span shorter than one beacon interval (no valid + intervals can exist in the beacon range). + 3. Rough mean interval far above max_beacon_interval (too slow to be + a beacon within the configured window). + """ + if pair.connection_count < self.config.min_connections: + return False + + if pair.first_seen is not None and pair.last_seen is not None: + span = pair.last_seen - pair.first_seen + # No interval can be >= min_beacon_interval if the whole span is shorter + if span < self.config.min_beacon_interval: + logger.debug( + f"Pre-filter: {pair.pair_key} span={span:.1f}s < " + f"min_beacon_interval={self.config.min_beacon_interval}s" + ) + return False + # Rough mean interval: if clearly above max, skip + if pair.connection_count > 1: + rough_interval = span / (pair.connection_count - 1) + if rough_interval > self.config.max_beacon_interval * 2: + logger.debug( + f"Pre-filter: {pair.pair_key} rough_interval={rough_interval:.1f}s " + f"> 2×max_beacon_interval={self.config.max_beacon_interval}s" + ) + return False + + return True + + def batch_analyze(self, pairs: List[ConnectionPair]) -> List[DetectionResult]: + + results = [] for pair in pairs: try: + if not self._prefilter(pair): + continue result = self.analyze(pair) if result is not None: results.append(result) except Exception as e: logger.error(f"Error analyzing {pair.pair_key}: {e}") - # Sort by combined score descending results.sort(key=lambda r: r.combined_score, reverse=True) - return results - def get_beacons(self, pairs: List[ConnectionPair]): + def get_beacons(self, pairs: List[ConnectionPair]) -> List[DetectionResult]: - all_results = self.batch_analyze(pairs) - return [r for r in all_results if r.is_beacon] + return [r for r in self.batch_analyze(pairs) if r.is_beacon] diff --git a/control_plane/server.py b/control_plane/server.py index 5344653..79e49b6 100644 --- a/control_plane/server.py +++ b/control_plane/server.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import argparse import asyncio +import gzip +import hmac import json import logging import logging.handlers @@ -9,11 +11,18 @@ from datetime import datetime, timezone from pathlib import Path +try: + import psutil + + _PSUTIL_AVAILABLE = True +except ImportError: + _PSUTIL_AVAILABLE = False + import yaml from aiohttp import web from .alerter import AlertingConfig, AlertManager, AlertSeverity -from .analyzer import AnalyzerConfig, ConnectionAnalyzer +from .analyzer import AnalyzerConfig, BenignPattern, ConnectionAnalyzer from .detector import BeaconDetector, DetectorConfig from .storage import ConnectionStorage @@ -66,6 +75,8 @@ def _build_runtime_config(self, config): return { "detection": { "min_connections": det.get("min_connections", 10), + "min_duration": det.get("min_duration", 300.0), + "time_window": det.get("time_window", 3600), "cv_threshold": det.get("cv_threshold", 0.15), "alert_threshold": det.get("alert_threshold", 0.7), "jitter_threshold": det.get("jitter_threshold", 5.0), @@ -73,9 +84,10 @@ def _build_runtime_config(self, config): "alert_cooldown": det.get("alert_cooldown", 300), }, "weights": { - "cv": det.get("cv_weight", 0.4), - "periodicity": det.get("periodicity_weight", 0.4), - "jitter": det.get("jitter_weight", 0.2), + "cv": det.get("cv_weight", 0.35), + "periodicity": det.get("periodicity_weight", 0.35), + "jitter": det.get("jitter_weight", 0.15), + "size": det.get("size_weight", 0.15), }, "alerting": { "syslog_enabled": alert.get("syslog", {}).get("enabled", True), @@ -112,9 +124,10 @@ def _init_detector(self, config): jitter_threshold=det_config.get("jitter_threshold", 5.0), min_beacon_interval=det_config.get("min_beacon_interval", 10.0), max_beacon_interval=det_config.get("max_beacon_interval", 3600.0), - cv_weight=det_config.get("cv_weight", 0.4), - periodicity_weight=det_config.get("periodicity_weight", 0.4), - jitter_weight=det_config.get("jitter_weight", 0.2), + cv_weight=det_config.get("cv_weight", 0.35), + periodicity_weight=det_config.get("periodicity_weight", 0.35), + jitter_weight=det_config.get("jitter_weight", 0.15), + size_weight=det_config.get("size_weight", 0.15), alert_threshold=det_config.get("alert_threshold", 0.7), ) self.detector = BeaconDetector(detector_config) @@ -145,11 +158,26 @@ def _init_alerter(self, config): def _init_analyzer(self, config): det_config = config.get("detection", {}) + baseline_config = config.get("benign_baseline", {}) + + # Parse benign patterns from config + patterns = [] + for p in baseline_config.get("patterns", []): + patterns.append( + BenignPattern( + dst_port=int(p["dst_port"]), + protocol=p.get("protocol"), + label=p.get("label", "benign"), + ) + ) + analyzer_config = AnalyzerConfig( - analysis_interval=60, # Run every minute + analysis_interval=det_config.get("analysis_interval", 60), min_connections=det_config.get("min_connections", 10), - min_duration=30.0, + min_duration=float(det_config.get("min_duration", 300.0)), alert_cooldown=det_config.get("alert_cooldown", 300), + benign_baseline_enabled=baseline_config.get("enabled", True), + benign_patterns=patterns, ) self.analyzer = ConnectionAnalyzer( storage=self.storage, @@ -217,6 +245,20 @@ async def _handle_status(self, request: web.Request) -> web.Response: if self._start_time: uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() + system_stats: dict = {"cpu_percent": None, "memory_mb": None, "open_fds": None} + if _PSUTIL_AVAILABLE: + try: + proc = psutil.Process() + system_stats["cpu_percent"] = psutil.cpu_percent(interval=None) + system_stats["memory_mb"] = round( + proc.memory_info().rss / (1024 * 1024), 1 + ) + system_stats["open_fds"] = ( + proc.num_fds() if hasattr(proc, "num_fds") else None + ) + except Exception: + pass + return web.json_response( { "status": "running", @@ -230,6 +272,7 @@ async def _handle_status(self, request: web.Request) -> web.Response: "storage": self.storage.statistics, "analyzer": self.analyzer.statistics, "alerter": self.alert_manager.statistics, + "system": system_stats, } ) @@ -248,14 +291,29 @@ async def _handle_statistics(self, request: web.Request) -> web.Response: } ) + def _check_api_key(self, request: web.Request) -> bool: + """Return True if request carries a valid API key, or auth is disabled.""" + required = self.config.get("control_plane", {}).get("api_key", "") + if not required: + return True + provided = ( + request.headers.get("X-API-Key", "") + or request.headers.get("Authorization", "").removeprefix("Bearer ").strip() + ) + return hmac.compare_digest(required, provided) + async def _handle_telemetry(self, request: web.Request) -> web.Response: self._requests_received += 1 + if not self._check_api_key(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: - # Get request body + # aiohttp ≥3.9 automatically decompresses Content-Encoding: gzip + # request bodies before they reach the handler, so no manual + # decompression is needed here. body = await request.read() - data = json.loads(body.decode("utf-8")) # Validate batch structure @@ -359,27 +417,15 @@ async def _handle_connections(self, request: web.Request) -> web.Response: } ) - async def _handle_manual_analyze(self, request: web.Request) -> web.Response: - - try: - run = self.analyzer.run_analysis() - return web.json_response( - { - "status": "completed", - "run": run.to_dict(), - "beacons_found": [r.to_dict() for r in run.results if r.is_beacon], - } - ) - except Exception as e: - logger.error(f"Manual analysis failed: {e}") - return web.json_response({"error": str(e)}, status=500) - async def _handle_get_config(self, request: web.Request) -> web.Response: return web.json_response(self._runtime_config) async def _handle_set_config(self, request: web.Request) -> web.Response: + if not self._check_api_key(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: data = await request.json() @@ -410,7 +456,6 @@ async def _handle_set_config(self, request: web.Request) -> web.Response: if "weights" in data: self._runtime_config["weights"].update(data["weights"]) - # Update detector weights if hasattr(self.detector, "config"): w = data["weights"] if "cv" in w: @@ -419,6 +464,8 @@ async def _handle_set_config(self, request: web.Request) -> web.Response: self.detector.config.periodicity_weight = w["periodicity"] if "jitter" in w: self.detector.config.jitter_weight = w["jitter"] + if "size" in w: + self.detector.config.size_weight = w["size"] if "alerting" in data: self._runtime_config["alerting"].update(data["alerting"]) @@ -461,6 +508,9 @@ async def _handle_set_config(self, request: web.Request) -> web.Response: async def _handle_clear_alerts(self, request: web.Request) -> web.Response: + if not self._check_api_key(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: if hasattr(self.alert_manager, "_recent_alerts"): self.alert_manager._recent_alerts = [] @@ -470,6 +520,9 @@ async def _handle_clear_alerts(self, request: web.Request) -> web.Response: async def _handle_clear_beacons(self, request: web.Request) -> web.Response: + if not self._check_api_key(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: if hasattr(self.analyzer, "_known_beacons"): self.analyzer._known_beacons = {} @@ -477,6 +530,24 @@ async def _handle_clear_beacons(self, request: web.Request) -> web.Response: except Exception as e: return web.json_response({"error": str(e)}, status=500) + async def _handle_manual_analyze(self, request: web.Request) -> web.Response: + + if not self._check_api_key(request): + return web.json_response({"error": "Unauthorized"}, status=401) + + try: + run = self.analyzer.run_analysis() + return web.json_response( + { + "status": "completed", + "run": run.to_dict(), + "beacons_found": [r.to_dict() for r in run.results if r.is_beacon], + } + ) + except Exception as e: + logger.error(f"Manual analysis failed: {e}") + return web.json_response({"error": str(e)}, status=500) + async def start(self): logger.info(f"Starting control plane server on {self.host}:{self.port}") diff --git a/control_plane/storage.py b/control_plane/storage.py index 8675c68..c171ca7 100644 --- a/control_plane/storage.py +++ b/control_plane/storage.py @@ -130,19 +130,17 @@ def get_intervals(self): def prune_old(self, cutoff_time): - # Find index of first timestamp >= cutoff_time idx = bisect.bisect_left(self.timestamps, cutoff_time) if idx > 0: self.timestamps = self.timestamps[idx:] - # Can't easily prune packet_sizes without matching indices - # Keep recent ones proportionally - if self.packet_sizes: - keep_ratio = len(self.timestamps) / (len(self.timestamps) + idx) - keep_count = max(1, int(len(self.packet_sizes) * keep_ratio)) - self.packet_sizes = self.packet_sizes[-keep_count:] - - # Update first_seen + # packet_sizes is kept index-aligned with timestamps via add_connection(), + # so we can slice at the same index without any ratio estimation. + if len(self.packet_sizes) >= idx: + self.packet_sizes = self.packet_sizes[idx:] + else: + self.packet_sizes = [] + if self.timestamps: self.first_seen = self.timestamps[0] else: diff --git a/data_plane/collector.py b/data_plane/collector.py index b104351..b2f8c7a 100644 --- a/data_plane/collector.py +++ b/data_plane/collector.py @@ -64,6 +64,10 @@ def __init__(self, interface: str, config: Dict, node_id: str = None): # eBPF components self._bpf = None + # Offset to convert bpf_ktime_get_ns() (boot-relative) to wall-clock ns. + # Computed once at startup; refreshed on eBPF program load. + self._ktime_offset_ns: int = self._compute_ktime_offset() + # Telemetry components dp_config = config.get("data_plane", {}) self._buffer = TelemetryBuffer( @@ -102,6 +106,27 @@ def _generate_node_id(self): unique_suffix = uuid.uuid4().hex[:8] return f"dp-{hostname}-{unique_suffix}" + @staticmethod + def _compute_ktime_offset() -> int: + """Return (wall_clock_ns - bpf_ktime_ns) so that adding the offset to any + bpf_ktime_get_ns() value yields nanoseconds since the Unix epoch. + + Uses /proc/uptime (available on all Linux kernels) to get uptime in + seconds, then derives boot epoch from current wall-clock time. + """ + try: + with open("/proc/uptime", "r") as f: + uptime_s = float(f.read().split()[0]) + wall_ns = int(time.time() * 1_000_000_000) + ktime_ns = int(uptime_s * 1_000_000_000) + return wall_ns - ktime_ns + except Exception as e: + logger.warning( + f"Could not compute ktime offset from /proc/uptime: {e}. " + "Timestamps will fall back to wall-clock time of poll batch." + ) + return 0 + def _load_ebpf_program(self): logger.info(f"Loading eBPF program from {self.EBPF_PROGRAM_PATH}") @@ -173,8 +198,10 @@ def ring_buffer_callback(ctx, data, size): # Cast the raw data to our event structure event = ctypes.cast(data, ctypes.POINTER(ConnectionEventCType)).contents - # Convert to Python object - conn_event = ConnectionEvent.from_ctype(event, self.node_id) + # Convert to Python object, applying the ktime→wall-clock offset + conn_event = ConnectionEvent.from_ctype( + event, self.node_id, self._ktime_offset_ns + ) # Apply whitelist filtering if self._should_filter(conn_event): @@ -267,6 +294,8 @@ def start(self): # Load and attach eBPF program self._bpf = self._load_ebpf_program() self._attach_ebpf_program() + # Refresh the ktime offset now that we know the eBPF clock is running + self._ktime_offset_ns = self._compute_ktime_offset() self._setup_ring_buffer() # Start exporter diff --git a/data_plane/ebpf_program.c b/data_plane/ebpf_program.c index b541195..dc7a980 100644 --- a/data_plane/ebpf_program.c +++ b/data_plane/ebpf_program.c @@ -23,9 +23,22 @@ struct connection_event { __u8 padding; /* Alignment padding */ }; -BPF_RINGBUF_OUTPUT(events, 1 << 16); /* 64KB ring buffer per CPU */ +BPF_RINGBUF_OUTPUT(events, 1 << 16); /* 64KB ring buffer */ -BPF_HASH(recent_connections, __u64, __u64, 65536); +/* + * Deduplication map: keyed on the full 5-tuple to avoid XOR collisions. + * Value is the last-seen timestamp in nanoseconds. + */ +struct conn_key { + __u32 src_ip; + __u32 dst_ip; + __u16 src_port; + __u16 dst_port; + __u8 protocol; + __u8 pad[3]; /* explicit padding for struct alignment */ +}; + +BPF_HASH(recent_connections, struct conn_key, __u64, 65536); BPF_ARRAY(stats, __u64, 8); @@ -49,22 +62,25 @@ static __always_inline void update_stat(__u32 index) { } } -static __always_inline int is_duplicate(__u32 src_ip, __u32 dst_ip, +static __always_inline int is_duplicate(__u32 src_ip, __u32 dst_ip, __u16 src_port, __u16 dst_port, - __u64 now_ns) { - /* Create connection key by XORing addresses and ports */ - __u64 conn_key = ((__u64)src_ip << 32) | dst_ip; - conn_key ^= ((__u64)src_port << 16) | dst_port; - - __u64 *last_seen = recent_connections.lookup(&conn_key); + __u8 protocol, __u64 now_ns) { + struct conn_key key = {}; + key.src_ip = src_ip; + key.dst_ip = dst_ip; + key.src_port = src_port; + key.dst_port = dst_port; + key.protocol = protocol; + + __u64 *last_seen = recent_connections.lookup(&key); if (last_seen) { if ((now_ns - *last_seen) < DEDUP_WINDOW_NS) { update_stat(STAT_DEDUP_HITS); - return 1; + return 1; } } - - recent_connections.update(&conn_key, &now_ns); + + recent_connections.update(&key, &now_ns); return 0; } @@ -140,7 +156,7 @@ static __always_inline int process_ipv4(void *data, void *data_end, //timestamp __u64 now_ns = bpf_ktime_get_ns(); - if (is_duplicate(ip->saddr, ip->daddr, src_port, dst_port, now_ns)) { + if (is_duplicate(ip->saddr, ip->daddr, src_port, dst_port, protocol, now_ns)) { return 0; } diff --git a/data_plane/telemetry.py b/data_plane/telemetry.py index 4887bc6..38e1d70 100644 --- a/data_plane/telemetry.py +++ b/data_plane/telemetry.py @@ -67,20 +67,37 @@ class ConnectionEvent: def __post_init__(self): if self.timestamp_utc is None: - # Convert kernel timestamp to UTC datetime string - # Note: bpf_ktime_get_ns() returns time since boot, not epoch - # We'll use current time for UTC representation + # Fall back to current wall-clock time when no ktime offset is available. + # Callers that have a ktime offset should pass timestamp_utc explicitly + # via from_ctype(event, node_id, ktime_offset_ns=...). self.timestamp_utc = ( datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") ) @classmethod - def from_ctype(cls, event: ConnectionEventCType, node_id: str = None): - """ - Create a ConnectionEvent from the C-type structure. + def from_ctype( + cls, event: ConnectionEventCType, node_id: str = None, ktime_offset_ns: int = 0 + ): + """Create a ConnectionEvent from the C-type structure. + + ktime_offset_ns is the value of (wall_clock_ns - bpf_ktime_ns) computed + once at collector startup from /proc/uptime. Adding it to the kernel + timestamp converts boot-relative nanoseconds to UTC epoch nanoseconds. """ + if ktime_offset_ns: + epoch_ns = event.timestamp_ns + ktime_offset_ns + epoch_s = epoch_ns / 1e9 + ts_utc = ( + datetime.fromtimestamp(epoch_s, tz=timezone.utc) + .isoformat() + .replace("+00:00", "Z") + ) + else: + epoch_ns = event.timestamp_ns + ts_utc = None # __post_init__ will set wall-clock fallback + return cls( - timestamp_ns=event.timestamp_ns, + timestamp_ns=epoch_ns, src_ip=cls._int_to_ip(event.src_ip), dst_ip=cls._int_to_ip(event.dst_ip), src_port=event.src_port, @@ -89,6 +106,7 @@ def from_ctype(cls, event: ConnectionEventCType, node_id: str = None): protocol=event.protocol, tcp_flags=event.tcp_flags, direction=event.direction, + timestamp_utc=ts_utc, node_id=node_id, ) diff --git a/requirements.txt b/requirements.txt index ec0c5ad..da8b1f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,10 +7,7 @@ PyYAML>=6.0 aiohttp>=3.8.0 requests>=2.28.0 python-json-logger>=2.0.0 -pydantic>=2.0.0 -click>=8.0.0 psutil>=5.9.0 pytest>=7.0.0 pytest-asyncio>=0.20.0 tenacity>=8.0.0 -typing-extensions>=4.0.0 diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index c65a1ab..977179c 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -5,11 +5,15 @@ import pytest -# Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from control_plane.alerter import Alert, AlertManager, AlertSeverity -from control_plane.analyzer import AnalysisRun, AnalyzerConfig, ConnectionAnalyzer +from control_plane.analyzer import ( + AnalysisRun, + AnalyzerConfig, + BenignPattern, + ConnectionAnalyzer, +) from control_plane.detector import ( BeaconConfidence, BeaconDetector, @@ -19,9 +23,14 @@ from control_plane.storage import ConnectionPair, ConnectionRecord, ConnectionStorage +# --------------------------------------------------------------------------- +# ConnectionStorage +# --------------------------------------------------------------------------- + + class TestConnectionStorage: - def test_add_record(self): + def test_add_record(self): storage = ConnectionStorage(retention_seconds=3600) record = ConnectionRecord( timestamp_ns=1000000000, @@ -38,16 +47,13 @@ def test_add_record(self): node_id="test-node", connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", ) - storage.add_record(record) - stats = storage.statistics assert stats["records_added"] == 1 assert stats["pair_count"] == 1 def test_add_batch(self): storage = ConnectionStorage(retention_seconds=3600) - records = [ { "timestamp_ns": 1000000000 + i * 1000000, @@ -66,19 +72,14 @@ def test_add_batch(self): } for i in range(10) ] - storage.add_batch(records) - - stats = storage.statistics - assert stats["records_added"] == 10 + assert storage.statistics["records_added"] == 10 def test_get_pairs_by_src(self): - storage = ConnectionStorage() - for i in range(5): record = ConnectionRecord( - timestamp_ns=1000000000 + i * 1000000, + timestamp_ns=1000000000 + i, timestamp_utc="2024-01-01T00:00:00Z", src_ip="192.168.1.100", dst_ip=f"10.0.0.{i}", @@ -93,20 +94,14 @@ def test_get_pairs_by_src(self): connection_key=f"192.168.1.100:54321->10.0.0.{i}:443/TCP", ) storage.add_record(record) - - pairs = storage.get_pairs_by_src("192.168.1.100") - assert len(pairs) == 5 - - pairs = storage.get_pairs_by_src("192.168.1.200") - assert len(pairs) == 0 + assert len(storage.get_pairs_by_src("192.168.1.100")) == 5 + assert len(storage.get_pairs_by_src("192.168.1.200")) == 0 def test_get_analyzable_pairs(self): - storage = ConnectionStorage() - for i in range(15): record = ConnectionRecord( - timestamp_ns=1000000000 + i * 60000000000, # 60s apart + timestamp_ns=1000000000 + i * 60000000000, timestamp_utc=f"2024-01-01T00:{i:02d}:00Z", src_ip="192.168.1.100", dst_ip="10.0.0.1", @@ -121,10 +116,9 @@ def test_get_analyzable_pairs(self): connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", ) storage.add_record(record) - for i in range(3): record = ConnectionRecord( - timestamp_ns=1000000000 + i * 60000000000, + timestamp_ns=1000000000 + i, timestamp_utc=f"2024-01-01T00:{i:02d}:00Z", src_ip="192.168.1.101", dst_ip="10.0.0.2", @@ -139,124 +133,168 @@ def test_get_analyzable_pairs(self): connection_key="192.168.1.101:54322->10.0.0.2:80/TCP", ) storage.add_record(record) - pairs = storage.get_analyzable_pairs(min_connections=10, min_duration=60) assert len(pairs) == 1 assert pairs[0].src_ip == "192.168.1.100" +# --------------------------------------------------------------------------- +# ConnectionPair +# --------------------------------------------------------------------------- + + class TestConnectionPair: def test_get_intervals(self): - pair = ConnectionPair( - src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + src_ip="1.2.3.4", dst_ip="5.6.7.8", dst_port=443, protocol="TCP" ) - for i in range(5): pair.timestamps.append(1000.0 + i * 60.0) - + pair.packet_sizes.append(128) intervals = pair.get_intervals() - assert len(intervals) == 4 assert all(i == 60.0 for i in intervals) def test_duration(self): - pair = ConnectionPair( - src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + src_ip="1.2.3.4", dst_ip="5.6.7.8", dst_port=443, protocol="TCP" ) - pair.first_seen = 1000.0 pair.last_seen = 2000.0 - assert pair.duration_seconds == 1000.0 - def test_prune_old(self): - + def test_prune_old_index_aligned(self): + """After pruning, timestamps and packet_sizes must have the same length.""" pair = ConnectionPair( - src_ip="192.168.1.100", dst_ip="10.0.0.1", dst_port=443, protocol="TCP" + src_ip="1.2.3.4", dst_ip="5.6.7.8", dst_port=443, protocol="TCP" ) - - base_time = time.time() - 7200 + base = time.time() - 7200 for i in range(10): - pair.timestamps.append(base_time + i * 600) + pair.timestamps.append(base + i * 600) + pair.packet_sizes.append(100 + i) pair.first_seen = pair.timestamps[0] pair.last_seen = pair.timestamps[-1] cutoff = time.time() - 3600 pair.prune_old(cutoff) + assert len(pair.timestamps) == len(pair.packet_sizes) assert pair.connection_count < 10 +# --------------------------------------------------------------------------- +# BenignPattern +# --------------------------------------------------------------------------- + + +class TestBenignPattern: + + def test_matches_by_port_and_protocol(self): + pattern = BenignPattern(dst_port=123, protocol="UDP", label="NTP") + ntp_pair = ConnectionPair( + src_ip="10.0.0.1", dst_ip="1.2.3.4", dst_port=123, protocol="UDP" + ) + tcp_pair = ConnectionPair( + src_ip="10.0.0.1", dst_ip="1.2.3.4", dst_port=123, protocol="TCP" + ) + other_pair = ConnectionPair( + src_ip="10.0.0.1", dst_ip="1.2.3.4", dst_port=443, protocol="UDP" + ) + assert pattern.matches(ntp_pair) is True + assert pattern.matches(tcp_pair) is False + assert pattern.matches(other_pair) is False + + def test_matches_any_protocol_when_unset(self): + pattern = BenignPattern(dst_port=53, label="DNS") + tcp_dns = ConnectionPair( + src_ip="10.0.0.1", dst_ip="8.8.8.8", dst_port=53, protocol="TCP" + ) + udp_dns = ConnectionPair( + src_ip="10.0.0.1", dst_ip="8.8.8.8", dst_port=53, protocol="UDP" + ) + assert pattern.matches(tcp_dns) is True + assert pattern.matches(udp_dns) is True + + +# --------------------------------------------------------------------------- +# AnalyzerConfig +# --------------------------------------------------------------------------- + + class TestAnalyzerConfig: def test_default_config(self): - config = AnalyzerConfig() - assert config.analysis_interval == 60 assert config.min_connections == 10 assert config.min_duration == 300.0 assert config.alert_cooldown == 300 + assert config.benign_baseline_enabled is True - def test_custom_config(self): + def test_effective_patterns_use_defaults_when_empty(self): + config = AnalyzerConfig(benign_baseline_enabled=True, benign_patterns=[]) + patterns = config.get_effective_benign_patterns() + assert len(patterns) > 0 # default NTP pattern - config = AnalyzerConfig( - analysis_interval=120, min_connections=20, alert_cooldown=600 - ) + def test_effective_patterns_empty_when_disabled(self): + config = AnalyzerConfig(benign_baseline_enabled=False) + assert config.get_effective_benign_patterns() == [] - assert config.analysis_interval == 120 - assert config.min_connections == 20 - assert config.alert_cooldown == 600 + def test_custom_patterns_override_defaults(self): + custom = [BenignPattern(dst_port=8125, label="StatsD")] + config = AnalyzerConfig(benign_baseline_enabled=True, benign_patterns=custom) + patterns = config.get_effective_benign_patterns() + assert len(patterns) == 1 + assert patterns[0].label == "StatsD" -class TestAnalysisRun: +# --------------------------------------------------------------------------- +# AnalysisRun +# --------------------------------------------------------------------------- - def test_analysis_run_creation(self): - run = AnalysisRun("test-run-1") +class TestAnalysisRun: + def test_creation(self): + run = AnalysisRun("test-run-1") assert run.run_id == "test-run-1" assert run.pairs_analyzed == 0 - assert run.beacons_detected == 0 + assert run.pairs_suppressed == 0 assert run.end_time is None - def test_analysis_run_completion(self): - + def test_completion(self): run = AnalysisRun("test-run-1") run.pairs_analyzed = 100 + run.pairs_suppressed = 5 run.beacons_detected = 2 - run.alerts_generated = 2 - run.complete() - assert run.end_time is not None assert run.duration_seconds >= 0 - def test_to_dict(self): - + def test_to_dict_includes_suppressed(self): run = AnalysisRun("test-run-1") - run.pairs_analyzed = 50 + run.pairs_suppressed = 3 run.complete() - d = run.to_dict() + assert "pairs_suppressed" in d + assert d["pairs_suppressed"] == 3 - assert d["run_id"] == "test-run-1" - assert d["pairs_analyzed"] == 50 - assert "start_time" in d - assert "end_time" in d + +# --------------------------------------------------------------------------- +# ConnectionAnalyzer +# --------------------------------------------------------------------------- class TestConnectionAnalyzer: def setup_method(self): - self.storage = ConnectionStorage() self.detector = BeaconDetector(DetectorConfig(min_connections=5)) self.alert_manager = Mock(spec=AlertManager) self.alert_manager.send_alert = Mock() + self.alert_manager.config = Mock() + self.alert_manager.config.enabled = False self.analyzer = ConnectionAnalyzer( storage=self.storage, @@ -267,18 +305,16 @@ def setup_method(self): min_connections=5, min_duration=60, alert_cooldown=60, + benign_baseline_enabled=False, # disabled for unit tests ), ) def test_run_analysis_empty_storage(self): - run = self.analyzer.run_analysis() - assert run.pairs_analyzed == 0 assert run.beacons_detected == 0 def test_run_analysis_with_data(self): - base_time = time.time() for i in range(20): record = ConnectionRecord( @@ -296,94 +332,113 @@ def test_run_analysis_with_data(self): node_id="test-node", connection_key="192.168.1.100:54321->10.0.0.1:443/TCP", ) - record.timestamp_epoch = base_time + i * 60 self.storage.add_record(record) - run = self.analyzer.run_analysis() - assert run.pairs_analyzed >= 1 - def test_alert_cooldown(self): + def test_benign_suppression_skips_ntp(self): + """NTP pairs (UDP/123) must be suppressed and not reach the detector.""" + storage = ConnectionStorage() + analyzer = ConnectionAnalyzer( + storage=storage, + detector=self.detector, + alert_manager=self.alert_manager, + config=AnalyzerConfig( + min_connections=5, + min_duration=60, + benign_baseline_enabled=True, + benign_patterns=[ + BenignPattern(dst_port=123, protocol="UDP", label="NTP") + ], + ), + ) + base = time.time() + for i in range(20): + record = ConnectionRecord( + timestamp_ns=int((base + i * 64) * 1e9), + timestamp_utc=f"2024-01-01T00:{i:02d}:00Z", + src_ip="10.0.0.5", + dst_ip="192.0.2.1", + src_port=12345, + dst_port=123, + packet_size=48, + protocol=17, + protocol_name="UDP", + tcp_flags=0, + direction=1, + node_id="test-node", + connection_key="10.0.0.5:12345->192.0.2.1:123/UDP", + ) + record.timestamp_epoch = base + i * 64 + storage.add_record(record) + run = analyzer.run_analysis() + assert run.pairs_suppressed == 1 + assert run.pairs_analyzed == 0 + def test_alert_cooldown(self): pair_key = "192.168.1.100->10.0.0.1:443/TCP" self.analyzer._alert_cooldowns[pair_key] = time.time() - mock_result = Mock(spec=DetectionResult) mock_result.pair_key = pair_key mock_result.combined_score = 0.9 mock_result.confidence = BeaconConfidence.HIGH - self.analyzer._known_beacons[pair_key] = mock_result - - should_alert = self.analyzer._should_alert(mock_result) - assert not should_alert + assert not self.analyzer._should_alert(mock_result) def test_get_known_beacons(self): - mock_result = Mock(spec=DetectionResult) mock_result.pair_key = "test-pair" - self.analyzer._known_beacons["test-pair"] = mock_result + assert len(self.analyzer.get_known_beacons()) == 1 - beacons = self.analyzer.get_known_beacons() - - assert len(beacons) == 1 - - def test_statistics(self): - + def test_statistics_includes_suppression_fields(self): stats = self.analyzer.statistics + assert "total_suppressed" in stats + assert "benign_baseline_enabled" in stats + assert "benign_pattern_count" in stats - assert "running" in stats - assert "analysis_interval" in stats - assert "total_runs" in stats - assert "current_known_beacons" in stats + +# --------------------------------------------------------------------------- +# Alert infrastructure +# --------------------------------------------------------------------------- class TestAlertManager: def test_alert_creation(self): - alert = Alert( - alert_id="test-alert-1", + alert_id="test-1", title="Test Alert", - description="This is a test alert", + description="desc", severity=AlertSeverity.HIGH, source="test", ) - - assert alert.alert_id == "test-alert-1" + assert alert.alert_id == "test-1" assert alert.severity == AlertSeverity.HIGH - assert alert.timestamp is not None def test_alert_to_dict(self): - alert = Alert( - alert_id="test-alert-1", - title="Test Alert", - description="This is a test alert", + alert_id="test-1", + title="Test", + description="desc", severity=AlertSeverity.CRITICAL, source="test", details={"key": "value"}, ) - d = alert.to_dict() - - assert d["alert_id"] == "test-alert-1" assert d["severity"] == "critical" assert d["details"] == {"key": "value"} def test_alert_to_syslog(self): alert = Alert( - alert_id="test-alert-1", + alert_id="test-1", title="Beacon Detected", - description="Beaconing detected from 192.168.1.100", + description="Beaconing from 192.168.1.100", severity=AlertSeverity.HIGH, source="beacon_detector", ) - msg = alert.to_syslog_message() - assert "[HIGH]" in msg assert "Beacon Detected" in msg @@ -395,7 +450,6 @@ def test_syslog_priority_mapping(self): assert AlertSeverity.INFO.syslog_priority == logging.INFO assert AlertSeverity.LOW.syslog_priority == logging.WARNING - assert AlertSeverity.MEDIUM.syslog_priority == logging.WARNING assert AlertSeverity.HIGH.syslog_priority == logging.ERROR assert AlertSeverity.CRITICAL.syslog_priority == logging.CRITICAL diff --git a/tests/test_detector.py b/tests/test_detector.py index c9d57bd..d5169f7 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -1,15 +1,15 @@ -""" -Tests for Beacon Detection Algorithms +"""Tests for Beacon Detection Algorithms. + Run with: pytest tests/test_detector.py -v """ import random import sys +import time from pathlib import Path import pytest -# Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from control_plane.detector import ( @@ -23,164 +23,179 @@ from control_plane.storage import ConnectionPair -class TestIntervalStats: +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_pair( + intervals, + src_ip="192.168.1.100", + dst_ip="10.0.0.1", + dst_port=443, + protocol="TCP", + packet_sizes=None, +): + pair = ConnectionPair( + src_ip=src_ip, dst_ip=dst_ip, dst_port=dst_port, protocol=protocol + ) + ts = 1_000_000.0 + pair.timestamps.append(ts) + default_size = 128 + pair.packet_sizes.append(packet_sizes[0] if packet_sizes else default_size) + for i, interval in enumerate(intervals): + ts += interval + pair.timestamps.append(ts) + size = ( + packet_sizes[i + 1] + if packet_sizes and i + 1 < len(packet_sizes) + else default_size + ) + pair.packet_sizes.append(size) + pair.first_seen = pair.timestamps[0] + pair.last_seen = pair.timestamps[-1] + return pair - def test_regular_intervals(self): - detector = BeaconDetector() - intervals = [60.0] * 20 # 20 intervals of 60 seconds each +# --------------------------------------------------------------------------- +# Interval statistics +# --------------------------------------------------------------------------- - stats = detector._calculate_interval_stats(intervals) +class TestIntervalStats: + + def test_regular_intervals(self): + detector = BeaconDetector() + stats = detector._calculate_interval_stats([60.0] * 20) assert stats.count == 20 assert stats.mean == 60.0 assert stats.std_dev == 0.0 assert stats.cv == 0.0 - assert stats.median == 60.0 assert stats.jitter == 0.0 def test_irregular_intervals(self): - detector = BeaconDetector() - intervals = [10.0, 50.0, 30.0, 70.0, 40.0] # Highly variable - - stats = detector._calculate_interval_stats(intervals) - + stats = detector._calculate_interval_stats([10.0, 50.0, 30.0, 70.0, 40.0]) assert stats.count == 5 assert stats.mean == 40.0 - assert stats.cv >= 0.5 # High coefficient of variation + assert stats.cv >= 0.5 assert stats.min_interval == 10.0 assert stats.max_interval == 70.0 def test_small_jitter(self): - + random.seed(0) detector = BeaconDetector() - base = 60.0 - jitter_range = 2.0 - intervals = [ - base + random.uniform(-jitter_range, jitter_range) for _ in range(50) - ] - + intervals = [60.0 + random.uniform(-2, 2) for _ in range(50)] stats = detector._calculate_interval_stats(intervals) - - assert stats.count == 50 - assert 58.0 < stats.mean < 62.0 - assert stats.cv < 0.1 # Low CV due to small jitter + assert stats.cv < 0.1 assert stats.jitter < 5.0 +# --------------------------------------------------------------------------- +# CV score +# --------------------------------------------------------------------------- + + class TestCVScore: def test_zero_cv_max_score(self): - - detector = BeaconDetector() - score = detector._calculate_cv_score(0.0) - assert score == 1.0 + assert BeaconDetector()._calculate_cv_score(0.0) == 1.0 def test_high_cv_low_score(self): - - detector = BeaconDetector() - score = detector._calculate_cv_score(1.0) - assert score < 0.1 + assert BeaconDetector()._calculate_cv_score(1.0) < 0.1 def test_threshold_cv_medium_score(self): - - config = DetectorConfig(cv_threshold=0.15) - detector = BeaconDetector(config) + detector = BeaconDetector(DetectorConfig(cv_threshold=0.15)) score = detector._calculate_cv_score(0.15) assert 0.4 < score < 0.6 -class TestPeriodicityAnalysis: +# --------------------------------------------------------------------------- +# Periodicity analysis +# --------------------------------------------------------------------------- - def test_perfectly_periodic(self): - - detector = BeaconDetector() - # Create perfectly periodic intervals - intervals = [60.0] * 30 - result = detector._analyze_periodicity(intervals) +class TestPeriodicityAnalysis: - # Perfect periodicity should have low score (no variation to detect) - # But our implementation should handle this edge case + def test_perfectly_periodic(self): + result = BeaconDetector()._analyze_periodicity([60.0] * 30) assert isinstance(result, PeriodicityResult) def test_periodic_with_noise(self): - - detector = BeaconDetector() - # Create periodic intervals with small noise - base = 60.0 - noise = 3.0 - intervals = [base + random.gauss(0, noise) for _ in range(50)] - - result = detector._analyze_periodicity(intervals) - + random.seed(1) + intervals = [60.0 + random.gauss(0, 3) for _ in range(50)] + result = BeaconDetector()._analyze_periodicity(intervals) assert isinstance(result, PeriodicityResult) - # Should have some periodicity detected def test_random_no_periodicity(self): + random.seed(2) + intervals = [random.uniform(10, 300) for _ in range(50)] + result = BeaconDetector()._analyze_periodicity(intervals) + assert result.periodicity_score < 0.5 + def test_sample_penalty_small_n(self): + """With n<20 the score should be lower than with n>=20 for identical signal.""" + random.seed(3) + intervals_large = [60.0 + random.gauss(0, 1) for _ in range(40)] + intervals_small = intervals_large[:10] detector = BeaconDetector() - # Create random intervals - intervals = [random.uniform(10, 300) for _ in range(50)] + score_large = detector._analyze_periodicity(intervals_large).periodicity_score + score_small = detector._analyze_periodicity(intervals_small).periodicity_score + assert score_large >= score_small - result = detector._analyze_periodicity(intervals) - assert isinstance(result, PeriodicityResult) - # Random data should have low periodicity score - assert result.periodicity_score < 0.5 +# --------------------------------------------------------------------------- +# Jitter score +# --------------------------------------------------------------------------- class TestJitterScore: - def test_zero_jitter_max_score(self): + def test_zero_jitter(self): + assert BeaconDetector()._calculate_jitter_score(0.0) == 1.0 - detector = BeaconDetector() - score = detector._calculate_jitter_score(0.0) - assert score == 1.0 - - def test_threshold_jitter(self): - - config = DetectorConfig(jitter_threshold=5.0) - detector = BeaconDetector(config) + def test_at_threshold(self): + detector = BeaconDetector(DetectorConfig(jitter_threshold=5.0)) score = detector._calculate_jitter_score(5.0) assert 0.4 < score < 0.6 - def test_high_jitter_low_score(self): + def test_high_jitter(self): + detector = BeaconDetector(DetectorConfig(jitter_threshold=5.0)) + assert detector._calculate_jitter_score(50.0) < 0.2 - config = DetectorConfig(jitter_threshold=5.0) - detector = BeaconDetector(config) - score = detector._calculate_jitter_score(50.0) - assert score < 0.2 +# --------------------------------------------------------------------------- +# Packet-size consistency score +# --------------------------------------------------------------------------- -class TestBeaconDetection: - def create_connection_pair( - self, - intervals: list, - src_ip: str = "192.168.1.100", - dst_ip: str = "10.0.0.1", - dst_port: int = 443, - ): - pair = ConnectionPair( - src_ip=src_ip, dst_ip=dst_ip, dst_port=dst_port, protocol="TCP" - ) +class TestSizeScore: - # Generate timestamps from intervals - timestamp = 1000000.0 - pair.timestamps.append(timestamp) - for interval in intervals: - timestamp += interval - pair.timestamps.append(timestamp) + def test_uniform_sizes_high_score(self): + detector = BeaconDetector() + sizes = [128] * 30 + assert detector._calculate_size_score(sizes) > 0.9 - pair.first_seen = pair.timestamps[0] - pair.last_seen = pair.timestamps[-1] + def test_variable_sizes_low_score(self): + random.seed(4) + detector = BeaconDetector() + sizes = [random.randint(64, 1400) for _ in range(30)] + assert detector._calculate_size_score(sizes) < 0.5 - return pair + def test_insufficient_samples_neutral(self): + detector = BeaconDetector() + assert detector._calculate_size_score([128, 130]) == 0.5 - def test_detect_regular_beacon(self): +# --------------------------------------------------------------------------- +# Full beacon detection +# --------------------------------------------------------------------------- + + +class TestBeaconDetection: + + def test_detect_regular_beacon(self): + random.seed(5) config = DetectorConfig( min_connections=10, cv_threshold=0.15, @@ -189,52 +204,89 @@ def test_detect_regular_beacon(self): alert_threshold=0.6, ) detector = BeaconDetector(config) - - # Create regular beacon pattern (60s intervals with small jitter) intervals = [60.0 + random.uniform(-1, 1) for _ in range(30)] - pair = self.create_connection_pair(intervals) - + pair = make_pair(intervals) result = detector.analyze(pair) - assert result is not None - assert result.cv_score > 0.7 # High score for regular intervals + assert result.cv_score > 0.7 assert result.combined_score > 0.5 - # Note: May or may not trigger is_beacon depending on exact jitter def test_detect_random_traffic(self): - config = DetectorConfig( - min_connections=10, cv_threshold=0.15, alert_threshold=0.7 + random.seed(6) + detector = BeaconDetector( + DetectorConfig(min_connections=10, alert_threshold=0.7) ) - detector = BeaconDetector(config) - - # Create random traffic pattern - intervals = [random.uniform(5, 300) for _ in range(30)] - pair = self.create_connection_pair(intervals) - + pair = make_pair([random.uniform(5, 300) for _ in range(30)]) result = detector.analyze(pair) - assert result is not None - assert result.cv_score < 0.5 # Low score for irregular intervals + assert result.cv_score < 0.5 assert not result.is_beacon - def test_insufficient_data(self): + def test_insufficient_data_returns_none(self): + detector = BeaconDetector(DetectorConfig(min_connections=20)) + pair = make_pair([60.0] * 9) + assert detector.analyze(pair) is None - config = DetectorConfig(min_connections=20) + def test_result_has_size_score(self): + random.seed(7) + config = DetectorConfig(min_connections=10, alert_threshold=0.5) detector = BeaconDetector(config) + sizes = [128 + random.randint(-2, 2) for _ in range(31)] + pair = make_pair( + [60.0 + random.uniform(-1, 1) for _ in range(30)], packet_sizes=sizes + ) + result = detector.analyze(pair) + assert result is not None + assert 0.0 <= result.size_score <= 1.0 - # Only 10 connections (need 20) - intervals = [60.0] * 9 - pair = self.create_connection_pair(intervals) - + def test_result_has_explanation(self): + random.seed(8) + config = DetectorConfig(min_connections=10, alert_threshold=0.5) + detector = BeaconDetector(config) + pair = make_pair([60.0 + random.uniform(-1, 1) for _ in range(30)]) result = detector.analyze(pair) + assert result is not None + exp = result.explanation + assert "detected_interval_seconds" in exp + assert "contributing_signals" in exp + assert len(exp["contributing_signals"]) == 4 + signal_names = {s["name"] for s in exp["contributing_signals"]} + assert "cv" in signal_names + assert "periodicity" in signal_names + assert "jitter" in signal_names + assert "packet_size_consistency" in signal_names + + def test_time_window_slicing(self): + """Only intervals within time_window seconds of last_seen are scored. + + Setup: 61 events at exactly 60s apart spanning 3600s total. + time_window=600 → cutoff = last_seen - 600 = T + 3000. + Events T+3000 … T+3600 fall inside the window: 11 timestamps → 10 intervals. + sample_count must equal 10 exactly. + """ + T = 1_000_000.0 # fixed epoch base — no wall-clock dependency + config = DetectorConfig( + min_connections=10, time_window=600, alert_threshold=0.5 + ) + detector = BeaconDetector(config) + + pair = ConnectionPair( + src_ip="1.2.3.4", dst_ip="5.6.7.8", dst_port=443, protocol="TCP" + ) + # 61 events: T, T+60, T+120, ..., T+3600 + for i in range(61): + pair.timestamps.append(T + i * 60.0) + pair.packet_sizes.append(128) + pair.first_seen = pair.timestamps[0] + pair.last_seen = pair.timestamps[-1] # T + 3600 - assert result is None + result = detector.analyze(pair) + assert result is not None, "Expected a result for a perfectly periodic pair" + # Cutoff = T+3600 - 600 = T+3000; bisect finds index 50 → 11 timestamps → 10 intervals + assert result.explanation["sample_count"] == 10 def test_confidence_levels(self): - detector = BeaconDetector() - - # Test various score combinations assert ( detector._determine_confidence(0.1, 0.1, 0.1, 0.1) == BeaconConfidence.NONE ) @@ -248,88 +300,81 @@ def test_confidence_levels(self): assert ( detector._determine_confidence(0.8, 0.8, 0.8, 0.8) == BeaconConfidence.HIGH ) + assert ( + detector._determine_confidence(0.9, 0.9, 0.9, 0.9, 0.9) + == BeaconConfidence.CRITICAL + ) - def test_batch_analyze(self): - + def test_batch_analyze_sorted_by_score(self): + random.seed(9) detector = BeaconDetector(DetectorConfig(min_connections=5)) - - # Create multiple pairs - pairs = [] - - # Regular beacon - pairs.append( - self.create_connection_pair( + pairs = [ + make_pair( [60.0 + random.uniform(-1, 1) for _ in range(20)], - src_ip="192.168.1.100", - dst_ip="10.0.0.1", - ) - ) - - # Random traffic - pairs.append( - self.create_connection_pair( + src_ip="1.1.1.1", + dst_ip="2.2.2.1", + ), + make_pair( [random.uniform(5, 300) for _ in range(20)], - src_ip="192.168.1.101", - dst_ip="10.0.0.2", - ) - ) - - # Another beacon with different interval - pairs.append( - self.create_connection_pair( + src_ip="1.1.1.2", + dst_ip="2.2.2.2", + ), + make_pair( [120.0 + random.uniform(-2, 2) for _ in range(20)], - src_ip="192.168.1.102", - dst_ip="10.0.0.3", - ) - ) - + src_ip="1.1.1.3", + dst_ip="2.2.2.3", + ), + ] results = detector.batch_analyze(pairs) - assert len(results) == 3 - # Results should be sorted by score descending assert results[0].combined_score >= results[-1].combined_score +# --------------------------------------------------------------------------- +# DetectorConfig +# --------------------------------------------------------------------------- + + class TestDetectorConfig: def test_default_config(self): - config = DetectorConfig() - assert config.min_connections == 10 assert config.cv_threshold == 0.15 - assert config.periodicity_threshold == 0.7 - assert config.jitter_threshold == 5.0 assert config.alert_threshold == 0.7 - - def test_custom_config(self): - - config = DetectorConfig( - min_connections=20, cv_threshold=0.1, alert_threshold=0.8 + # New defaults for 4-signal weights + assert ( + abs( + config.cv_weight + + config.periodicity_weight + + config.jitter_weight + + config.size_weight + - 1.0 + ) + < 0.01 ) - assert config.min_connections == 20 - assert config.cv_threshold == 0.1 - assert config.alert_threshold == 0.8 - - def test_weight_validation(self): - - # Weights should sum to 1.0 + def test_weights_sum_to_one(self): config = DetectorConfig( - cv_weight=0.5, periodicity_weight=0.3, jitter_weight=0.2 + cv_weight=0.3, periodicity_weight=0.3, jitter_weight=0.2, size_weight=0.2 + ) + total = ( + config.cv_weight + + config.periodicity_weight + + config.jitter_weight + + config.size_weight ) - detector = BeaconDetector(config) - - # Should not raise any warnings - total = config.cv_weight + config.periodicity_weight + config.jitter_weight assert abs(total - 1.0) < 0.01 -class TestDetectionResult: +# --------------------------------------------------------------------------- +# DetectionResult serialisation +# --------------------------------------------------------------------------- + - def test_to_dict(self): +class TestDetectionResult: - # Create a mock result + def test_to_dict_includes_new_fields(self): interval_stats = IntervalStats( count=20, mean=60.0, @@ -340,14 +385,12 @@ def test_to_dict(self): max_interval=62.0, jitter=2.0, ) - periodicity_result = PeriodicityResult( is_periodic=True, dominant_period=60.0, periodicity_score=0.8, frequency_peaks=[(0.0167, 0.8)], ) - result = DetectionResult( pair_key="192.168.1.100->10.0.0.1:443/TCP", src_ip="192.168.1.100", @@ -357,7 +400,8 @@ def test_to_dict(self): cv_score=0.9, periodicity_score=0.8, jitter_score=0.85, - combined_score=0.85, + size_score=0.75, + combined_score=0.83, is_beacon=True, confidence=BeaconConfidence.HIGH, interval_stats=interval_stats, @@ -366,15 +410,14 @@ def test_to_dict(self): duration_seconds=1200.0, first_seen="2024-01-01T00:00:00Z", last_seen="2024-01-01T00:20:00Z", + explanation={"detected_interval_seconds": 60.0}, ) - d = result.to_dict() - - assert d["pair_key"] == "192.168.1.100->10.0.0.1:443/TCP" - assert d["is_beacon"] == True + assert d["is_beacon"] is True assert d["confidence"] == "high" - assert "interval_stats" in d - assert "periodicity_result" in d + assert "size_score" in d + assert "explanation" in d + assert d["explanation"]["detected_interval_seconds"] == 60.0 if __name__ == "__main__": From be3b58de7dfdd83aac1ce02bf600dced7fe5cbdd Mon Sep 17 00:00:00 2001 From: litemars Date: Mon, 27 Apr 2026 21:06:04 +0200 Subject: [PATCH 2/2] fixing bugs --- .github/workflows/lint-and-test.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint-and-test.yml b/.github/workflows/lint-and-test.yml index 430c710..8d62c7a 100644 --- a/.github/workflows/lint-and-test.yml +++ b/.github/workflows/lint-and-test.yml @@ -2,16 +2,16 @@ name: Code Quality & Linting on: push: - branches: [ main, develop ] + branches: [ main ] pull_request: - branches: [ main, develop ] + branches: [ main ] jobs: lint: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10'] + python-version: ['3.9'] steps: - uses: actions/checkout@v4 @@ -35,7 +35,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10'] + python-version: ['3.9'] steps: - uses: actions/checkout@v4