Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions cyberai/agents/recon/nmap_tool.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,90 @@
import shlex
import subprocess
from typing import Dict, Any
from typing import Any, Dict, List

from cyberai.core.security.input_sanitizer import sanitize_target
from cyberai.core.cache import FileCache
from pathlib import Path

# Whitelist of nmap flags the toolkit is allowed to pass through.
# Anything outside this set is rejected — prevents abuse like
# -oN /etc/cron.d/x, --script=<malicious>, or arbitrary file writes.
ALLOWED_FLAGS = {
"-sV", "-sC", "-sT", "-sS", "-sU", "-sn",
"-T0", "-T1", "-T2", "-T3", "-T4", "-T5",
"-Pn", "-A", "-O",
"-p", "--top-ports", "-oX",
}

# Flags that consume the next token as a value (port spec, count, etc.).
_VALUE_FLAGS = {"-p", "--top-ports", "-oX"}

# Dedicated 1-hour cache for nmap results, keyed by target+flags.
# Avoids re-scanning the same target repeatedly within a session.
NMAP_CACHE_TTL = 3600 # 1 hour
_nmap_cache = FileCache(
cache_dir=Path.home() / ".cyberai" / "nmap-cache",
ttl=NMAP_CACHE_TTL,
)


def _cache_key(target: str, flags: str) -> str:
return f"nmap:{target}:{flags}"


def validate_flags(flags: str) -> List[str]:
"""Parse a flag string via shlex and reject anything not whitelisted.

Returns the validated token list. Raises ValueError on the first
unknown flag so a malicious flag string never reaches subprocess.
"""
tokens = shlex.split(flags)
safe: List[str] = []
i = 0
while i < len(tokens):
tok = tokens[i]
if tok not in ALLOWED_FLAGS:
raise ValueError(f"Rejected nmap flag: {tok!r}")
safe.append(tok)
if tok in _VALUE_FLAGS and i + 1 < len(tokens):
safe.append(tokens[i + 1])
i += 1
i += 1
return safe

def run_nmap(target: str, flags: str = "-sV -T4 --top-ports 1000") -> Dict[str, Any]:
"""
Run nmap against target, return parsed results.
Requires nmap installed on system.
"""
cmd = ["nmap", "-oX", "-"] + flags.split() + [target]
safe_target = sanitize_target(target)
try:
safe_flags = validate_flags(flags)
except ValueError as exc:
return {"target": target, "error": f"unsafe nmap flags: {exc}"}

cache_key = _cache_key(safe_target, flags)
cached = _nmap_cache.get(cache_key)
if cached is not None:
cached["cached"] = True
return cached

cmd = ["nmap", "-oX", "-"] + safe_flags + [safe_target]
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=120
)
return {
parsed = {
"target": target,
"raw": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode,
"ports": _parse_ports(result.stdout),
"cached": False,
}
if result.returncode == 0:
_nmap_cache.set(cache_key, parsed)
return parsed
except subprocess.TimeoutExpired:
return {"target": target, "error": "nmap timeout after 120s"}
except FileNotFoundError:
Expand Down
82 changes: 0 additions & 82 deletions cyberai/agents/recon/nmap_wrapper.py

This file was deleted.

92 changes: 92 additions & 0 deletions tests/unit/test_nmap_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Unit tests for nmap flag whitelist and result caching (day 10)."""
from __future__ import annotations

from unittest.mock import patch, MagicMock

import pytest

from cyberai.agents.recon import nmap_tool
from cyberai.agents.recon.nmap_tool import validate_flags, run_nmap


# ── flag whitelist ────────────────────────────────────────────────────

def test_allowed_flags_pass():
assert validate_flags("-sV -T4 --top-ports 1000") == [
"-sV", "-T4", "--top-ports", "1000"
]


def test_value_flag_keeps_its_argument():
assert validate_flags("-p 80,443 -sV") == ["-p", "80,443", "-sV"]


@pytest.mark.parametrize("bad", [
"-sV; rm -rf /",
"-oN /etc/cron.d/x",
"--script=http-vuln",
"-sV && curl evil.com",
"--unsafe-flag",
])
def test_unknown_flags_rejected(bad):
with pytest.raises(ValueError):
validate_flags(bad)


def test_run_nmap_rejects_unsafe_flags_gracefully():
"""Unsafe flags must not crash — run_nmap returns an error dict."""
result = run_nmap("scanme.test", flags="-sV; rm -rf /")
assert "error" in result
assert "unsafe" in result["error"].lower()


# ── caching ───────────────────────────────────────────────────────────

@pytest.fixture(autouse=True)
def _clean_cache():
nmap_tool._nmap_cache.clear()
yield
nmap_tool._nmap_cache.clear()


def _fake_proc(stdout: str = "", rc: int = 0) -> MagicMock:
proc = MagicMock()
proc.stdout = stdout
proc.stderr = ""
proc.returncode = rc
return proc


def test_cache_miss_then_hit():
"""First call runs nmap; second identical call comes from cache."""
fake = _fake_proc(stdout="<nmaprun></nmaprun>", rc=0)
with patch.object(nmap_tool.subprocess, "run", return_value=fake) as m:
first = run_nmap("scanme.test", flags="-sV")
second = run_nmap("scanme.test", flags="-sV")

assert first["cached"] is False
assert second["cached"] is True
# subprocess.run called only once — second served from cache
assert m.call_count == 1


def test_failed_scan_not_cached():
"""A non-zero return code must not be cached."""
fake = _fake_proc(stdout="", rc=1)
with patch.object(nmap_tool.subprocess, "run", return_value=fake) as m:
run_nmap("scanme.test", flags="-sV")
run_nmap("scanme.test", flags="-sV")

# both calls hit subprocess — nothing was cached
assert m.call_count == 2


def test_different_flags_different_cache():
"""Different flags must not collide in the cache."""
fake = _fake_proc(stdout="<nmaprun></nmaprun>", rc=0)
with patch.object(nmap_tool.subprocess, "run", return_value=fake) as m:
run_nmap("scanme.test", flags="-sV")
run_nmap("scanme.test", flags="-sV -Pn")

# different flag strings -> two real scans
assert m.call_count == 2
Loading