Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
* `--parallel` flag to run independent modules concurrently for speed
* Autonomous `--agent` mode with `--watch` and `--diff-only` for stateful
differential scanning
* Distributed mode via `--distributed` to fan out scans across remote runners

---

Expand Down Expand Up @@ -108,6 +109,21 @@ trivial.

---

## Distributed scanning

Launch `pentest-runner` on remote nodes (or containers) then run the main CLI
with `--distributed`:

```bash
pentest-runner --port 9000 &
pentest-toolkit targets.txt --distributed --runners http://localhost:9000
```

Targets are automatically chunked across runners with health-checks and local
fallback if a node goes offline.

---

## Output example

```json
Expand Down
92 changes: 92 additions & 0 deletions distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

"""Distributed orchestrator for Pentest-Toolkit."""

from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Iterable, List

import requests

from modules.base import Finding
from utils.logger import get_logger

logger = get_logger()


class DistributedOrchestrator:
"""Dispatch targets to remote runners via HTTP."""

def __init__(
self,
runners: Iterable[str],
*,
timeout: int = 600,
runner: Callable[[str, List[str], bool], List[Finding]] | None = None,
) -> None:
self.runners = list(runners)
self.timeout = timeout
self.runner = runner

def _health(self) -> List[str]:
available: List[str] = []
for url in self.runners:
try:
resp = requests.get(f"{url.rstrip('/')}/health", timeout=5)
if resp.status_code == 200:
available.append(url.rstrip('/'))
except requests.RequestException:
logger.warning("\u26a0\ufe0f Runner %s unavailable", url)
return available

def _run_local(self, target: str, tools: List[str], pipeline_mode: bool) -> List[Finding]:
logger.info("\u23f1\ufe0f Local fallback for %s", target)
run = self.runner
if run is None:
from main import pipeline as _pipeline

def run(t: str, tl: List[str], pm: bool) -> List[Finding]:
return _pipeline(t, tl, use_pipeline=pm, show_summary=False)

return run(target, tools, pipeline_mode)

def dispatch(
self,
targets: List[str],
tools: List[str],
*,
pipeline_mode: bool = False,
) -> List[Finding]:
"""Send targets to available runners and aggregate results."""
available = self._health()
if not available:
logger.warning("No runners available, running locally")
results: List[Finding] = []
for t in targets:
results.extend(self._run_local(t, tools, pipeline_mode))
return results

results: List[Finding] = []

def send(url: str, tgt: str) -> List[Finding]:
try:
resp = requests.post(
f"{url}/scan",
json={"target": tgt, "tools": tools, "pipeline": pipeline_mode},
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json().get("findings", [])
return [Finding(tool=d.pop("tool"), data=d) for d in data]
except Exception as exc: # noqa: BLE001
logger.error("Runner %s failed for %s: %s", url, tgt, exc)
return self._run_local(tgt, tools, pipeline_mode)

with ThreadPoolExecutor() as pool:
futs = []
for idx, tgt in enumerate(targets):
url = available[idx % len(available)]
futs.append(pool.submit(send, url, tgt))
for fut in futs:
results.extend(fut.result())

return results
65 changes: 56 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from utils.notifiers import Notifier
from utils.deps import DependencyError, check_dependencies
from utils.plugins import load_plugins
from distributed import DistributedOrchestrator

load_plugins()

Expand Down Expand Up @@ -349,8 +350,25 @@ def cli() -> None:
action="store_true",
help="Only output new findings compared to previous scan",
)
parser.add_argument(
"--distributed",
action="store_true",
help="Enable distributed mode with remote runners",
)
parser.add_argument(
"--runners",
help="Comma-separated list of runner URLs (default: $PENTEST_TOOLKIT_RUNNERS)",
)
args = parser.parse_args()

runners: List[str] = []
if args.runners:
runners = [r.strip() for r in args.runners.split(',') if r.strip()]
else:
env = os.environ.get('PENTEST_TOOLKIT_RUNNERS')
if env:
runners = [r.strip() for r in env.split(',') if r.strip()]

notify_names: List[str] = []
if args.notify:
notify_names.extend(args.notify)
Expand Down Expand Up @@ -476,20 +494,49 @@ def run_agent() -> None:
logger.info("⏳ Sleeping %s seconds", args.interval)
time.sleep(args.interval)

def run_distributed() -> None:
orchestrator = DistributedOrchestrator(runners)
res = orchestrator.dispatch(targets, args.tools, pipeline_mode=args.pipeline)
prefix = "master"
write_json(res, args.out, prefix=prefix)
if args.report == "html":
write_html(res, args.out, prefix=prefix)
elif args.report == "pdf":
write_pdf(res, args.out, prefix=prefix)
elif args.report == "markdown":
write_markdown(res, args.out, prefix=prefix)
elif args.report == "summary":
write_markdown(res, args.out, prefix=prefix, summary_only=True)
for notifier in notifiers:
try:
notifier.send(res)
except Exception as exc: # noqa: BLE001
logger.error("❌ Notifier error: %s", exc)
if args.strict_notify:
raise SystemExit(1)
logger.info(
"✅ Completed %s targets – %s findings collected",
len(targets),
len(res),
)

if args.agent:
run_agent()
else:
if args.parallel and len(targets) > 1:
with ThreadPoolExecutor() as pool:
for res in pool.map(run_one, targets):
all_findings.extend(res)
if args.distributed:
run_distributed()
else:
for t in targets:
all_findings.extend(run_one(t))
if args.parallel and len(targets) > 1:
with ThreadPoolExecutor() as pool:
for res in pool.map(run_one, targets):
all_findings.extend(res)
else:
for t in targets:
all_findings.extend(run_one(t))

logger.info(
"✅ Completed %s targets – %s findings collected", len(targets), len(all_findings)
)
logger.info(
"✅ Completed %s targets – %s findings collected", len(targets), len(all_findings)
)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ full = []

[project.scripts]
pentest-toolkit = "main:cli"
pentest-runner = "runner:run_server"

[project.urls]
Homepage = "https://github.com/psychevus/pentest-toolkit"

[tool.setuptools]
packages = ["modules", "utils"]
py-modules = ["main", "lambda_function"]
py-modules = ["main", "lambda_function", "distributed", "runner"]
57 changes: 57 additions & 0 deletions runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
from __future__ import annotations

"""HTTP runner for remote execution."""

import json
from http.server import BaseHTTPRequestHandler, HTTPServer

from utils.logger import get_logger
from utils.plugins import load_plugins
from modules.base import Module
from main import pipeline

load_plugins()
logger = get_logger()


class Handler(BaseHTTPRequestHandler):
def _send(self, status: int, body: bytes = b"{}", *, content: str = "application/json") -> None:
self.send_response(status)
self.send_header("Content-Type", content)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)

def do_GET(self) -> None: # noqa: D401
if self.path == "/health":
self._send(200, b'{"status":"ok"}')
else:
self._send(404, b"{}")

def do_POST(self) -> None: # noqa: D401
if self.path != "/scan":
self._send(404, b"{}")
return
length = int(self.headers.get("Content-Length", 0))
payload = json.loads(self.rfile.read(length) or b"{}")
target = payload.get("target")
tools = payload.get("tools", list(Module.registry.keys()))
pipe = payload.get("pipeline", False)
try:
res = pipeline(target, tools, use_pipeline=pipe, show_summary=False)
data = json.dumps({"findings": [f.asdict() for f in res]}).encode()
self._send(200, data)
except Exception as exc: # noqa: BLE001
logger.error("Runner error: %s", exc)
self._send(500, json.dumps({"error": str(exc)}).encode())


def run_server(host: str = "0.0.0.0", port: int = 8000) -> None:
server = HTTPServer((host, port), Handler)
logger.info("Runner listening on %s:%s", host, port)
server.serve_forever()


if __name__ == "__main__": # pragma: no cover
run_server()
65 changes: 65 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import json
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
import sys

sys.path.append(str(Path(__file__).resolve().parents[1]))

from distributed import DistributedOrchestrator
from modules.base import Finding
import main


class DummyHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: D401
if self.path == "/health":
self.send_response(200)
self.end_headers()
self.wfile.write(b"ok")
else:
self.send_response(404)
self.end_headers()

def do_POST(self): # noqa: D401
if self.path != "/scan":
self.send_response(404)
self.end_headers()
return
length = int(self.headers.get("Content-Length", 0))
payload = json.loads(self.rfile.read(length) or b"{}")
out = json.dumps({"findings": [{"tool": "dummy", "target": payload["target"]}]}).encode()
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(out)))
self.end_headers()
self.wfile.write(out)


def _start_server():
srv = HTTPServer(("localhost", 0), DummyHandler)
thread = threading.Thread(target=srv.serve_forever)
thread.daemon = True
thread.start()
return srv, thread


def test_dispatched_results(monkeypatch):
monkeypatch.setattr(main, "pipeline", lambda *a, **k: [Finding(tool="local", data={})])
srv, th = _start_server()
orch = DistributedOrchestrator([f"http://localhost:{srv.server_address[1]}"])
res = orch.dispatch(["a.com", "b.com"], ["dummy"])
srv.shutdown()
th.join()
assert len(res) == 2
assert any(f.tool == "dummy" for f in res)


def test_fallback(monkeypatch):
monkeypatch.setattr(main, "pipeline", lambda *a, **k: [Finding(tool="local", data={"t": a[0]})])
orch = DistributedOrchestrator(["http://localhost:12345"])
res = orch.dispatch(["x.com"], ["dummy"])
assert res[0].tool == "local"