diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b2dc0b6..8072305 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,9 @@ jobs: - name: Install dependencies run: | pip install --upgrade pip - pip install -e ".[dev]" + # install both dev and api extras so that the full test + # matrix (unit + integration + API) can execute. + pip install -e ".[api,dev]" - name: Ruff run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b211ae7..ec53d3f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ fail_fast: true repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.15.12 hooks: - id: ruff args: [--fix] @@ -40,7 +40,6 @@ repos: - pyright==1.1.405 - pydantic - python-dotenv - - flask - typer - rich - httpx @@ -48,4 +47,8 @@ repos: - beautifulsoup4 - pyyaml - cryptography + - fastapi + - uvicorn + - python-multipart + - starlette - orjson diff --git a/README.md b/README.md index 214c3e3..e66dd30 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Includes safe, controlled penetration testing capabilities for XSS detection, SQ ### Professional Integration -- **REST API (In Development)**: Planned full-featured API for seamless integration with existing security workflows and tools (currently not publicly available) +- **REST API (Beta)**: A powerful REST interface for seamless integration with existing security workflows. Built with FastAPI, it includes interactive documentation and allows for remote orchestration of scans. - **Command Line Interface**: Powerful CLI with rich formatting and automation support for security professionals - **Flexible Export Options**: Generate comprehensive reports in JSON, CSV, and structured formats - **Configuration Management**: Centralized, persistent configuration system for enterprise deployment @@ -82,12 +82,27 @@ Enhance your methodology with systematic reconnaissance tools that uncover hidde ### Installation ```bash -# Install from source +# 1. clone the repository and install the package git clone https://github.com/HC-ONLINE/CiberWebScan.git cd CiberWebScan + +# 2. create a virtual environment (optional but recommended) +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate + +# 3. install the package and dependencies + +# CLI only pip install -e . -# Verify installation +# CLI + API +pip install -e "[api]" + +# Full Developer Setup +# if you are running the developer tests you will also want the dev dependencies, which include testing frameworks and tools +pip install -e "[api,dev]" + +# verify that the tool is available ciberwebscan --help ``` @@ -123,20 +138,37 @@ ciberwebscan attack --url https://testsite.example.com --xss ciberwebscan attack --url https://testsite.example.com --enumeration ``` -### REST API Integration (In Development) +### REST API Integration + +> **API Preview**: The REST interface is functional but considered "unstable." Endpoint signatures and JSON schemas may change as we refine the 2.0.0 specification. + +To start the server: + +```bash +ciberwebscan api run +``` + +### Interactive Documentation + +Once the server is running, you can explore and test all available endpoints through the built-in interactive UI: + +- **Swagger UI**: http://localhost:8000/docs +- **ReDoc**: http://localhost:8000/redoc -> **Note**: The REST API is currently under development and not publicly available. This section shows planned usage examples for future releases. +### Programmatic Access Example + +You can also integrate CiberWebScan into your own scripts using the requests library: ```python import requests -# Security analysis via API (planned) -response = requests.post("http://localhost:5000/api/analyze", json={ - "url": "https://target.example.com", - "checks": ["fingerprint", "ssl", "headers", "cve"] +# Security analysis via REST API +response = requests.post("http://localhost:8000/api/analyze", json={ + "url": "https://target.example.com" }) -analysis_results = response.json() +results = response.json() +# Returns: {"success": true, "data": {"technologies": [...], "vulnerabilities": [...]}, ...} ``` --- @@ -192,6 +224,18 @@ ciberwebscan --help --- +## Community & Support + +Contributions are what make the open-source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. + +- **Found a bug?** Open an [Issue](https://github.com/HC-ONLINE/CiberWebScan/issues) describing the problem. +- **Want a new feature?** Feel free to submit a [Pull Request](https://github.com/HC-ONLINE/CiberWebScan/pulls) with your proposal. +- **Enjoying the tool?** Give us a ⭐ on GitHub to show your support! + +## Before contributing, please read our [Contributing Guide](docs/CONTRIBUTING.md) to maintain code quality and consistency. + +--- + ## Documentation - **[Installation Guide](docs/INSTALLATION.md)** - Complete setup and installation instructions diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..06ea2ac --- /dev/null +++ b/docs/API.md @@ -0,0 +1,811 @@ +# API Documentation + +Complete REST API reference for CiberWebScan (Beta). + +## Quick Start + +### Starting the API + +```bash +ciberwebscan api run +# Server runs on http://localhost:8000 +# API docs: http://localhost:8000/docs +# ReDoc docs: http://localhost:8000/redoc +``` + +### Authentication + +The API requires an API key via the `X-API-Key` header: + +```bash +curl -X POST "http://localhost:8000/api/analyze" \ + -H "X-API-Key: your-api-key-here" \ + -H "Content-Type: application/json" \ + -d '{"url": "https://example.com"}' +``` + +### Response Format + +All responses follow a consistent structure: + +```json +{ + "success": true, + "data": { ... }, + "error": null, + "timestamp": "2026-01-01T00:00:00Z", + "download_token": "token-uuid", + "download_url": "/api/download/token-uuid" +} +``` + +## Endpoints + +### Health & Status + +#### GET /health + +Basic health check endpoint (no authentication required). + +**Response:** + +```json +{ + "status": "healthy", + "version": "2.0.0", + "message": "CiberWebScan API is running", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +#### GET /health/ready + +Readiness check for container orchestration (no authentication required). + +**Response:** + +```json +{ + "status": "ready", + "version": "2.0.0", + "message": "CiberWebScan API is ready to accept requests", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +### Security Analysis + +#### POST /api/analyze + +Perform comprehensive security analysis on a URL. + +**Includes:** + +- SSL/TLS certificate analysis +- Technology fingerprinting +- Security headers evaluation +- CVE lookup and vulnerability intelligence + +**Request:** + +```json +{ + "url": "https://example.com", + "ssl": true, + "fingerprint": true, + "analyze_headers": true, + "cve": true, + "ssl_verify": true, + "timeout": 30.0, + "ssl_timeout": 10.0, + "deep_scan": false, + "cve_sources": ["nvd"], + "cve_limit": 100, + "cve_severity": null, + "headers": {}, + "cookies": {} +} +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ----------------- | ------- | -------- | ----------------------------------------------------- | +| `url` | string | required | Target URL | +| `ssl` | boolean | true | Perform SSL/TLS analysis | +| `fingerprint` | boolean | true | Detect technologies | +| `analyze_headers` | boolean | true | Analyze security headers | +| `cve` | boolean | true | Lookup CVEs | +| `ssl_verify` | boolean | true | Verify SSL certificates | +| `timeout` | float | 30.0 | Request timeout (1.0-300.0) | +| `ssl_timeout` | float | 10.0 | SSL analysis timeout (1.0-120.0) | +| `deep_scan` | boolean | false | Deeper fingerprinting | +| `cve_sources` | array | [] | CVE sources: nvd, vulners, circl | +| `cve_limit` | integer | 100 | Max CVEs to retrieve (1-1000) | +| `cve_severity` | string | null | Filter by severity: critical, high, medium, low, info | +| `headers` | object | {} | Custom HTTP headers | +| `cookies` | object | {} | Cookies | +| `proxy` | string | null | HTTP proxy URL | +| `user_agent` | string | null | Custom User-Agent | +| `check_robots` | boolean | false | Respect robots.txt | +| `enrich_exploits` | boolean | false | Enrich CVEs with exploit info | +| `export` | string | null | Export file path | +| `export_format` | string | json | json, jsonl, or csv | + +**Response:** + +```json +{ + "success": true, + "data": { + "url": "https://example.com", + "timestamp": "2026-01-01T00:00:00Z", + "ssl_analysis": { + "protocol_version": "TLSv1.3", + "certificate": { + "subject": "CN=example.com", + "issuer": "...", + "valid_from": "2026-01-01T00:00:00Z", + "valid_until": "2026-01-01T00:00:00Z" + }, + "ciphers": [...], + "vulnerabilities": [] + }, + "fingerprint": { + "technologies": [ + { + "name": "nginx", + "version": "1.24.0", + "category": "Web Servers" + } + ] + }, + "headers_analysis": { + "headers": {...}, + "security_issues": [...] + }, + "cve_results": [ + { + "id": "CVE-2026-1234", + "severity": "HIGH", + "description": "...", + "cvss_score": 7.5 + } + ] + }, + "download_token": "uuid-token", + "download_url": "/api/download/uuid-token", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +### Scraping + +#### POST /api/scrape + +Scrape a single URL with flexible extraction options. + +**Supports:** + +- Static scraping (BeautifulSoup) +- Dynamic scraping (Playwright for JavaScript-rendered content) +- CSS selector extraction +- Structured data extraction with schemas +- Pagination + +**Request:** + +```json +{ + "url": "https://example.com", + "dynamic": false, + "wait_for": null, + "selector": ".product", + "attributes": ["id", "title", "price"], + "extraction_schema": null, + "timeout": 30.0, + "headers": {}, + "cookies": {} +} +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| --------------------- | ------- | -------- | ------------------------------------------- | +| `url` | string | required | Target URL | +| `dynamic` | boolean | false | Use Playwright for JS rendering | +| `wait_for` | string | null | CSS selector to wait for (dynamic mode) | +| `selector` | string | null | CSS selector for focused extraction | +| `attributes` | array | [] | Attributes to extract from matched elements | +| `extraction_schema` | object | null | Structured extraction schema | +| `pagination_selector` | string | null | Selector for pagination links | +| `pagination_limit` | integer | 1 | Max pages to traverse (1-1000) | +| `timeout` | float | 30.0 | Request timeout in seconds (1.0-120.0) | +| `headers` | object | {} | Custom HTTP headers | +| `cookies` | object | {} | Cookies to include | +| `proxy` | string | null | HTTP/HTTPS proxy URL | +| `user_agent` | string | null | Custom User-Agent | +| `check_robots` | boolean | true | Respect robots.txt | +| `export` | string | null | Export file path | +| `export_format` | string | json | json, jsonl, or csv | + +**Response:** + +```json +{ + "success": true, + "data": { + "url": "https://example.com", + "timestamp": "2026-01-01T00:00:00Z", + "title": "Page Title", + "links": [...], + "images": [...], + "forms": [...], + "scripts": [...], + "extracted_data": [...] + }, + "download_token": "uuid-token", + "download_url": "/api/download/uuid-token", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +#### POST /api/scrape/batch + +Scrape multiple URLs in batch. + +**Request:** + +```json +{ + "urls": ["https://example.com/page1", "https://example.com/page2"], + "dynamic": false, + "selector": ".item", + "timeout": 30.0, + "headers": {}, + "cookies": {} +} +``` + +**Response:** + +```json +{ + "success": true, + "data": { + "job_id": "uuid-string", + "results": [ + { + "url": "https://example.com/page1", + "success": true, + "extracted_data": [...] + } + ], + "failed_urls": [ + { + "url": "https://example.com/page3", + "error": "Timeout" + } + ], + "total_success": 2, + "total_failed": 1 + }, + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +### Attack Simulation + +#### POST /api/attack + +Perform controlled security attack simulations. + +**IMPORTANT**: Only test systems you own or have explicit written authorization to test. Unauthorized security testing is illegal. + +**Supports:** + +- XSS (Cross-Site Scripting) detection +- SQL Injection testing +- Path Traversal vulnerability testing +- Directory Enumeration +- Custom payloads and wordlists + +**Request:** + +```json +{ + "url": "https://testsite.example.com", + "xss": true, + "sqli": true, + "traversal": true, + "enumeration": true, + "all_attacks": false, + "intensity": "medium", + "max_payloads": 50, + "timeout": 10.0, + "delay_between_requests": 0.1, + "concurrent_requests": 1, + "skip_dangerous_payloads": true, + "user_consent": true +} +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ------------------------- | ----------- | ------------ | -------------------------------- | +| `url` | string | required | Target URL | +| `xss` | boolean | null | Test XSS vulnerabilities | +| `sqli` | boolean | null | Test SQL injection | +| `traversal` | boolean | null | Test path traversal | +| `enumeration` | boolean | null | Directory enumeration | +| `all_attacks` | boolean | false | Enable all attack types | +| `intensity` | string | medium | low, medium, or high | +| `max_payloads` | integer | null | Max payloads per attack (1-1000) | +| `custom_payloads_file` | string | null | Path to custom payloads | +| `custom_wordlist` | string | null | Custom wordlist for enumeration | +| `timeout` | float | 10.0 | Request timeout (1.0-300.0) | +| `delay_between_requests` | float | 0.1 | Delay between requests (seconds) | +| `concurrent_requests` | integer | 1 | Concurrent requests (1-10) | +| `skip_dangerous_payloads` | boolean | true | Skip dangerous payloads | +| `scope_urls` | array | [] | URLs to scope attack to | +| `export` | string | null | Export file path | +| `export_format` | string | json | json, jsonl, or csv | +| `headers` | object | {} | Custom HTTP headers | +| `cookies` | object | {} | Cookies | +| `proxy` | string | null | HTTP proxy URL | +| `user_agent` | string | null | Custom User-Agent | +| `verbose` | boolean | false | Verbose output | +| **`user_consent`** | **boolean** | **required** | **Must be true to proceed** | + +**Response:** + +```json +{ + "success": true, + "data": { + "url": "https://testsite.example.com", + "timestamp": "2026-01-01T00:00:00Z", + "attack_results": { + "xss": { + "vulnerabilities_found": 2, + "payloads_tested": 50, + "results": [ + { + "type": "XSS", + "parameter": "search", + "payload": "", + "status": "vulnerable", + "severity": "HIGH" + } + ] + }, + "sqli": { + "vulnerabilities_found": 0, + "payloads_tested": 50, + "results": [] + }, + "traversal": { + "vulnerabilities_found": 1, + "payloads_tested": 30, + "results": [...] + }, + "enumeration": { + "paths_found": 15, + "common_paths_tested": 1000, + "results": [...] + } + }, + "total_vulnerabilities": 3, + "risk_score": 72 + }, + "download_token": "uuid-token", + "download_url": "/api/download/uuid-token", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +**Error Response (Missing Consent):** + +```json +{ + "success": false, + "error": "Attack simulation requires user_consent=true. Only test systems you own or have explicit permission to test.", + "error_code": "VALIDATION_ERROR", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +### Downloads + +#### GET /api/download/{token} + +Download previously exported results using a time-limited token. + +**URL Parameters:** + +| Parameter | Type | Description | +| --------- | ------ | ----------------------------------- | +| `token` | string | Download token from a POST response | + +**Headers:** + +``` +X-API-Key: your-api-key-here +``` + +**Response:** File stream (application/octet-stream) + +**Status Codes:** + +| Code | Description | +| ---- | ----------------------------------------------------- | +| 200 | File downloaded successfully | +| 400 | Invalid token | +| 401 | Unauthorized (different user or max retries exceeded) | +| 404 | Token not found | +| 410 | Token expired | +| 429 | Too many retry attempts | +| 503 | Download service unavailable | + +**Example:** + +```bash +curl -H "X-API-Key: your-key" \ + "http://localhost:8000/api/download/uuid-token" \ + -o results.json +``` + +### Authentication + +#### GET /api/auth/me + +Get information about the current authenticated user. + +**Response:** + +```json +{ + "identifier": "user-id", + "auth_method": "api_key", + "scopes": ["full_access"], + "authenticated": true +} +``` + +#### POST /api/auth/generate-key + +Generate a new API key (requires admin/full_access scope). + +**Response:** + +```json +{ + "api_key": "new-key-uuid", + "message": "Store this key securely. It cannot be retrieved again." +} +``` + +### Configuration + +#### GET /api/config + +Retrieve all configuration settings. + +**Response:** + +```json +{ + "success": true, + "data": { + "api": {...}, + "scraping": {...}, + "analysis": {...}, + "attacks": {...} + }, + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +#### GET /api/config/sections/{section} + +Retrieve a specific configuration section. + +**URL Parameters:** + +| Parameter | Type | Description | +| --------- | ------ | ----------------------------------------------- | +| `section` | string | Section name (api, scraping, analysis, attacks) | + +**Response:** + +```json +{ + "success": true, + "data": { + "key1": "value1", + "key2": "value2" + }, + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +#### PUT /api/config + +Update configuration settings. + +**Request:** + +```json +{ + "updates": { + "scraping.timeout": 45.0, + "analysis.cve_limit": 150 + } +} +``` + +#### POST /api/config/export + +Export configuration to file. + +**Request:** + +```json +{ + "format": "yaml" +} +``` + +#### POST /api/config/load + +Load configuration from file. + +**Request:** + +```json +{ + "file_path": "/path/to/config.yaml" +} +``` + +#### POST /api/config/reset + +Reset to default configuration. + +**Request:** + +```json +{ + "section": "scraping" +} +``` + +#### POST /api/config/save + +Save configuration to file. + +**Request:** + +```json +{ + "file_path": "/path/to/config.yaml", + "format": "yaml" +} +``` + +**Response:** + +```json +{ + "success": true, + "data": { + "file_path": "/path/to/config.yaml", + "format": "yaml", + "message": "Configuration saved successfully" + }, + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +## Error Handling + +All error responses include standardized error information: + +### Validation Error + +```json +{ + "success": false, + "error": "Validation error", + "error_code": "VALIDATION_ERROR", + "details": [ + { + "field": "url", + "message": "Invalid URL format", + "value": "not-a-url" + } + ], + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +### General Error + +```json +{ + "success": false, + "error": "Internal server error", + "error_code": "INTERNAL_ERROR", + "timestamp": "2026-01-01T00:00:00Z" +} +``` + +### Common HTTP Status Codes + +| Status | Meaning | +| ------ | -------------------------------------- | +| 200 | Success | +| 400 | Bad request (validation error) | +| 401 | Unauthorized (missing/invalid API key) | +| 403 | Forbidden (insufficient permissions) | +| 404 | Not found | +| 429 | Too many requests (rate limited) | +| 500 | Internal server error | +| 503 | Service unavailable | + +## Configuration + +### Environment Variables + +Configure the API via environment variables: + +```bash +# Server +CIBERWEBSCAN_API_HOST=0.0.0.0 +CIBERWEBSCAN_API_PORT=8000 + +# Authentication +CIBERWEBSCAN_API_AUTH_API_KEYS="key1,key2,key3" + +# Rate Limiting +CIBERWEBSCAN_API_RATE_LIMIT_ENABLED=true +CIBERWEBSCAN_API_RATE_LIMIT_REQUESTS_PER_MINUTE=60 + +# CORS +CIBERWEBSCAN_API_CORS_ORIGINS="http://localhost:3000,https://example.com" + +# Downloads +CIBERWEBSCAN_DOWNLOAD_ENABLED=true +CIBERWEBSCAN_DOWNLOAD_TOKEN_EXPIRY_MINUTES=60 +CIBERWEBSCAN_DOWNLOAD_MAX_RETRY_ATTEMPTS=3 +``` + +### Rate Limiting + +The API includes rate limiting to prevent abuse: + +- **Default**: 60 requests per minute +- **Configurable** via `CIBERWEBSCAN_API_RATE_LIMIT_REQUESTS_PER_MINUTE` +- **429 status code** returned when exceeded +- **Retry-After header** included in rate limit responses + +### CORS + +Cross-Origin Resource Sharing is configured for: + +- Origins: Configurable via environment variable +- Credentials: Enabled +- Methods: All (GET, POST, PUT, DELETE, etc.) +- Headers: All + +## API Documentation + +### Interactive API Docs + +Once the API server is running, visit: + +- **Swagger UI**: http://localhost:8000/docs +- **ReDoc**: http://localhost:8000/redoc +- **OpenAPI JSON**: http://localhost:8000/openapi.json + +### Example cURL Requests + +**Health Check:** + +```bash +curl http://localhost:8000/health +``` + +**Analyze URL:** + +```bash +curl -X POST "http://localhost:8000/api/analyze" \ + -H "X-API-Key: your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "url": "https://example.com", + "ssl": true, + "fingerprint": true, + "analyze_headers": true + }' +``` + +**Scrape URL:** + +```bash +curl -X POST "http://localhost:8000/api/scrape" \ + -H "X-API-Key: your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "url": "https://example.com", + "selector": ".product" + }' +``` + +**Attack Test:** + +```bash +curl -X POST "http://localhost:8000/api/attack" \ + -H "X-API-Key: your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "url": "https://testsite.example.com", + "xss": true, + "user_consent": true + }' +``` + +## Best Practices + +### Security + +- **Never commit API keys** to version control +- **Store keys securely** using environment variables or secrets manager +- **Rotate keys regularly** for accounts with multiple keys +- **Use HTTPS only** in production +- **Use IP whitelisting** if behind a firewall + +### Performance + +- **Use timeouts** appropriately (30-60 seconds for most operations) +- **Implement exponential backoff** for retries +- **Respect rate limits** and implement queuing on your side +- **Use batch endpoints** for multiple URLs when available +- **Cache results** when appropriate + +### Error Handling + +- **Check HTTP status codes** before parsing response +- **Log all errors** for debugging +- **Implement retry logic** with exponential backoff +- **Handle timeouts gracefully** +- **Validate all responses** before processing + +### Compliance + +- **Only test authorized targets** - obtain written permission +- **Follow responsible disclosure** for vulnerabilities +- **Comply with local laws** regarding security testing +- **Use appropriate delays** between requests +- **Respect robots.txt** when configured + +## Support & Troubleshooting + +### Common Issues + +**401 Unauthorized**: Verify your API key is correct and included in `X-API-Key` header + +**429 Too Many Requests**: Rate limit exceeded. Implement exponential backoff + +**503 Service Unavailable**: Download service disabled. Check configuration + +**Connection Refused**: Ensure API server is running on the correct host/port + +### Getting Help + +- Check [API Documentation](docs/API.md) +- Review [Configuration Guide](docs/CONFIGURATION.md) +- See [CLI Reference](docs/CLI.md) +- Report issues on GitHub diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 424791d..34fb853 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -5,6 +5,50 @@ All notable changes to CiberWebScan will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 2.1.0 - 2026-04-29 + +### Added + +- Complete REST API implementation with FastAPI framework (now in Beta) +- Authentication module with API key-based security and management endpoints +- Full API endpoint suite: + - `/api/analyze` - Security analysis with detailed options + - `/api/attack` - URL attack testing with configurable payloads + - `/api/scrape` - Single and batch URL scraping endpoints + - `/api/download` - Token-based file download system with automatic cleanup + - `/api/config` - Configuration management endpoints + - `/health` and `/health/ready` - Health check endpoints +- Request logging and rate limiting middleware +- Download token generation and validation system with configurable expiration +- Enhanced request/response models with export options and validation +- DownloadCleanupScheduler for managing expired download tokens +- Comprehensive API endpoint unit tests and integration tests +- API command in CLI for server management +- Enhanced documentation with API usage guides and beta status indicators +- FastAPI and python-multipart dependencies for API functionality +- Improved CI/CD workflow for full test coverage with both dev and api dependencies + +### Changed + +- Updated APIResponse model with download token and URL fields +- Improved error handling across API endpoints +- Enhanced request models with new validation fields +- Updated health check endpoints to use HealthCheckResponse model +- Refactored configuration system to support API settings +- Optimized pre-commit configuration to include FastAPI checks + +### Fixed + +- Fixed timestamp field to use timezone-aware datetime in API responses +- Improved async/sync context handling in services +- Better input validation across API endpoints + +### Known Issues + +- REST API is in Beta - expect potential breaking changes +- Some advanced API features may still be under development +- Download tokens are stored in memory (production should use persistent storage) + ## 2.0.0 - 2026-02-12 ### Added diff --git a/docs/CLI.md b/docs/CLI.md index 058c3c6..9d7ac90 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -298,6 +298,39 @@ ciberwebscan attack sqli --consent [OPTIONS] ciberwebscan attack sqli https://example.com/product?id=1 --consent ``` +### API Command + +Manage and run the CiberWebScan REST API server. + +#### api run + +Start the REST API server using Uvicorn. + +```bash +ciberwebscan api run [OPTIONS] +``` + +**Options:** + +- --host : Bind socket to this host (default: 0.0.0.0) + +- --port : Bind socket to this port (default: 8000) + +- --reload: Enable auto-reload (development mode) + +**Examples:** + +```bash +# Start the API server on default port 8000 +ciberwebscan api run + +# Start on a custom port and host +ciberwebscan api run --host 127.0.0.1 --port 9000 + +# Run in development mode with auto-reload +ciberwebscan api run --reload +``` + ### Config Command Manage application configuration. diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 6cc140e..aa668db 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -9,11 +9,14 @@ CiberWebScan uses a flexible configuration system that allows customization of v 2. [Configuration File](#configuration-file) 3. [Configuration Sections](#configuration-sections) - [HTTP Client](#http-client) + - [User Agent](#user-agent) - [Scraping](#scraping) - [Analysis](#analysis) - [Attack](#attack) - [Export](#export) - [Cache](#cache) + - [API](#api) + - [Logging](#logging) 4. [CLI Commands](#cli-configuration-commands) 5. [Validation & Troubleshooting](#validation) 6. [Development Roadmap](#development-notes) @@ -444,6 +447,40 @@ Configure caching behavior. | `cache.ttl` | `3600` | Cache TTL (seconds) | | `cache.max_size_mb` | `100` | Max cache size (MB) | +### API + +Configure the FastAPI server and authentication settings. + +```json +{ + "api": { + "host": "0.0.0.0", + "port": 8000, + "auth": { + "api_keys": ["your-api-key-1", "your-api-key-2"] + }, + "rate_limit": { + "enabled": true, + "requests_per_minute": 60 + }, + "cors_origins": ["*"] + } +} +``` + +#### Default values (quick reference) + +| Key | Default | Description | +| ------------------------------------ | --------: | ----------------------------------------- | +| `api.host` | `0.0.0.0` | API server bind address | +| `api.port` | `8000` | API server port (1-65535) | +| `api.auth.api_keys` | `[]` | List of valid API keys for authentication | +| `api.rate_limit.enabled` | `true` | Rate limiting enabled by default | +| `api.rate_limit.requests_per_minute` | `60` | Max requests per minute (1-10000) | +| `api.cors_origins` | `["*"]` | CORS allowed origins | + +> Note: Use specific domains in cors_origins for production environments to improve security. + ### Logging Configure logging behavior. diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md index c24e12b..a83b4cf 100644 --- a/docs/INSTALLATION.md +++ b/docs/INSTALLATION.md @@ -30,7 +30,13 @@ This guide covers the installation and setup of CiberWebScan. pip install -e . ``` -4. (Optional) Install development dependencies: +4. (Optional) API Setup: + + ```bash + pip install -e ".[api]" + ``` + +5. (Optional) Install development dependencies: ```bash pip install -e ".[dev]" diff --git a/pyproject.toml b/pyproject.toml index 0e010ca..d665849 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,10 +25,12 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] requires-python = ">=3.10" + +# core/CLI dependencies – these are pulled in when a user +# runs ``pip install ciberwebscan``. the command-line interface +# and the underlying scanning engine require these libraries. dependencies = [ "beautifulsoup4>=4.13.4", - "fastapi>=0.115.0", - "uvicorn>=0.32.0", "httpx[http2]>=0.27.0", "playwright>=1.58.0", "lxml>=5.3.0", @@ -40,10 +42,24 @@ dependencies = [ "pyyaml>=6.0.2", "pydantic>=2.0.0", "orjson>=3.9.0", - "python-multipart>=0.0.9", ] [project.optional-dependencies] +# users who only need the command-line interface can install the +# package normally; none of the API-specific requirements will be +# pulled in. to enable the REST server build the extra explicitly: +# +# pip install "ciberwebscan[api]" +# +# the API implementation is located in ``src/ciberwebscan/api`` and +# depends on FastAPI/uvicorn, which are therefore kept out of the +# base dependency set. +api = [ + "fastapi>=0.115.0", + "uvicorn>=0.32.0", + "python-multipart>=0.0.9", +] + dev = [ "pytest>=8.3.5", "pytest-asyncio>=1.3.0", @@ -134,6 +150,6 @@ indent-style = "space" [tool.pyright] include = ["src/ciberwebscan"] -exclude = ["**/__pycache__"] +exclude = ["**/__pycache__", "**/tests/**", "**/node_modules/**", "**/.venv/**", "**/.*", "**/build/**", "**/dist/**","**/*.egg-info"] pythonVersion = "3.10" typeCheckingMode = "basic" diff --git a/src/ciberwebscan/api/__init__.py b/src/ciberwebscan/api/__init__.py new file mode 100644 index 0000000..57a208f --- /dev/null +++ b/src/ciberwebscan/api/__init__.py @@ -0,0 +1 @@ +"""CiberWebScan API""" diff --git a/src/ciberwebscan/api/app.py b/src/ciberwebscan/api/app.py new file mode 100644 index 0000000..0ce6aa8 --- /dev/null +++ b/src/ciberwebscan/api/app.py @@ -0,0 +1,150 @@ +""" +Main FastAPI application for CiberWebScan. + +Provides REST API endpoints for scraping, analysis, and attack simulation. +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request +from fastapi.encoders import jsonable_encoder +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from ciberwebscan import __description__, __version__ +from ciberwebscan.api.middleware import ( + add_rate_limiting_middleware, + add_request_logging_middleware, +) +from ciberwebscan.api.models.responses import ErrorResponse +from ciberwebscan.api.routes import ( + analyze, + attack, + auth, + config, + download, + health, + scrape, +) +from ciberwebscan.config.loader import get_config +from ciberwebscan.services.cleanup_scheduler import get_scheduler + +logger = logging.getLogger(__name__) + +prefix = "/api" + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Handle application startup and shutdown.""" + logger.info("Starting CiberWebScan API") + + # Start cleanup scheduler + app_config = get_config() + if app_config.download.enabled: + scheduler = get_scheduler() + scheduler.start() + + yield + + # Stop cleanup scheduler + if app_config.download.enabled: + scheduler = get_scheduler() + scheduler.stop() + + logger.info("Shutting down CiberWebScan API") + + +def create_app() -> FastAPI: + """Create and configure FastAPI application.""" + app_config = get_config() + api_config = app_config.api + + app = FastAPI( + title="CiberWebScan API", + description=__description__, + version=__version__, + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + lifespan=lifespan, + ) + + # Configure CORS from global config + app.add_middleware( + CORSMiddleware, + allow_origins=api_config.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Add custom middleware + add_request_logging_middleware(app) + if api_config.rate_limit.enabled: + add_rate_limiting_middleware( + app, requests_per_minute=api_config.rate_limit.requests_per_minute + ) + + # Exception handlers + @app.exception_handler(ValueError) + async def value_error_handler(request: Request, exc: ValueError) -> JSONResponse: + """Handle validation errors.""" + logger.warning(f"Validation error on {request.url}: {exc}") + error_response = ErrorResponse( + error=str(exc), + error_code="VALIDATION_ERROR", + ) + return JSONResponse( + status_code=400, + content=jsonable_encoder(error_response), + ) + + @app.exception_handler(Exception) + async def general_exception_handler( + request: Request, exc: Exception + ) -> JSONResponse: + """Handle unexpected errors.""" + logger.error(f"Unexpected error on {request.url}: {exc}", exc_info=True) + error_response = ErrorResponse( + error="Internal server error", + error_code="INTERNAL_ERROR", + details={"request_path": str(request.url.path)}, + ) + return JSONResponse( + status_code=500, + content=jsonable_encoder(error_response), + ) + + # Include routers + app.include_router(health.router, tags=["health"]) + app.include_router(auth.router, prefix=prefix + "/auth", tags=["authentication"]) + app.include_router(config.router, prefix=prefix, tags=["configuration"]) + app.include_router(scrape.router, prefix=prefix, tags=["scraping"]) + app.include_router(analyze.router, prefix=prefix, tags=["analysis"]) + app.include_router(attack.router, prefix=prefix, tags=["attacks"]) + app.include_router(download.router, prefix=prefix, tags=["download"]) + + return app + + +# Create the app instance +app = create_app() + + +if __name__ == "__main__": + import uvicorn + + app_config = get_config() + + uvicorn.run( + "ciberwebscan.api.app:app", + host=app_config.api.host, + port=app_config.api.port, + reload=True, + log_level=app_config.logging.level.lower(), + ) diff --git a/src/ciberwebscan/api/auth.py b/src/ciberwebscan/api/auth.py new file mode 100644 index 0000000..20b07b6 --- /dev/null +++ b/src/ciberwebscan/api/auth.py @@ -0,0 +1,222 @@ +""" +Authentication module for CiberWebScan API. + +Provides API Key authentication. +""" + +from __future__ import annotations + +import logging +import secrets +from typing import Annotated + +from fastapi import Depends, HTTPException, Request, Security, status +from fastapi.security import APIKeyHeader +from pydantic import BaseModel + +from ciberwebscan.config.loader import get_config + +logger = logging.getLogger(__name__) + +# ============================================================================= +# Configuration +# ============================================================================= + + +class AuthConfig(BaseModel): + """Authentication configuration.""" + + api_key_enabled: bool = True + api_keys: list[str] = [] + + +def get_auth_config() -> AuthConfig: + """ + Load authentication configuration from global config. + """ + config = get_config() + auth_cfg = config.api.auth + + return AuthConfig( + api_key_enabled=bool(auth_cfg.api_keys), + api_keys=auth_cfg.api_keys, + ) + + +# ============================================================================= +# Security Schemes +# ============================================================================= + +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +# ============================================================================= +# Authentication Dependencies +# ============================================================================= + + +class AuthenticatedUser(BaseModel): + """Authenticated user/client information.""" + + identifier: str + auth_method: str + scopes: list[str] = [] + + +def _secure_compare_key(provided_key: str, stored_keys: list[str]) -> str | None: + """ + Compare API key using constant-time comparison to prevent timing attacks. + Returns the matched key identifier if valid, None otherwise. + """ + for stored_key in stored_keys: + if secrets.compare_digest(provided_key.encode(), stored_key.encode()): + return stored_key[:8] + return None + + +async def verify_api_key( + request: Request, + api_key: Annotated[str | None, Security(api_key_header)] = None, +) -> AuthenticatedUser | None: + """ + Verify API key from X-API-Key header. + + Uses constant-time comparison to prevent timing attacks. + + Returns: + AuthenticatedUser if valid, None if no key provided + """ + client_ip = _get_client_ip(request) + + if not api_key: + return None + + config = get_auth_config() + + if not config.api_key_enabled: + logger.warning( + "API key auth disabled but key provided", + extra={"client_ip": client_ip}, + ) + return None + + # Constant-time comparison + key_id = _secure_compare_key(api_key, config.api_keys) + + if key_id: + logger.info( + f"API key authenticated: {key_id}...", + extra={ + "event": "auth_success", + "key_id": key_id, + "client_ip": client_ip, + }, + ) + return AuthenticatedUser( + identifier=f"apikey:{key_id}", + auth_method="api_key", + scopes=["full_access"], + ) + + # Log failed attempt + logger.warning( + f"Invalid API key attempt from {client_ip}", + extra={ + "event": "auth_failed", + "reason": "invalid_key", + "client_ip": client_ip, + "key_prefix": api_key[:4] + "..." if len(api_key) > 4 else "***", + }, + ) + return None + + +def _get_client_ip(request: Request) -> str: + """Extract client IP from request, handling proxies.""" + forwarded = request.headers.get("X-Forwarded-For", "").split(",")[0].strip() + if forwarded: + return forwarded + return request.client.host if request.client else "unknown" + + +async def get_current_user( + api_key_user: Annotated[AuthenticatedUser | None, Depends(verify_api_key)], + request: Request, +) -> AuthenticatedUser: + """ + Get the current authenticated user. + + Checks API key from X-API-Key header. + Raises 401 if not valid. + """ + if api_key_user: + return api_key_user + + # Log unauthorized access attempt + client_ip = _get_client_ip(request) + logger.warning( + f"Unauthorized access attempt from {client_ip}: {request.method} {request.url.path}", + extra={ + "event": "auth_required", + "client_ip": client_ip, + "method": request.method, + "path": request.url.path, + }, + ) + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required. Provide X-API-Key header.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + +async def get_optional_user( + api_key_user: Annotated[AuthenticatedUser | None, Depends(verify_api_key)], +) -> AuthenticatedUser | None: + """ + Get the current user if authenticated, None otherwise. + + Useful for endpoints that have different behavior for authenticated users. + """ + return api_key_user + + +# ============================================================================= +# Scope/Permission Checking +# ============================================================================= + + +def require_scope(required_scope: str): + """ + Dependency factory that requires a specific scope. + + Usage: + @router.get("/admin", dependencies=[Depends(require_scope("admin"))]) + async def admin_endpoint(): + ... + """ + + async def check_scope( + user: Annotated[AuthenticatedUser, Depends(get_current_user)], + ) -> AuthenticatedUser: + if "full_access" in user.scopes: + return user + if required_scope not in user.scopes: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Scope '{required_scope}' required", + ) + return user + + return check_scope + + +# ============================================================================= +# Utility Functions +# ============================================================================= + + +def generate_api_key() -> str: + """Generate a secure random API key.""" + return secrets.token_urlsafe(32) diff --git a/src/ciberwebscan/api/helpers/__init__.py b/src/ciberwebscan/api/helpers/__init__.py new file mode 100644 index 0000000..8d470b8 --- /dev/null +++ b/src/ciberwebscan/api/helpers/__init__.py @@ -0,0 +1 @@ +"""Helpers for API routes.""" diff --git a/src/ciberwebscan/api/helpers/download_helper.py b/src/ciberwebscan/api/helpers/download_helper.py new file mode 100644 index 0000000..13405b9 --- /dev/null +++ b/src/ciberwebscan/api/helpers/download_helper.py @@ -0,0 +1,54 @@ +""" +Helper functions for download token generation in routes. + +Provides utility functions to enrich service results with download tokens. +""" + +from __future__ import annotations + +from typing import TypeVar + +from ciberwebscan.services.base import ServiceResult +from ciberwebscan.services.download_service import DownloadService + +T = TypeVar("T") + + +def enrich_response_with_token( + result: ServiceResult[T], + user_id: str, + download_service: DownloadService, +) -> tuple[T | None, str | None]: + """ + Intercept export_path from service result and generate download token. + + Safely extracts the file_path from the result and generates a download token. + If no export_path exists or token generation fails, logs error but doesn't fail. + + Args: + result: ServiceResult from analysis/attack/scrape service + user_id: ID of user who made the request + download_service: DownloadService instance + + Returns: + Tuple of (data, download_token) where token is None if not generated + """ + # If no export_path, return data without token + if result.export_path is None: + return result.data, None + + # Generate token from the exported file + token_result = download_service.generate_download_token( + file_path=result.export_path, + user_id=user_id, + file_format="json", # Default format, could be detected from path extension + ) + + # If token generation failed, log but don't error + if not token_result.success or token_result.data is None: + download_service.logger.warning( + f"Failed to generate download token: {token_result.error}" + ) + return result.data, None + + return result.data, token_result.data.token diff --git a/src/ciberwebscan/api/middleware.py b/src/ciberwebscan/api/middleware.py new file mode 100644 index 0000000..f8d0c3d --- /dev/null +++ b/src/ciberwebscan/api/middleware.py @@ -0,0 +1,96 @@ +""" +Middleware for CiberWebScan API. + +Provides request logging, rate limiting, and other cross-cutting concerns. +""" + +from __future__ import annotations + +import logging +import time +from collections import Counter + +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger(__name__) + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + """Log all HTTP requests with timing information.""" + + async def dispatch(self, request: Request, call_next) -> Response: + start_time = time.perf_counter() + client_ip = request.client.host if request.client else "unknown" + path = request.url.path + + try: + response = await call_next(request) + status_code = response.status_code + except Exception as e: + status_code = 500 + raise e from None + finally: + duration = time.perf_counter() - start_time + logger.info( + f"{request.method} {path} - {status_code} ({duration:.3f}s)", + extra={ + "method": request.method, + "path": path, + "status_code": status_code, + "client_ip": client_ip, + "duration": duration, + }, + ) + return response + + +class RateLimitingMiddleware(BaseHTTPMiddleware): + """Simple in-memory rate limiting middleware.""" + + def __init__(self, app, requests_per_minute: int): + super().__init__(app) + self.limit = requests_per_minute + self.counts = Counter() + self.window = 0 + + async def dispatch(self, request: Request, call_next) -> Response: + # Identificación de cliente (Prioriza X-Forwarded-For) + client_ip = request.headers.get("X-Forwarded-For", "").split(",")[ + 0 + ].strip() or (request.client.host if request.client else "unknown") + + now = time.time() + current_window = int(now // 60) + + # Rotación de ventana + if current_window != self.window: + self.counts.clear() + self.window = current_window + + # Verificación de límite + if self.counts[client_ip] >= self.limit: + retry_after = 60 - int(now % 60) + logger.warning(f"Rate limit exceeded: {client_ip}") + return JSONResponse( + status_code=429, + content={ + "error": "Rate limit exceeded", + "retry_after_seconds": retry_after, + }, + headers={"Retry-After": str(retry_after)}, + ) + + self.counts[client_ip] += 1 + return await call_next(request) + + +def add_request_logging_middleware(app: FastAPI) -> None: + """Add request logging middleware to FastAPI app.""" + app.add_middleware(RequestLoggingMiddleware) + + +def add_rate_limiting_middleware(app: FastAPI, requests_per_minute: int = 60) -> None: + """Add rate limiting middleware to FastAPI app.""" + app.add_middleware(RateLimitingMiddleware, requests_per_minute=requests_per_minute) diff --git a/src/ciberwebscan/api/models/__init__.py b/src/ciberwebscan/api/models/__init__.py index 735cb0d..d9fbbe3 100644 --- a/src/ciberwebscan/api/models/__init__.py +++ b/src/ciberwebscan/api/models/__init__.py @@ -7,7 +7,12 @@ from ciberwebscan.api.models.requests import ( AnalyzeRequest, AttackRequest, + ConfigExportRequest, + ConfigLoadRequest, + ConfigResetRequest, + ConfigSaveRequest, ConfigUpdateRequest, + DownloadRequest, ExportRequest, FilterParams, PaginationParams, @@ -15,17 +20,15 @@ ScrapeRequest, ) from ciberwebscan.api.models.responses import ( - AnalyzeResponse, APIResponse, - AttackResponse, - ConfigResponse, - ConfigUpdateResponse, - CVESearchResponse, + ConfigFileResponse, + ConfigKeysResponse, + ConfigValueResponse, DetailedHealthResponse, + DownloadInfo, + DownloadTokenResponse, ErrorResponse, ExportResponse, - FingerprintResponse, - HeadersAnalysisResponse, HealthCheckResponse, JobCreatedResponse, JobStatus, @@ -33,9 +36,7 @@ ScanSummaryResponse, ScrapeBatchResponse, ScrapeBatchResultResponse, - ScrapeResponse, SeveritySummary, - SSLAnalysisResponse, ValidationErrorDetail, ValidationErrorResponse, ) @@ -48,6 +49,11 @@ "AttackRequest", "ExportRequest", "ConfigUpdateRequest", + "ConfigResetRequest", + "ConfigExportRequest", + "ConfigLoadRequest", + "ConfigSaveRequest", + "DownloadRequest", "PaginationParams", "FilterParams", # Response models @@ -58,20 +64,16 @@ "PaginatedResponse", "JobStatus", "JobCreatedResponse", - "ScrapeResponse", "ScrapeBatchResponse", "ScrapeBatchResultResponse", - "AnalyzeResponse", - "SSLAnalysisResponse", - "FingerprintResponse", - "HeadersAnalysisResponse", - "CVESearchResponse", - "AttackResponse", "HealthCheckResponse", "DetailedHealthResponse", "SeveritySummary", "ScanSummaryResponse", "ExportResponse", - "ConfigResponse", - "ConfigUpdateResponse", + "ConfigValueResponse", + "ConfigKeysResponse", + "ConfigFileResponse", + "DownloadTokenResponse", + "DownloadInfo", ] diff --git a/src/ciberwebscan/api/models/requests.py b/src/ciberwebscan/api/models/requests.py index e29b9ff..059b3e3 100644 --- a/src/ciberwebscan/api/models/requests.py +++ b/src/ciberwebscan/api/models/requests.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Annotated, Literal +import json +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, HttpUrl, field_validator @@ -23,15 +24,30 @@ class ScrapeRequest(BaseModel): default=False, description="Use playwright for JavaScript-rendered pages", ) - wait_selector: str | None = Field( + wait_for: str | None = Field( None, description="CSS selector to wait for (dynamic mode only)", ) - extract_links: bool = True - extract_images: bool = True - extract_forms: bool = True - extract_scripts: bool = True - include_raw_html: bool = False + selector: str | None = Field( + default=None, + description="CSS selector used for focused extraction", + ) + attributes: list[str] = Field( + default_factory=list, + description="Attributes to extract from matched elements", + ) + extraction_schema: dict[str, Any] | None = Field( + default=None, + description="Structured extraction schema", + ) + pagination_selector: str | None = Field( + default=None, + description="Selector for pagination links", + ) + pagination_limit: Annotated[int, Field(ge=1, le=1000)] = Field( + default=1, + description="Maximum number of pages to traverse", + ) timeout: Annotated[float, Field(ge=1.0, le=120.0)] = 30.0 headers: dict[str, str] = Field( default_factory=dict, @@ -41,6 +57,36 @@ class ScrapeRequest(BaseModel): default_factory=dict, description="Cookies to include in request", ) + proxy: str | None = Field( + default=None, + description="HTTP/HTTPS proxy URL", + ) + user_agent: str | None = Field( + default=None, + description="Custom User-Agent string", + ) + check_robots: bool = Field( + default=True, + description="Respect robots.txt when scraping", + ) + export: str | None = Field( + default=None, + description="Optional output file path for exported results", + ) + export_format: Literal["json", "jsonl", "csv"] = Field( + default="json", + description="Export format when export path is provided", + ) + + @field_validator("attributes", mode="before") + @classmethod + def parse_attributes(cls, value: list[str] | str | None) -> list[str]: + """Allow attributes as list or comma-separated string.""" + if value is None: + return [] + if isinstance(value, str): + return [item.strip() for item in value.split(",") if item.strip()] + return value class ScrapeBatchRequest(BaseModel): @@ -48,8 +94,20 @@ class ScrapeBatchRequest(BaseModel): urls: list[HttpUrl] = Field(..., min_length=1, max_length=100) dynamic: bool = False - concurrency: Annotated[int, Field(ge=1, le=10)] = 5 - include_raw_html: bool = False + selector: str | None = None + timeout: Annotated[float, Field(ge=1.0, le=120.0)] = 30.0 + headers: dict[str, str] = Field(default_factory=dict) + cookies: dict[str, str] = Field(default_factory=dict) + proxy: str | None = None + user_agent: str | None = None + export: str | None = Field( + default=None, + description="Optional output file path for exported batch results", + ) + export_format: Literal["json", "jsonl", "csv"] = Field( + default="jsonl", + description="Export format when export path is provided", + ) @field_validator("urls") @classmethod @@ -76,13 +134,57 @@ class AnalyzeRequest(BaseModel): url: HttpUrl ssl: bool = Field(default=True, description="Perform SSL/TLS analysis") fingerprint: bool = Field(default=True, description="Detect technologies") - headers: bool = Field(default=True, description="Analyze security headers") + analyze_headers: bool = Field( + default=True, + description="Analyze security headers", + ) cve: bool = Field(default=True, description="Lookup CVEs for detected technologies") - cve_api: Literal["nvd", "vulners", "circl", "all"] = "all" - full_report: bool = Field( + ssl_verify: bool = Field( default=True, - description="Include scrape results in report", + description="Verify SSL certificates when fetching target page", + ) + timeout: Annotated[float, Field(ge=1.0, le=300.0)] = 30.0 + ssl_timeout: Annotated[float, Field(ge=1.0, le=120.0)] = 10.0 + deep_scan: bool = Field( + default=False, + description="Enable deeper technology fingerprinting", + ) + cve_sources: list[Literal["nvd", "vulners", "circl"]] = Field( + default_factory=list, + description="Explicit CVE sources. If empty, config value is used", + ) + cve_limit: Annotated[int, Field(ge=1, le=1000)] = 100 + cve_severity: Literal["critical", "high", "medium", "low", "info"] | None = None + headers: dict[str, str] = Field( + default_factory=dict, + description="Custom HTTP headers for analysis requests", + ) + cookies: dict[str, str] = Field(default_factory=dict) + proxy: str | None = None + user_agent: str | None = None + check_robots: bool = False + enrich_exploits: bool = False + export: str | None = Field( + default=None, + description="Optional output file path for exported results", ) + export_format: Literal["json", "jsonl", "csv"] = Field( + default="json", + description="Export format when export path is provided", + ) + + @field_validator("cve_sources", mode="before") + @classmethod + def parse_cve_sources( + cls, + value: list[str] | str | None, + ) -> list[str]: + """Allow CVE sources as list or comma-separated string.""" + if value is None: + return [] + if isinstance(value, str): + return [item.strip() for item in value.split(",") if item.strip()] + return value # ============================================================================= @@ -94,11 +196,60 @@ class AttackRequest(BaseModel): """Request payload for attack simulation endpoint.""" url: HttpUrl - xss: bool = Field(default=True, description="Test for XSS vulnerabilities") - sqli: bool = Field(default=True, description="Test for SQL injection") - traversal: bool = Field(default=True, description="Test for path traversal") - enumeration: bool = Field(default=True, description="Directory enumeration") - max_payloads: Annotated[int, Field(ge=1, le=1000)] = 50 + xss: bool | None = Field( + default=None, + description="Test for XSS vulnerabilities (None uses config default)", + ) + sqli: bool | None = Field( + default=None, + description="Test for SQL injection (None uses config default)", + ) + traversal: bool | None = Field( + default=None, + description="Test for path traversal (None uses config default)", + ) + enumeration: bool | None = Field( + default=None, + description="Directory enumeration (None uses config default)", + ) + all_attacks: bool = Field( + default=False, + description="Enable all attack types", + ) + intensity: Literal["low", "medium", "high"] = "medium" + max_payloads: Annotated[int | None, Field(ge=1, le=1000)] = None + custom_payloads_file: str | None = Field( + default=None, + description="Path to custom payloads file", + ) + custom_wordlist: str | None = Field( + default=None, + description="Custom wordlist path for enumeration", + ) + timeout: Annotated[float, Field(ge=1.0, le=300.0)] = 10.0 + delay_between_requests: float = 0.1 + concurrent_requests: Annotated[int, Field(ge=1, le=10)] = 1 + scope_urls: list[str] = Field( + default_factory=list, + description="Optional list of URLs to scope the attack to", + ) + skip_dangerous_payloads: bool = Field( + default=True, + description="Skip payloads marked as dangerous", + ) + export: str | None = Field( + default=None, + description="Optional output file path for exported results", + ) + export_format: Literal["json", "jsonl", "csv"] = Field( + default="json", + description="Export format when export path is provided", + ) + headers: dict[str, str] = Field(default_factory=dict) + cookies: dict[str, str] = Field(default_factory=dict) + proxy: str | None = None + user_agent: str | None = None + verbose: bool = False user_consent: bool = Field( default=False, description="User confirms authorization to test this target", @@ -115,6 +266,25 @@ def require_consent(cls, v: bool) -> bool: ) return v + @field_validator("headers", "cookies", mode="before") + @classmethod + def parse_key_value_map( + cls, + value: dict[str, str] | str | None, + ) -> dict[str, str]: + """Allow maps as dict or JSON object string.""" + if value is None: + return {} + if isinstance(value, str): + try: + loaded = json.loads(value) + if isinstance(loaded, dict): + return {str(k): str(v) for k, v in loaded.items()} + except json.JSONDecodeError: + return {} + return {} + return value + # ============================================================================= # Export Requests @@ -136,15 +306,71 @@ class ExportRequest(BaseModel): class ConfigUpdateRequest(BaseModel): - """Request to update configuration.""" + """Request to update a configuration value.""" path: str = Field( ..., - description="Dot-separated path to config key (e.g., 'http.timeout.connect')", + min_length=1, + description="Configuration key (dot-notation)", ) - value: str | int | float | bool | list | dict = Field( + value: Any = Field(..., description="New value (str, int, float, bool, list, dict)") + save: bool = Field( + False, + description="If True, persist changes to disk immediately", + ) + + +class ConfigResetRequest(BaseModel): + """Request to reset configuration.""" + + path: Annotated[str | None, Field(min_length=1)] = Field( + None, + description="Specific key to reset, or None to reset all", + ) + save: bool = Field( + False, + description="If True, persist reset to disk immediately", + ) + + +class ConfigExportRequest(BaseModel): + """Request to export configuration.""" + + path: str = Field(..., min_length=1, description="Output file path") + format: str = Field( + "yaml", + description="Export format (yaml or json)", + ) + + +class ConfigLoadRequest(BaseModel): + """Request to load configuration from file.""" + + path: str = Field(..., min_length=1, description="Input file path") + + +class ConfigSaveRequest(BaseModel): + """Request to save configuration to file.""" + + path: Annotated[str | None, Field(min_length=1)] = Field( + None, + description="Output file path (uses default if not provided)", + ) + + +# ============================================================================= +# Download Requests +# ============================================================================= + + +class DownloadRequest(BaseModel): + """Request to download a file using a token.""" + + token: str = Field( ..., - description="New value for the configuration key", + min_length=36, + max_length=36, + description="Download token (UUID format)", ) diff --git a/src/ciberwebscan/api/models/responses.py b/src/ciberwebscan/api/models/responses.py index d0609e5..fb9eaa5 100644 --- a/src/ciberwebscan/api/models/responses.py +++ b/src/ciberwebscan/api/models/responses.py @@ -7,20 +7,12 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Generic, TypeVar from pydantic import BaseModel, Field -from ciberwebscan.export.models import ( - AnalysisReport, - AttackResult, - CVEResult, - FingerprintResult, - HeadersResult, - ScrapeResult, - SSLResult, -) +from ciberwebscan.export.models import ScrapeResult T = TypeVar("T") @@ -36,7 +28,9 @@ class APIResponse(BaseModel, Generic[T]): success: bool = True data: T | None = None error: str | None = None - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + download_token: str | None = None + download_url: str | None = None class ErrorResponse(BaseModel): @@ -46,7 +40,7 @@ class ErrorResponse(BaseModel): error: str error_code: str | None = None details: dict[str, Any] | None = None - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class ValidationErrorDetail(BaseModel): @@ -64,7 +58,7 @@ class ValidationErrorResponse(BaseModel): error: str = "Validation error" error_code: str = "VALIDATION_ERROR" details: list[ValidationErrorDetail] - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # ============================================================================= @@ -83,7 +77,7 @@ class PaginatedResponse(BaseModel, Generic[T]): total_pages: int has_next: bool has_prev: bool - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @classmethod def create( @@ -137,18 +131,7 @@ class JobCreatedResponse(BaseModel): status: str = "pending" status_url: str message: str = "Job created successfully" - timestamp: datetime = Field(default_factory=datetime.utcnow) - - -# ============================================================================= -# Scrape Responses -# ============================================================================= - - -class ScrapeResponse(APIResponse[ScrapeResult]): - """Response for scrape endpoint.""" - - elapsed_ms: float = 0.0 + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class ScrapeBatchResponse(BaseModel): @@ -158,13 +141,12 @@ class ScrapeBatchResponse(BaseModel): job_id: str total_urls: int status_url: str - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class ScrapeBatchResultResponse(BaseModel): """Result of completed batch scrape.""" - success: bool = True job_id: str results: list[ScrapeResult] failed_urls: list[dict[str, str]] = Field( @@ -174,56 +156,7 @@ class ScrapeBatchResultResponse(BaseModel): total_success: int total_failed: int elapsed_seconds: float - timestamp: datetime = Field(default_factory=datetime.utcnow) - - -# ============================================================================= -# Analysis Responses -# ============================================================================= - - -class AnalyzeResponse(APIResponse[AnalysisReport]): - """Response for analysis endpoint.""" - - pass - - -class SSLAnalysisResponse(APIResponse[SSLResult]): - """Response for SSL analysis endpoint.""" - - pass - - -class FingerprintResponse(APIResponse[FingerprintResult]): - """Response for fingerprint endpoint.""" - - pass - - -class HeadersAnalysisResponse(APIResponse[HeadersResult]): - """Response for headers analysis endpoint.""" - - pass - - -class CVESearchResponse(PaginatedResponse[CVEResult]): - """Response for CVE search endpoint.""" - - pass - - -# ============================================================================= -# Attack Responses -# ============================================================================= - - -class AttackResponse(APIResponse[AttackResult]): - """Response for attack simulation endpoint.""" - - warnings: list[str] = Field( - default_factory=list, - description="Legal/ethical warnings for user", - ) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # ============================================================================= @@ -236,8 +169,9 @@ class HealthCheckResponse(BaseModel): status: str = "healthy" version: str - uptime_seconds: float - timestamp: datetime = Field(default_factory=datetime.utcnow) + message: str = "" + uptime_seconds: float = 0.0 + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class ServiceStatus(BaseModel): @@ -256,7 +190,7 @@ class DetailedHealthResponse(BaseModel): version: str uptime_seconds: float services: list[ServiceStatus] = Field(default_factory=list) - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # ============================================================================= @@ -288,7 +222,7 @@ class ScanSummaryResponse(BaseModel): default_factory=list, description="Top 5 most critical issues found", ) - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # ============================================================================= @@ -304,7 +238,34 @@ class ExportResponse(BaseModel): format: str file_size_bytes: int expires_at: datetime - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# ============================================================================= +# Download Responses +# ============================================================================= + + +class DownloadTokenResponse(BaseModel): + """Response when a download token is generated.""" + + token: str = Field(..., description="Unique download token") + expires_at: datetime = Field(..., description="Token expiration timestamp") + download_url: str = Field(..., description="URL to download the file") + + +class DownloadInfo(BaseModel): + """Metadata about a download token and associated file.""" + + token: str + user_id: str + file_size_bytes: int + created_at: datetime + expires_at: datetime + attempts_remaining: int + file_format: str = Field( + default="json", description="Export format (json/jsonl/csv)" + ) # ============================================================================= @@ -312,20 +273,26 @@ class ExportResponse(BaseModel): # ============================================================================= -class ConfigResponse(BaseModel): - """Response for config endpoints.""" +class ConfigValueResponse(BaseModel): + """Response containing a configuration value and metadata.""" - success: bool = True - config: dict[str, Any] - timestamp: datetime = Field(default_factory=datetime.utcnow) + key: str + value: Any + default: Any + source: str # 'file', 'env', 'default', 'runtime' + description: str = "" -class ConfigUpdateResponse(BaseModel): - """Response for config update endpoint.""" +class ConfigKeysResponse(BaseModel): + """Response containing a list of configuration keys.""" - success: bool = True - path: str - old_value: Any - new_value: Any - message: str = "Configuration updated successfully" - timestamp: datetime = Field(default_factory=datetime.utcnow) + keys: list[str] + count: int + + +class ConfigFileResponse(BaseModel): + """Response containing file operation result.""" + + file_path: str + operation: str # 'export', 'load', 'save' + format: str | None = None diff --git a/src/ciberwebscan/api/routes/__init__.py b/src/ciberwebscan/api/routes/__init__.py new file mode 100644 index 0000000..0b31d30 --- /dev/null +++ b/src/ciberwebscan/api/routes/__init__.py @@ -0,0 +1,9 @@ +""" +API routes for CiberWebScan. + +This package contains all FastAPI route handlers for the REST API. +""" + +from . import analyze, attack, auth, config, download, health, scrape + +__all__ = ["analyze", "attack", "auth", "config", "download", "health", "scrape"] diff --git a/src/ciberwebscan/api/routes/analyze.py b/src/ciberwebscan/api/routes/analyze.py new file mode 100644 index 0000000..4217418 --- /dev/null +++ b/src/ciberwebscan/api/routes/analyze.py @@ -0,0 +1,95 @@ +""" +Security analysis endpoints for CiberWebScan API. +""" + +from __future__ import annotations + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import ValidationError + +from ciberwebscan.api.auth import AuthenticatedUser, get_current_user +from ciberwebscan.api.helpers.download_helper import enrich_response_with_token +from ciberwebscan.api.models.requests import AnalyzeRequest +from ciberwebscan.api.models.responses import APIResponse +from ciberwebscan.export.models import AnalysisReport +from ciberwebscan.services.analyze_service import AnalyzeOptions, AnalyzeService +from ciberwebscan.services.download_service import DownloadService + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/analyze", response_model=APIResponse[AnalysisReport]) +async def analyze_url( + request: AnalyzeRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[AnalysisReport]: + """ + Perform security analysis on a URL. + Supports SSL analysis, technology fingerprinting, header analysis, and CVE lookup. + """ + try: + # Convert request to service options + options = AnalyzeOptions( + url=str(request.url), + # Analysis types + ssl=request.ssl, + fingerprint=request.fingerprint, + analyze_headers=request.analyze_headers, + cve=request.cve, + ssl_verify=request.ssl_verify, + timeout=request.timeout, + ssl_timeout=request.ssl_timeout, + deep_scan=request.deep_scan, + cve_sources=request.cve_sources, + cve_limit=request.cve_limit, + cve_severity=request.cve_severity, + headers=request.headers, + cookies=request.cookies, + proxy=request.proxy, + user_agent=request.user_agent, + check_robots=request.check_robots, + enrich_exploits=request.enrich_exploits, + export=request.export, + export_format=request.export_format, + ) + + # Execute analysis + service = AnalyzeService() + result = service.analyze(options) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result.error or "Analysis failed", + ) + + # Enrich response with download token + download_service = DownloadService() + data, download_token = enrich_response_with_token( + result, user.identifier, download_service + ) + download_url = f"/api/v1/download/{download_token}" if download_token else None + + return APIResponse[AnalysisReport]( + success=True, + data=data, + download_token=download_token, + download_url=download_url, + ) + + except ValidationError as e: + logger.warning(f"Validation error in analyze request: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid request: {e}" + ) from e + except Exception as e: + logger.error(f"Error analyzing URL {request.url}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Analysis failed: {str(e)}", + ) from e diff --git a/src/ciberwebscan/api/routes/attack.py b/src/ciberwebscan/api/routes/attack.py new file mode 100644 index 0000000..d7861a2 --- /dev/null +++ b/src/ciberwebscan/api/routes/attack.py @@ -0,0 +1,119 @@ +""" +Attack simulation endpoints for CiberWebScan API. + +WARNING: Only use against systems you own or have explicit written permission +to test. Unauthorized security testing is illegal and unethical. +""" + +from __future__ import annotations + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import ValidationError + +from ciberwebscan.api.auth import AuthenticatedUser, get_current_user +from ciberwebscan.api.helpers.download_helper import enrich_response_with_token +from ciberwebscan.api.models.requests import AttackRequest +from ciberwebscan.api.models.responses import APIResponse +from ciberwebscan.export.models import AttackResult +from ciberwebscan.services.attack_service import AttackOptions, AttackService +from ciberwebscan.services.base import ValidationError as ServiceValidationError +from ciberwebscan.services.download_service import DownloadService + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/attack", response_model=APIResponse[AttackResult]) +async def attack_target( + request: AttackRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[AttackResult]: + """ + Perform security attack simulations against a target URL. + + Supports XSS, SQL injection, path traversal, and directory enumeration testing. + + **IMPORTANT**: Only test systems you own or have explicit written permission to test. + Setting `user_consent=true` confirms that you have that permission. + """ + try: + # Resolve individual attack flags; all_attacks overrides each one + xss = True if request.all_attacks else request.xss + sqli = True if request.all_attacks else request.sqli + traversal = True if request.all_attacks else request.traversal + enumeration = True if request.all_attacks else request.enumeration + + options = AttackOptions( + url=str(request.url), + user_consent=request.user_consent, + xss=xss, + sqli=sqli, + traversal=traversal, + enumeration=enumeration, + intensity=request.intensity, + max_payloads=request.max_payloads, + timeout=request.timeout, + delay_between_requests=request.delay_between_requests, + concurrent_requests=request.concurrent_requests, + custom_payloads_file=request.custom_payloads_file, + custom_wordlist=request.custom_wordlist, + skip_dangerous_payloads=request.skip_dangerous_payloads, + scope_urls=request.scope_urls, + export=request.export, + export_format=request.export_format, + headers=request.headers, + cookies=request.cookies, + proxy=request.proxy, + user_agent=request.user_agent, + verbose=request.verbose, + ) + + service = AttackService() + result = service.attack(options) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result.error or "Attack execution failed", + ) + + # Enrich response with download token + download_service = DownloadService() + data, download_token = enrich_response_with_token( + result, user.identifier, download_service + ) + download_url = f"/api/v1/download/{download_token}" if download_token else None + + return APIResponse[AttackResult]( + success=True, + data=data, + download_token=download_token, + download_url=download_url, + ) + + except ServiceValidationError as e: + logger.warning(f"Validation error in attack request for {request.url}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except ValidationError as e: + logger.warning(f"Pydantic validation error in attack request: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid request: {e}", + ) from e + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error running attack execution on {request.url}: {e}", exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Attack execution failed: {str(e)}", + ) from e diff --git a/src/ciberwebscan/api/routes/auth.py b/src/ciberwebscan/api/routes/auth.py new file mode 100644 index 0000000..d21f8e6 --- /dev/null +++ b/src/ciberwebscan/api/routes/auth.py @@ -0,0 +1,92 @@ +""" +Authentication endpoints for CiberWebScan API. + +Provides endpoints for API key management and user information. +""" + +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel + +from ciberwebscan.api.auth import ( + AuthenticatedUser, + generate_api_key, + get_current_user, +) + +router = APIRouter() + + +# ============================================================================= +# Request/Response Models +# ============================================================================= + + +class UserInfoResponse(BaseModel): + """Response model for user information.""" + + identifier: str + auth_method: str + scopes: list[str] + authenticated: bool = True + + +class ApiKeyGenerateResponse(BaseModel): + """Response for API key generation.""" + + api_key: str + message: str = "Store this key securely. It cannot be retrieved again." + + +# ============================================================================= +# Endpoints +# ============================================================================= + + +@router.get("/me", response_model=UserInfoResponse) +async def get_current_user_info( + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> UserInfoResponse: + """ + Get information about the currently authenticated user. + + Requires authentication via API key. + """ + return UserInfoResponse( + identifier=user.identifier, + auth_method=user.auth_method, + scopes=user.scopes, + ) + + +@router.post("/generate-key", response_model=ApiKeyGenerateResponse) +async def generate_new_api_key( + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> ApiKeyGenerateResponse: + """ + Generate a new API key. + + Note: This only generates the key value. You must manually add it to + the CIBERWEBSCAN_API_KEYS environment variable for it to work. + + Requires authentication. + """ + # Check if user has permission (full_access or admin scope) + if "full_access" not in user.scopes and "admin" not in user.scopes: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin access required to generate API keys", + ) + + new_key = generate_api_key() + + return ApiKeyGenerateResponse( + api_key=new_key, + message=( + "Store this key securely. It cannot be retrieved again. " + "Add it to CIBERWEBSCAN_API_AUTH_API_KEYS environment variable to activate." + ), + ) diff --git a/src/ciberwebscan/api/routes/config.py b/src/ciberwebscan/api/routes/config.py new file mode 100644 index 0000000..82b7314 --- /dev/null +++ b/src/ciberwebscan/api/routes/config.py @@ -0,0 +1,485 @@ +""" +Configuration management endpoints for CiberWebScan API. + +Provides REST endpoints for: +- Viewing and modifying configuration +- Exporting/importing configuration +- Resetting to defaults +""" + +from __future__ import annotations + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, status + +from ciberwebscan.api.auth import AuthenticatedUser, get_current_user +from ciberwebscan.api.models.requests import ( + ConfigExportRequest, + ConfigLoadRequest, + ConfigResetRequest, + ConfigSaveRequest, + ConfigUpdateRequest, +) +from ciberwebscan.api.models.responses import ( + APIResponse, + ConfigFileResponse, + ConfigKeysResponse, + ConfigValueResponse, +) +from ciberwebscan.services.config_service import ConfigService + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.get( + "/config", + response_model=APIResponse[dict[str, Any]], + summary="Get all configuration", +) +async def get_all_config( + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[dict[str, Any]]: + """ + Retrieve the complete configuration. + + Returns all configuration sections and their values. + """ + try: + service = ConfigService() + result = service.get_all() + + if not result.success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to get configuration", + ) + + return APIResponse[dict[str, Any]](data=result.data) + + except Exception as e: + logger.error(f"Error getting configuration: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get configuration: {str(e)}", + ) from e + + +@router.get( + "/config/sections/{section}", + response_model=APIResponse[dict[str, Any]], + summary="Get configuration section", +) +async def get_config_section( + section: str, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[dict[str, Any]]: + """ + Retrieve a specific configuration section. + + Args: + section: Section name (e.g., 'scraping', 'analysis', 'api') + + Returns: + Configuration values for the specified section. + """ + try: + service = ConfigService() + result = service.get_section(section) + + if not result.success: + if "not found" in (result.error or "").lower(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=result.error or f"Section not found: {section}", + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to get section", + ) + + return APIResponse[dict[str, Any]](data=result.data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting section {section}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get section: {str(e)}", + ) from e + + +@router.get( + "/config/value", + response_model=APIResponse[ConfigValueResponse], + summary="Get a specific configuration value", +) +async def get_config_value( + path: str, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[ConfigValueResponse]: + """ + Retrieve a specific configuration value with metadata. + + Args: + path: Configuration key in dot-notation (e.g., 'scraping.timeout') + + Returns: + Value, default, source, and description. + """ + try: + service = ConfigService() + result = service.get(path) + + if not result.success: + if "not found" in (result.error or "").lower(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=result.error or f"Key not found: {path}", + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to get value", + ) + + data = result.data + if data is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Configuration value data was None", + ) + + response_data = ConfigValueResponse( + key=data.key, + value=data.value, + default=data.default, + source=data.source, + description=data.description, + ) + + return APIResponse[ConfigValueResponse](data=response_data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting config value {path}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get value: {str(e)}", + ) from e + + +@router.put( + "/config", + response_model=APIResponse[ConfigValueResponse], + summary="Update a configuration value", +) +async def update_config( + request: ConfigUpdateRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[ConfigValueResponse]: + """ + Update a configuration value. + + Args: + request: Contains key path, new value, and optional save flag. + + Returns: + Updated configuration value with metadata. + """ + try: + service = ConfigService() + result = service.set(request.path, request.value) + + if not result.success: + if "not found" in (result.error or "").lower(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=result.error or f"Key not found: {request.path}", + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to update configuration", + ) + + # Optionally save to disk + if request.save: + save_result = service.save() + if not save_result.success: + logger.warning(f"Failed to save config: {save_result.error}") + + data = result.data + if data is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Configuration value data was None", + ) + + response_data = ConfigValueResponse( + key=data.key, + value=data.value, + default=data.default, + source=data.source, + description=data.description, + ) + + return APIResponse[ConfigValueResponse](data=response_data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating config {request.path}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update configuration: {str(e)}", + ) from e + + +@router.post( + "/config/reset", + response_model=APIResponse[dict[str, Any]], + summary="Reset configuration to defaults", +) +async def reset_config( + request: ConfigResetRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[dict[str, Any]]: + """ + Reset configuration values to defaults. + + Args: + request: Contains optional key path and save flag. + If path is None, resets all configuration. + + Returns: + Confirmation of reset operation. + """ + try: + service = ConfigService() + result = service.reset(request.path) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to reset configuration", + ) + + # Optionally save to disk + if request.save: + save_result = service.save() + if not save_result.success: + logger.warning( + f"Failed to save config after reset: {save_result.error}" + ) + + return APIResponse[dict[str, Any]]( + data={ + "reset": True, + "path": request.path or "all", + "saved": request.save, + } + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error resetting config: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to reset configuration: {str(e)}", + ) from e + + +@router.get( + "/config/keys", + response_model=APIResponse[ConfigKeysResponse], + summary="List configuration keys", +) +async def list_config_keys( + user: Annotated[AuthenticatedUser, Depends(get_current_user)], + section: str | None = None, +) -> APIResponse[ConfigKeysResponse]: + """ + List all configuration keys, optionally filtered by section. + + Args: + section: Optional section name to filter keys. + + Returns: + List of available configuration keys. + """ + try: + service = ConfigService() + result = service.list_keys(section) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to list keys", + ) + + response_data = ConfigKeysResponse( + keys=result.data or [], + count=len(result.data or []), + ) + + return APIResponse[ConfigKeysResponse](data=response_data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error listing config keys: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to list keys: {str(e)}", + ) from e + + +@router.post( + "/config/export", + response_model=APIResponse[ConfigFileResponse], + summary="Export configuration to file", +) +async def export_config( + request: ConfigExportRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[ConfigFileResponse]: + """ + Export current configuration to a file. + + Args: + request: Contains output path and format (yaml or json). + + Returns: + Information about the exported file. + """ + try: + if request.format not in ("yaml", "json"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Format must be 'yaml' or 'json'", + ) + + service = ConfigService() + result = service.export_config(request.path, format=request.format) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to export configuration", + ) + + response_data = ConfigFileResponse( + file_path=str(result.data), + operation="export", + format=request.format, + ) + + logger.info(f"Configuration exported to {result.data}") + return APIResponse[ConfigFileResponse](data=response_data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error exporting config to {request.path}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to export configuration: {str(e)}", + ) from e + + +@router.post( + "/config/load", + response_model=APIResponse[dict[str, Any]], + summary="Load configuration from file", +) +async def load_config( + request: ConfigLoadRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[dict[str, Any]]: + """ + Load configuration from a file. + + Args: + request: Contains input file path. + + Returns: + Loaded configuration. + """ + try: + service = ConfigService() + result = service.load(request.path) + + if not result.success: + if "not found" in (result.error or "").lower(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=result.error or f"File not found: {request.path}", + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to load configuration", + ) + + logger.info(f"Configuration loaded from {request.path}") + return APIResponse[dict[str, Any]](data=result.data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error loading config from {request.path}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load configuration: {str(e)}", + ) from e + + +@router.post( + "/config/save", + response_model=APIResponse[ConfigFileResponse], + summary="Save configuration to file", +) +async def save_config( + user: Annotated[AuthenticatedUser, Depends(get_current_user)], + request: ConfigSaveRequest | None = None, +) -> APIResponse[ConfigFileResponse]: + """ + Save current configuration to file. + + Args: + request: Optional path to save to (uses default if not provided). + + Returns: + Information about the saved file. + """ + try: + service = ConfigService() + save_path = request.path if request else None + result = service.save(save_path) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.error or "Failed to save configuration", + ) + + response_data = ConfigFileResponse( + file_path=str(result.data), + operation="save", + ) + + logger.info(f"Configuration saved to {result.data}") + return APIResponse[ConfigFileResponse](data=response_data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error saving config: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to save configuration: {str(e)}", + ) from e diff --git a/src/ciberwebscan/api/routes/download.py b/src/ciberwebscan/api/routes/download.py new file mode 100644 index 0000000..6fefb4f --- /dev/null +++ b/src/ciberwebscan/api/routes/download.py @@ -0,0 +1,146 @@ +""" +Download endpoint for file downloads. + +Provides streaming download endpoint with token-based authentication +and rate limiting. +""" + +from __future__ import annotations + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse + +from ciberwebscan.api.auth import AuthenticatedUser, get_current_user +from ciberwebscan.config.loader import get_config +from ciberwebscan.services.download_service import DownloadService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/download", tags=["download"]) + +# Service instance +_download_service = DownloadService() + + +def _get_download_service() -> DownloadService: + """Get download service instance.""" + return _download_service + + +@router.get( + "/{token}", + summary="Download exported results", + description="Download previously exported analysis/attack/scrape results using a time-limited token.", + responses={ + 200: {"description": "File downloaded successfully"}, + 400: {"description": "Invalid token"}, + 401: {"description": "Unauthorized - different user or max retries exceeded"}, + 404: {"description": "Token not found"}, + 410: {"description": "Token expired"}, + 503: {"description": "Download service unavailable"}, + }, +) +async def download_file( + token: str, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], + download_service: Annotated[DownloadService, Depends(_get_download_service)], +) -> StreamingResponse: + """ + Download file using a download token. + + The token is obtained from POST endpoints (analyze, attack, scrape). + Tokens expire after a configured time period and have a limited number of retry attempts. + + Args: + token: Download token from response + user: Authenticated user from API key + download_service: Download service instance (injected via Depends) + + Returns: + StreamingResponse with file data + + Raises: + HTTPException: If token is invalid, expired, or user unauthorized + """ + config = get_config() + + # Check if download is enabled + if not config.download.enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Download service is disabled", + ) + + # Validate token + validation_result = download_service.validate_download_request( + token=token, + user_id=user.identifier, + ) + + if not validation_result.success: + error_msg = validation_result.error or "Invalid token" + + # Determine appropriate HTTP status code + if "expired" in error_msg.lower(): + status_code = status.HTTP_410_GONE + elif "unauthorized" in error_msg.lower(): + status_code = status.HTTP_401_UNAUTHORIZED + elif "attempts" in error_msg.lower(): + status_code = status.HTTP_429_TOO_MANY_REQUESTS + else: + status_code = status.HTTP_400_BAD_REQUEST + + logger.warning(f"Download validation failed for {user.identifier}: {error_msg}") + raise HTTPException(status_code=status_code, detail=error_msg) + + # Get file stream + stream_result = download_service.get_file_stream(token=token) + if not stream_result.success: + logger.error(f"Failed to get file stream: {stream_result.error}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve file", + ) + + # Prepare headers + headers = { + "Content-Disposition": "attachment; filename=export.json", + "X-Content-Type-Options": "nosniff", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + } + + logger.info(f"Streaming download for token {token} to user {user.identifier}") + + # Get file stream and validate data is not None + if stream_result.data is None: + logger.error("Stream result data is None despite success status") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve file stream", + ) + + # Capture stream data to ensure it's not None for type checker + file_chunks = stream_result.data + + # Wrapper generator that cleans up token after streaming completes + def stream_and_cleanup(): + """Stream file data and cleanup token after completion.""" + try: + yield from file_chunks + finally: + # Clean up token after successful stream + cleanup_result = download_service.delete_token(token) + if cleanup_result.success: + logger.info(f"Token cleaned up after download: {token}") + else: + logger.warning(f"Failed to cleanup token: {token}") + + return StreamingResponse( + stream_and_cleanup(), + media_type="application/octet-stream", + headers=headers, + ) diff --git a/src/ciberwebscan/api/routes/health.py b/src/ciberwebscan/api/routes/health.py new file mode 100644 index 0000000..39fb2e4 --- /dev/null +++ b/src/ciberwebscan/api/routes/health.py @@ -0,0 +1,37 @@ +""" +Health check endpoints for CiberWebScan API. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from fastapi import APIRouter + +from ciberwebscan import __version__ +from ciberwebscan.api.models.responses import HealthCheckResponse + +router = APIRouter() + + +@router.get("/health", response_model=HealthCheckResponse) +async def health_check() -> HealthCheckResponse: + """Basic health check endpoint.""" + return HealthCheckResponse( + status="healthy", + timestamp=datetime.now(timezone.utc), + version=__version__, + message="CiberWebScan API is running", + ) + + +@router.get("/health/ready", response_model=HealthCheckResponse) +async def readiness_check() -> HealthCheckResponse: + """Readiness check endpoint for container orchestration.""" + # Could add checks for database, external services, etc. + return HealthCheckResponse( + status="ready", + timestamp=datetime.now(timezone.utc), + version=__version__, + message="CiberWebScan API is ready to accept requests", + ) diff --git a/src/ciberwebscan/api/routes/scrape.py b/src/ciberwebscan/api/routes/scrape.py new file mode 100644 index 0000000..2f64a40 --- /dev/null +++ b/src/ciberwebscan/api/routes/scrape.py @@ -0,0 +1,195 @@ +""" +Scraping endpoints for CiberWebScan API. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import ValidationError + +from ciberwebscan.api.auth import AuthenticatedUser, get_current_user +from ciberwebscan.api.helpers.download_helper import enrich_response_with_token +from ciberwebscan.api.models.requests import ScrapeBatchRequest, ScrapeRequest +from ciberwebscan.api.models.responses import ( + APIResponse, + ScrapeBatchResultResponse, +) +from ciberwebscan.export.models import ScrapeResult +from ciberwebscan.services.download_service import DownloadService +from ciberwebscan.services.scrape_service import ScrapeOptions, ScrapeService + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/scrape", response_model=APIResponse[ScrapeResult | list[dict[str, Any]]]) +async def scrape_url( + request: ScrapeRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[ScrapeResult | list[dict[str, Any]]]: + """ + Scrape a single URL and return structured data. + Supports both static (BeautifulSoup) and dynamic (Playwright) scraping. + """ + try: + # Convert request to service options + options = ScrapeOptions( + url=str(request.url), + dynamic=request.dynamic, + wait_for=request.wait_for, + timeout=request.timeout, + selector=request.selector, + attributes=request.attributes, + schema=request.extraction_schema, + pagination_selector=request.pagination_selector, + pagination_limit=request.pagination_limit, + export=request.export, + export_format=request.export_format, + headers=request.headers, + cookies=request.cookies, + proxy=request.proxy, + user_agent=request.user_agent, + check_robots=request.check_robots, + ) + + # Execute scraping + service = ScrapeService() + result = service.scrape(options) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result.error or "Scraping failed", + ) + + if result.data is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Scraping returned no data", + ) + + # Enrich response with download token + download_service = DownloadService() + data, download_token = enrich_response_with_token( + result, user.identifier, download_service + ) + download_url = f"/api/v1/download/{download_token}" if download_token else None + + return APIResponse[ScrapeResult | list[dict[str, Any]]]( + success=True, + data=data, + download_token=download_token, + download_url=download_url, + ) + + except ValidationError as e: + logger.warning(f"Validation error in scrape request: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid request: {e}", + ) from e + except Exception as e: + logger.error(f"Error scraping URL {request.url}: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Scraping failed: {str(e)}", + ) from e + + +@router.post("/scrape/batch", response_model=APIResponse[ScrapeBatchResultResponse]) +async def scrape_batch( + request: ScrapeBatchRequest, + user: Annotated[AuthenticatedUser, Depends(get_current_user)], +) -> APIResponse[ScrapeBatchResultResponse]: + """ + Scrape multiple URLs in batch. + """ + try: + urls = [str(url) for url in request.urls] + options = ScrapeOptions( + url=urls[0], + dynamic=request.dynamic, + timeout=request.timeout, + selector=request.selector, + export=request.export, + export_format=request.export_format, + headers=request.headers, + cookies=request.cookies, + proxy=request.proxy, + user_agent=request.user_agent, + ) + + service = ScrapeService() + result = service.scrape_multiple(urls, options) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result.error or "Batch scraping failed", + ) + + successful_results = result.data or [] + successful_urls = {item.url for item in successful_results} + failed_urls = [ + {"url": url, "error": "Scrape failed"} + for url in urls + if url not in successful_urls + ] + + job_id = str(uuid.uuid4()) + logger.info( + "Batch scrape completed: %s (%d success, %d failed)", + job_id, + len(successful_results), + len(failed_urls), + ) + + # Create batch result data + batch_data = ScrapeBatchResultResponse( + job_id=job_id, + results=successful_results, + failed_urls=failed_urls, + total_success=len(successful_results), + total_failed=len(failed_urls), + elapsed_seconds=result.duration_seconds, + ) + + # Enrich response with download token if exported + download_service = DownloadService() + download_token = None + download_url = None + + if result.export_path: + token_result = download_service.generate_download_token( + file_path=result.export_path, + user_id=user.identifier, + file_format=request.export_format, + ) + if token_result.success and token_result.data: + download_token = token_result.data.token + download_url = token_result.data.download_url + + return APIResponse[ScrapeBatchResultResponse]( + success=True, + data=batch_data, + download_token=download_token, + download_url=download_url, + ) + + except ValidationError as e: + logger.warning(f"Validation error in batch scrape request: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid request: {e}", + ) from e + except Exception as e: + logger.error(f"Error in batch scraping: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Batch scraping failed: {str(e)}", + ) from e diff --git a/src/ciberwebscan/cli/app.py b/src/ciberwebscan/cli/app.py index cc2d416..f58c8b2 100644 --- a/src/ciberwebscan/cli/app.py +++ b/src/ciberwebscan/cli/app.py @@ -13,6 +13,7 @@ import typer from ciberwebscan.cli.commands.analyze import analyze +from ciberwebscan.cli.commands.api import api from ciberwebscan.cli.commands.attack import attack from ciberwebscan.cli.commands.config import config from ciberwebscan.cli.commands.scrape import scrape @@ -32,6 +33,7 @@ app.add_typer(analyze, name="analyze") app.add_typer(attack, name="attack") app.add_typer(config, name="config") +app.add_typer(api, name="api") @app.command("version") diff --git a/src/ciberwebscan/cli/commands/api.py b/src/ciberwebscan/cli/commands/api.py new file mode 100644 index 0000000..1ae831c --- /dev/null +++ b/src/ciberwebscan/cli/commands/api.py @@ -0,0 +1,91 @@ +""" +API command for CiberWebScan CLI. + +Handles API server operations. +""" + +from __future__ import annotations + +from typing import Annotated + +import typer + +from ciberwebscan.cli.output import print_error, print_info, print_success +from ciberwebscan.config.loader import get_config + +api = typer.Typer( + name="api", + help="API server management commands.", + no_args_is_help=True, +) + + +@api.command("run") +def run_api( + host: Annotated[ + str | None, + typer.Option( + "--host", + help="Host to bind the API server to", + ), + ] = None, + port: Annotated[ + int | None, + typer.Option( + "--port", + help="Port to bind the API server to", + ), + ] = None, + reload: Annotated[ + bool, + typer.Option( + "--reload", + help="Enable auto-reload on code changes", + ), + ] = False, +) -> None: + """Start the CiberWebScan API server. + Examples: + # Run with default config settings + ciberwebscan api run + # Run on custom host and port + ciberwebscan api run --host 0.0.0.0 --port 9000 + # Disable auto-reload for production + ciberwebscan api run --no-reload + """ + try: + # Import here to avoid issues if uvicorn is not installed for CLI-only installs + import uvicorn + + app_config = get_config() + api_config = app_config.api + + # Use provided values or fall back to config + server_host = host or api_config.host + server_port = port or api_config.port + log_level = app_config.logging.level.lower() + + print_success( + f"Starting CiberWebScan API on http://{server_host}:{server_port}" + ) + print_info( + f"Documentation available at http://{server_host}:{server_port}/docs" + ) + + uvicorn.run( + "ciberwebscan.api.app:app", + host=server_host, + port=server_port, + reload=reload, + log_level=log_level, + ) + except ImportError: + print_error( + 'uvicorn is not installed. Install it with: pip install -e ".[api]"' + ) + raise typer.Exit(code=1) from None + except KeyboardInterrupt: + print_info("API server stopped by user") + except Exception as e: + print_error(f"Failed to start API server: {e}") + raise typer.Exit(code=1) from None diff --git a/src/ciberwebscan/config/models.py b/src/ciberwebscan/config/models.py index 0c02157..b5d1eba 100644 --- a/src/ciberwebscan/config/models.py +++ b/src/ciberwebscan/config/models.py @@ -283,6 +283,74 @@ class ExportConfig(BaseModel): pretty: bool = True +# ============================================================================= +# Download Configuration +# ============================================================================= + + +class DownloadConfig(BaseModel): + """Download/streaming result settings.""" + + enabled: bool = True + retention_seconds: Annotated[int, Field(ge=60, le=86400)] = 1800 + max_file_size_mb: Annotated[int, Field(ge=1, le=10240)] = 500 + max_retries: Annotated[int, Field(ge=1, le=10)] = 3 + cleanup_interval_seconds: Annotated[int, Field(ge=60, le=3600)] = 300 + require_same_user: bool = True + stream_chunk_size: Annotated[int, Field(ge=1024, le=10 * 1024 * 1024)] = 1024 * 1024 + + +# ============================================================================= +# API Configuration +# ============================================================================= + + +class APIAuthConfig(BaseModel): + """API authentication settings.""" + + # API Key settings + api_keys: list[str] = Field( + default=[], + description="List of valid API keys (can be comma-separated string)", + ) + + @field_validator("api_keys", mode="before") + @classmethod + def parse_api_keys(cls, v: str | list[str]) -> list[str]: + """Parse API keys from comma-separated string or list.""" + if isinstance(v, str): + return [k.strip() for k in v.split(",") if k.strip()] + return [k for k in v if k] + + +class APIRateLimitConfig(BaseModel): + """API rate limiting settings.""" + + enabled: bool = True + requests_per_minute: Annotated[int, Field(ge=1, le=10000)] = 60 + + +class APIConfig(BaseModel): + """API server configuration.""" + + host: str = "0.0.0.0" + port: Annotated[int, Field(ge=1, le=65535)] = 8000 + auth: APIAuthConfig = Field(default_factory=lambda: APIAuthConfig()) + rate_limit: APIRateLimitConfig = Field(default_factory=lambda: APIRateLimitConfig()) + cors_origins: list[str] = Field( + default=["*"], + description="Allowed CORS origins", + ) + + @field_validator("cors_origins", mode="before") + @classmethod + def parse_cors_origins(cls, v: str | list[str]) -> list[str]: + """Parse CORS origins from comma-separated string or list.""" + if isinstance(v, str): + return [o.strip() for o in v.split(",") if o.strip()] + return v + + # ============================================================================= # Logging Configuration # ============================================================================= @@ -344,6 +412,16 @@ class AppConfig(BaseModel): export: format: jsonl streaming: true + api: + host: 0.0.0.0 + port: 8000 + auth: + api_keys: "key1,key2,key3" + logging: + level: INFO + file: app.log + cache: + enabled: true ``` """ @@ -353,6 +431,8 @@ class AppConfig(BaseModel): analysis: AnalysisConfig = Field(default_factory=lambda: AnalysisConfig()) attack: AttackConfig = Field(default_factory=lambda: AttackConfig()) export: ExportConfig = Field(default_factory=lambda: ExportConfig()) + download: DownloadConfig = Field(default_factory=lambda: DownloadConfig()) + api: APIConfig = Field(default_factory=lambda: APIConfig()) logging: LoggingConfig = Field(default_factory=lambda: LoggingConfig()) cache: CacheConfig = Field(default_factory=lambda: CacheConfig()) diff --git a/src/ciberwebscan/services/analyze_service.py b/src/ciberwebscan/services/analyze_service.py index d1edb0c..1a3b84f 100644 --- a/src/ciberwebscan/services/analyze_service.py +++ b/src/ciberwebscan/services/analyze_service.py @@ -67,7 +67,7 @@ class AnalyzeOptions: deep_scan: bool = False # CVE options - cve_sources: list[str] = field(default_factory=lambda: ["nvd"]) + cve_sources: Sequence[str] = field(default_factory=lambda: ["nvd"]) cve_limit: int = 100 cve_severity: str | None = None # Filter by severity diff --git a/src/ciberwebscan/services/cleanup_scheduler.py b/src/ciberwebscan/services/cleanup_scheduler.py new file mode 100644 index 0000000..bd8265e --- /dev/null +++ b/src/ciberwebscan/services/cleanup_scheduler.py @@ -0,0 +1,91 @@ +""" +Background job scheduler for download token cleanup. + +Handles periodic cleanup of expired download tokens and associated file data. +Runs as a background task in the API lifetime events. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from ciberwebscan.config.loader import get_config +from ciberwebscan.services.download_service import DownloadService + +logger = logging.getLogger(__name__) + + +class DownloadCleanupScheduler: + """Scheduler for cleaning up expired download tokens.""" + + def __init__(self) -> None: + """Initialize the scheduler.""" + self._task: asyncio.Task[Any] | None = None + self._running = False + self._service = DownloadService() + + def start(self) -> None: + """Start the cleanup scheduler.""" + if self._running: + logger.warning("Cleanup scheduler already running") + return + + self._running = True + self._task = asyncio.create_task(self._run_loop()) + logger.info("Download cleanup scheduler started") + + def stop(self) -> None: + """Stop the cleanup scheduler.""" + if not self._running: + return + + self._running = False + if self._task: + self._task.cancel() + logger.info("Download cleanup scheduler stopped") + + async def _run_loop(self) -> None: + """Main scheduler loop - runs cleanup at configured intervals.""" + config = get_config() + interval = config.download.cleanup_interval_seconds + + while self._running: + try: + await asyncio.sleep(interval) + + if not self._running: + break + + # Run cleanup + result = self._service.cleanup_expired_tokens() + + if result.success: + count = result.data or 0 + if count > 0: + logger.info( + f"Cleanup job: removed {count} expired download tokens" + ) + else: + logger.error(f"Cleanup job failed: {result.error}") + + except asyncio.CancelledError: + logger.debug("Cleanup scheduler task cancelled") + break + except Exception as e: + logger.error(f"Error in cleanup scheduler loop: {e}", exc_info=True) + # Continue running despite errors + continue + + +# Global scheduler instance +_scheduler: DownloadCleanupScheduler | None = None + + +def get_scheduler() -> DownloadCleanupScheduler: + """Get or create the global scheduler instance.""" + global _scheduler + if _scheduler is None: + _scheduler = DownloadCleanupScheduler() + return _scheduler diff --git a/src/ciberwebscan/services/download_service.py b/src/ciberwebscan/services/download_service.py new file mode 100644 index 0000000..e48f9ba --- /dev/null +++ b/src/ciberwebscan/services/download_service.py @@ -0,0 +1,334 @@ +""" +Download service for managing file downloads and streaming. + +Handles token generation, validation, expiration, and cleanup +of download tokens. Uses in-memory storage with asyncio locks. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import uuid +from collections.abc import Iterator +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from ciberwebscan.api.models.responses import DownloadInfo, DownloadTokenResponse +from ciberwebscan.config.loader import get_config +from ciberwebscan.services.base import BaseService, ServiceResult + + +def _run_async(coro): + """ + Run async coroutine safely, handling both async and sync contexts + In async context (FastAPI): uses ThreadPoolExecutor to avoid event loop issues + In sync context: uses asyncio.run() + """ + try: + asyncio.get_running_loop() + # We're in an async context, use thread executor + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(lambda: asyncio.run(coro)) + return future.result(timeout=30) + except RuntimeError: + # No event loop running, safe to use asyncio.run() + return asyncio.run(coro) + + +class _DownloadRegistry: + """In-memory registry for download tokens and file data.""" + + def __init__(self) -> None: + self._tokens: dict[str, DownloadInfo] = {} + self._file_data: dict[str, bytes] = {} + self._lock = asyncio.Lock() + + async def store( + self, + token: str, + info: DownloadInfo, + file_data: bytes, + ) -> None: + """Store token with metadata and file data.""" + async with self._lock: + self._tokens[token] = info + self._file_data[token] = file_data + + async def get_info(self, token: str) -> DownloadInfo | None: + """Retrieve token metadata.""" + async with self._lock: + return self._tokens.get(token) + + async def get_file_data(self, token: str) -> bytes | None: + """Retrieve file data.""" + async with self._lock: + return self._file_data.get(token) + + async def update_attempts(self, token: str) -> bool: + """ + Decrement remaining attempts and return True if still valid. + Returns False if max attempts exceeded. + """ + async with self._lock: + if token not in self._tokens: + return False + info = self._tokens[token] + if info.attempts_remaining <= 0: + return False + info.attempts_remaining -= 1 + self._tokens[token] = info + return True + + async def delete(self, token: str) -> bool: + """Delete token and associated file data.""" + async with self._lock: + if token not in self._tokens: + return False + del self._tokens[token] + self._file_data.pop(token, None) + return True + + async def cleanup_expired(self) -> int: + """Delete all expired tokens. Returns count of deleted tokens.""" + async with self._lock: + now = datetime.now(timezone.utc) + expired_tokens = [ + token for token, info in self._tokens.items() if info.expires_at <= now + ] + for token in expired_tokens: + del self._tokens[token] + self._file_data.pop(token, None) + return len(expired_tokens) + + async def get_expired_count(self) -> int: + """Get count of expired tokens without deleting them.""" + async with self._lock: + now = datetime.now(timezone.utc) + return sum(1 for info in self._tokens.values() if info.expires_at <= now) + + +# Global registry instance +_registry = _DownloadRegistry() + + +class DownloadService(BaseService): + """Service for managing file downloads and streaming.""" + + def generate_download_token( + self, + file_path: Path | str, + user_id: str, + file_format: str = "json", + ) -> ServiceResult[DownloadTokenResponse]: + """ + Generate a download token for a file. + + Args: + file_path: Path to file to download + user_id: ID of user requesting download + file_format: Format of the exported file (json/jsonl/csv) + + Returns: + ServiceResult with DownloadTokenResponse containing token and URL + """ + try: + config = get_config() + file_path = Path(file_path) + + # Validate file exists + if not file_path.exists(): + return ServiceResult( + success=False, + error=f"File not found: {file_path}", + ) + + # Read file data + file_data = file_path.read_bytes() + file_size_mb = len(file_data) / (1024 * 1024) + + # Validate size + if file_size_mb > config.download.max_file_size_mb: + return ServiceResult( + success=False, + error=f"File size {file_size_mb:.2f}MB exceeds limit of {config.download.max_file_size_mb}MB", + ) + + # Generate token + token = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=config.download.retention_seconds) + + # Create metadata + info = DownloadInfo( + token=token, + user_id=user_id, + file_size_bytes=len(file_data), + created_at=now, + expires_at=expires_at, + attempts_remaining=config.download.max_retries, + file_format=file_format, + ) + + # Store token (async operation) + _run_async(_registry.store(token, info, file_data)) + + download_url = f"/api/v1/download/{token}" + response = DownloadTokenResponse( + token=token, + expires_at=expires_at, + download_url=download_url, + ) + + self.logger.info( + f"Generated download token {token} for user {user_id} " + f"(file size: {file_size_mb:.2f}MB, expires: {expires_at.isoformat()})" + ) + + return ServiceResult(success=True, data=response) + + except Exception as e: + self.logger.error(f"Error generating download token: {e}") + return ServiceResult(success=False, error=str(e)) + + def validate_download_request( + self, + token: str, + user_id: str, + ) -> ServiceResult[bool]: + """ + Validate a download request token. + + Args: + token: Download token to validate + user_id: ID of user requesting download + + Returns: + ServiceResult with True if valid, False otherwise + """ + try: + config = get_config() + + # Get token info + info = _run_async(_registry.get_info(token)) + + if info is None: + return ServiceResult(success=False, error="Token not found") + + # Check expiration + now = datetime.now(timezone.utc) + if info.expires_at <= now: + self.logger.warning(f"Expired token accessed: {token}") + return ServiceResult(success=False, error="Token expired") + + # Check user match if required + if config.download.require_same_user and info.user_id != user_id: + self.logger.warning( + f"Unauthorized download attempt: token owner {info.user_id}, " + f"requester {user_id}" + ) + return ServiceResult( + success=False, + error="Unauthorized: token belongs to different user", + ) + + # Check attempts remaining + if info.attempts_remaining <= 0: + self.logger.warning(f"Max retries exceeded for token: {token}") + return ServiceResult( + success=False, + error="Maximum download attempts exceeded", + ) + + # Decrement attempts + still_valid = _run_async(_registry.update_attempts(token)) + if not still_valid: + return ServiceResult( + success=False, + error="Download attempts exhausted", + ) + + remaining = info.attempts_remaining - 1 + self.logger.info( + f"Valid download request for token {token} " + f"({remaining} attempts remaining)" + ) + + return ServiceResult(success=True, data=True) + + except Exception as e: + self.logger.error(f"Error validating download request: {e}") + return ServiceResult(success=False, error=str(e)) + + def get_file_stream( + self, + token: str, + ) -> ServiceResult[Iterator[bytes]]: + """ + Get file data as a streaming iterator. + + Args: + token: Download token + + Returns: + ServiceResult with Iterator yielding file chunks + """ + try: + config = get_config() + file_data = _run_async(_registry.get_file_data(token)) + + if file_data is None: + return ServiceResult( + success=False, + error="File data not found", + ) + + def chunk_iterator() -> Iterator[bytes]: + """Yield file data in chunks.""" + chunk_size = config.download.stream_chunk_size + for i in range(0, len(file_data), chunk_size): + yield file_data[i : i + chunk_size] + + return ServiceResult(success=True, data=chunk_iterator()) + + except Exception as e: + self.logger.error(f"Error getting file stream: {e}") + return ServiceResult(success=False, error=str(e)) + + def cleanup_expired_tokens(self) -> ServiceResult[int]: + """ + Delete all expired tokens. Typically called by background job. + + Returns: + ServiceResult with count of tokens cleaned up + """ + try: + count = _run_async(_registry.cleanup_expired()) + if count > 0: + self.logger.info(f"Cleanup job: deleted {count} expired tokens") + return ServiceResult(success=True, data=count) + except Exception as e: + self.logger.error(f"Error during token cleanup: {e}") + return ServiceResult(success=False, error=str(e), data=0) + + def delete_token(self, token: str) -> ServiceResult[bool]: + """ + Delete a token and its associated file data after download completes. + + Args: + token: Download token to delete + + Returns: + ServiceResult with success status + """ + try: + result = _run_async(_registry.delete(token)) + if result: + self.logger.info(f"Token deleted after successful download: {token}") + return ServiceResult(success=True, data=True) + else: + self.logger.error(f"Token not found for deletion: {token}") + return ServiceResult(success=False, error="Token not found") + except Exception as e: + self.logger.error(f"Error deleting token {token}: {e}") + return ServiceResult(success=False, error=str(e)) diff --git a/tests/integration/api/__init__.py b/tests/integration/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/api/test_cleanup_scheduler.py b/tests/integration/api/test_cleanup_scheduler.py new file mode 100644 index 0000000..6c1d948 --- /dev/null +++ b/tests/integration/api/test_cleanup_scheduler.py @@ -0,0 +1,78 @@ +""" +Integration tests for download cleanup scheduler. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest + +from ciberwebscan.services.cleanup_scheduler import ( + DownloadCleanupScheduler, + get_scheduler, +) +from ciberwebscan.services.download_service import DownloadService + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def scheduler() -> DownloadCleanupScheduler: + """Create a scheduler instance.""" + return DownloadCleanupScheduler() + + +@pytest.fixture +def test_file() -> Path: + """Create a temporary test file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + f.write('{"test": "data"}') + return Path(f.name) + + +# ============================================================================= +# Tests +# ============================================================================= + + +class TestDownloadCleanupScheduler: + """Tests for cleanup scheduler.""" + + @pytest.mark.asyncio + async def test_scheduler_starts_and_stops( + self, scheduler: DownloadCleanupScheduler + ): + """Scheduler starts and stops correctly.""" + scheduler.start() + assert scheduler._running is True + assert scheduler._task is not None + + scheduler.stop() + assert scheduler._running is False + + def test_scheduler_singleton_pattern(self): + """Scheduler follows singleton pattern via get_scheduler.""" + sched1 = get_scheduler() + sched2 = get_scheduler() + + assert sched1 is sched2, "Should return same instance" + + def test_cleanup_service_integration(self, test_file: Path): + """Scheduler integrates with DownloadService.""" + service = DownloadService() + + # Generate token + result = service.generate_download_token( + file_path=test_file, user_id="test_user", file_format="json" + ) + + assert result.success + assert result.data.token is not None + + # Cleanup should work + cleanup_result = service.cleanup_expired_tokens() + assert cleanup_result.success diff --git a/tests/integration/api/test_download_endpoint.py b/tests/integration/api/test_download_endpoint.py new file mode 100644 index 0000000..e68b0bc --- /dev/null +++ b/tests/integration/api/test_download_endpoint.py @@ -0,0 +1,83 @@ +""" +Integration tests for download endpoint. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app +from ciberwebscan.services.download_service import DownloadService + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the API.""" + app = create_app() + return TestClient(app) + + +@pytest.fixture +def test_file() -> Path: + """Create a temporary test file for downloading.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + f.write('{"test": "data", "result": "sample"}') + return Path(f.name) + + +# ============================================================================= +# Tests +# ============================================================================= + + +class TestDownloadEndpoint: + """Integration tests for GET /api/download/{token} endpoint.""" + + def test_download_endpoint_registered(self, client: TestClient): + """Verify endpoint is registered (returns 401 for auth, not 404 for route).""" + # The endpoint returns 401 if not authenticated, not 404 if route doesn't exist + response = client.get("/api/download/test-token") + # Should NOT be 404 (route not found) - should be 401 (auth required) + assert response.status_code != 404, "Endpoint not registered" + + def test_download_requires_auth(self, client: TestClient, test_file: Path): + """Endpoint requires authentication.""" + service = DownloadService() + result = service.generate_download_token( + file_path=test_file, user_id="test_user", file_format="json" + ) + token = result.data.token + + # Try without auth - should be rejected + response = client.get(f"/api/download/{token}") + assert response.status_code in [401, 403], f"Got {response.status_code}" + + def test_download_endpoint_with_api_key(self, client: TestClient, test_file: Path): + """Endpoint can be called with API key auth.""" + from ciberwebscan.config.loader import get_config + + config = get_config() + if not config.api.auth.api_keys: + pytest.skip("No API keys configured") + + service = DownloadService() + result = service.generate_download_token( + file_path=test_file, user_id="test_user", file_format="json" + ) + token = result.data.token + api_key = config.api.auth.api_keys[0] + + # Try with valid API key - should not be auth error + response = client.get(f"/api/download/{token}", headers={"X-API-Key": api_key}) + # Should be 200 (success) or 400/404 (token error), not 401 (auth error) + assert response.status_code != 401, ( + f"Auth should work with valid API key, got {response.status_code}" + ) diff --git a/tests/integration/api/test_download_endpoint_clean.py b/tests/integration/api/test_download_endpoint_clean.py new file mode 100644 index 0000000..e68b0bc --- /dev/null +++ b/tests/integration/api/test_download_endpoint_clean.py @@ -0,0 +1,83 @@ +""" +Integration tests for download endpoint. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app +from ciberwebscan.services.download_service import DownloadService + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the API.""" + app = create_app() + return TestClient(app) + + +@pytest.fixture +def test_file() -> Path: + """Create a temporary test file for downloading.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + f.write('{"test": "data", "result": "sample"}') + return Path(f.name) + + +# ============================================================================= +# Tests +# ============================================================================= + + +class TestDownloadEndpoint: + """Integration tests for GET /api/download/{token} endpoint.""" + + def test_download_endpoint_registered(self, client: TestClient): + """Verify endpoint is registered (returns 401 for auth, not 404 for route).""" + # The endpoint returns 401 if not authenticated, not 404 if route doesn't exist + response = client.get("/api/download/test-token") + # Should NOT be 404 (route not found) - should be 401 (auth required) + assert response.status_code != 404, "Endpoint not registered" + + def test_download_requires_auth(self, client: TestClient, test_file: Path): + """Endpoint requires authentication.""" + service = DownloadService() + result = service.generate_download_token( + file_path=test_file, user_id="test_user", file_format="json" + ) + token = result.data.token + + # Try without auth - should be rejected + response = client.get(f"/api/download/{token}") + assert response.status_code in [401, 403], f"Got {response.status_code}" + + def test_download_endpoint_with_api_key(self, client: TestClient, test_file: Path): + """Endpoint can be called with API key auth.""" + from ciberwebscan.config.loader import get_config + + config = get_config() + if not config.api.auth.api_keys: + pytest.skip("No API keys configured") + + service = DownloadService() + result = service.generate_download_token( + file_path=test_file, user_id="test_user", file_format="json" + ) + token = result.data.token + api_key = config.api.auth.api_keys[0] + + # Try with valid API key - should not be auth error + response = client.get(f"/api/download/{token}", headers={"X-API-Key": api_key}) + # Should be 200 (success) or 400/404 (token error), not 401 (auth error) + assert response.status_code != 401, ( + f"Auth should work with valid API key, got {response.status_code}" + ) diff --git a/tests/integration/cli/test_attack_cli.py b/tests/integration/cli/test_attack_cli.py index 5e82e72..94cee02 100644 --- a/tests/integration/cli/test_attack_cli.py +++ b/tests/integration/cli/test_attack_cli.py @@ -16,6 +16,15 @@ import httpx import pytest +# the integration suite relies on the small FastAPI test server and +# uvicorn to run it. python-multipart is also needed by the server for +# form handling. we do not want these packages to be pulled in for +# users who only install the CLI, so skip the entire module if they are +# absent. +pytest.importorskip("fastapi") +pytest.importorskip("uvicorn") +pytest.importorskip("python_multipart") + # Test server URL TEST_SERVER_URL = "http://127.0.0.1:5555" diff --git a/tests/testserver.py b/tests/testserver.py index 0419b51..3f3e38e 100644 --- a/tests/testserver.py +++ b/tests/testserver.py @@ -9,9 +9,19 @@ import logging +import pytest from fastapi import FastAPI, Form, Query from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse +# when running under a bare CLI installation the API packages may not be +# available; tests that depend on this server should skip themselves +# instead of crashing during import. pytest.importorskip will raise a +# Skip exception but evaluating it here would skip import of the whole +# module which is fine because nothing else uses it without checking. +pytest.importorskip("fastapi") +pytest.importorskip("python_multipart") + + # Disable uvicorn logging in tests log = logging.getLogger("uvicorn") log.setLevel(logging.ERROR) diff --git a/tests/unit/api/__init__.py b/tests/unit/api/__init__.py new file mode 100644 index 0000000..455d0c6 --- /dev/null +++ b/tests/unit/api/__init__.py @@ -0,0 +1 @@ +"""Unit tests for API module.""" diff --git a/tests/unit/api/routes/__init__.py b/tests/unit/api/routes/__init__.py new file mode 100644 index 0000000..abe297f --- /dev/null +++ b/tests/unit/api/routes/__init__.py @@ -0,0 +1 @@ +"""Unit tests for API routes.""" diff --git a/tests/unit/api/routes/test_analyze_routes.py b/tests/unit/api/routes/test_analyze_routes.py new file mode 100644 index 0000000..0b958cc --- /dev/null +++ b/tests/unit/api/routes/test_analyze_routes.py @@ -0,0 +1,128 @@ +""" +Unit tests for the analysis API endpoint (POST /api/analyze). +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app +from ciberwebscan.export.models import AnalysisReport, ExportMeta + + +@pytest.fixture +def client(): + """Create a test client with auth dependency overridden.""" + from ciberwebscan.api.auth import AuthenticatedUser, get_current_user + + app = create_app() + + def mock_get_current_user() -> AuthenticatedUser: + return AuthenticatedUser( + identifier="test-user", + auth_method="api_key", + scopes=["read", "write"], + ) + + app.dependency_overrides[get_current_user] = mock_get_current_user + yield TestClient(app) + app.dependency_overrides.clear() + + +@pytest.fixture +def mock_analysis_report() -> AnalysisReport: + """Minimal valid AnalysisReport for serialization.""" + return AnalysisReport(meta=ExportMeta(target_url="https://example.com")) + + +def _make_service_result(data=None, success: bool = True, error: str | None = None): + result = MagicMock() + result.success = success + result.data = data + result.error = error + return result + + +class TestAnalyzeEndpoint: + """Tests for POST /api/analyze.""" + + def test_post_analyze_returns_200(self, client: TestClient, mock_analysis_report): + with patch( + "ciberwebscan.api.routes.analyze.AnalyzeService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.analyze.return_value = _make_service_result( + data=mock_analysis_report + ) + + response = client.post("/api/analyze", json={"url": "https://example.com"}) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["meta"]["target_url"] == "https://example.com" + + def test_post_analyze_forwards_options( + self, client: TestClient, mock_analysis_report + ): + with patch( + "ciberwebscan.api.routes.analyze.AnalyzeService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.analyze.return_value = _make_service_result( + data=mock_analysis_report + ) + + response = client.post( + "/api/analyze", + json={ + "url": "https://example.com", + "ssl": False, + "cve_sources": ["nvd", "circl"], + "headers": {"X-Test": "1"}, + }, + ) + + assert response.status_code == 200 + options = mock_service.analyze.call_args[0][0] + assert options.ssl is False + assert options.cve_sources == ["nvd", "circl"] + assert options.headers == {"X-Test": "1"} + + def test_post_analyze_missing_url_returns_422(self, client: TestClient): + response = client.post("/api/analyze", json={"ssl": True}) + assert response.status_code == 422 + + def test_post_analyze_service_failure_returns_500(self, client: TestClient): + with patch( + "ciberwebscan.api.routes.analyze.AnalyzeService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.analyze.return_value = _make_service_result( + success=False, + error="Analysis failed in service", + ) + + response = client.post("/api/analyze", json={"url": "https://example.com"}) + + assert response.status_code == 500 + assert "Analysis failed in service" in response.json()["detail"] + + def test_post_analyze_unexpected_exception_returns_500(self, client: TestClient): + with patch( + "ciberwebscan.api.routes.analyze.AnalyzeService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.analyze.side_effect = RuntimeError("boom") + + response = client.post("/api/analyze", json={"url": "https://example.com"}) + + assert response.status_code == 500 + assert "boom" in response.json()["detail"] diff --git a/tests/unit/api/routes/test_attack_routes.py b/tests/unit/api/routes/test_attack_routes.py new file mode 100644 index 0000000..0c27f4c --- /dev/null +++ b/tests/unit/api/routes/test_attack_routes.py @@ -0,0 +1,297 @@ +""" +Unit tests for the attack simulation API endpoint (POST /api/attack). +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app +from ciberwebscan.services.base import ValidationError as ServiceValidationError + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def client(): + """Test client with authentication bypassed.""" + from ciberwebscan.api.auth import AuthenticatedUser, get_current_user + + app = create_app() + + def mock_get_current_user() -> AuthenticatedUser: + return AuthenticatedUser( + identifier="test-user", + auth_method="api_key", + scopes=["read", "write"], + ) + + app.dependency_overrides[get_current_user] = mock_get_current_user + yield TestClient(app) + app.dependency_overrides.clear() + + +@pytest.fixture +def mock_attack_result(): + """A minimal AttackResult mock that can be serialised by FastAPI.""" + from ciberwebscan.export.models import AttackResult + + return AttackResult( + target_url="https://example.com", + vulnerabilities=[], + total_payloads_tested=10, + total_findings=0, + xss_findings=0, + sqli_findings=0, + traversal_findings=0, + enumeration_findings=0, + duration_seconds=1.5, + ) + + +def _make_service_result(data=None, success=True, error=None): + """Build a mock ServiceResult.""" + result = MagicMock() + result.success = success + result.data = data + result.error = error + return result + + +BASE_PAYLOAD = { + "url": "https://example.com", + "xss": True, + "user_consent": True, +} + + +# ============================================================================= +# Happy-path tests +# ============================================================================= + + +class TestAttackEndpointSuccess: + """Tests for successful attack endpoint calls.""" + + def test_post_attack_xss_returns_200(self, client, mock_attack_result): + """POST /api/attack with xss=True returns 200 and attack data.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.return_value = _make_service_result( + data=mock_attack_result + ) + + response = client.post("/api/attack", json=BASE_PAYLOAD) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["target_url"] == "https://example.com" + assert data["data"]["total_findings"] == 0 + + def test_post_attack_all_attacks_shortcut(self, client, mock_attack_result): + """all_attacks=True enables all four attack types.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.return_value = _make_service_result( + data=mock_attack_result + ) + + response = client.post( + "/api/attack", + json={ + "url": "https://example.com", + "all_attacks": True, + "user_consent": True, + }, + ) + + assert response.status_code == 200 + # Verify ALL four attack types were set to True in AttackOptions + call_kwargs = mock_service.attack.call_args[0][0] + assert call_kwargs.xss is True + assert call_kwargs.sqli is True + assert call_kwargs.traversal is True + assert call_kwargs.enumeration is True + + def test_post_attack_intensity_high(self, client, mock_attack_result): + """intensity=high is forwarded to AttackOptions correctly.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.return_value = _make_service_result( + data=mock_attack_result + ) + + response = client.post( + "/api/attack", + json={**BASE_PAYLOAD, "intensity": "high"}, + ) + + assert response.status_code == 200 + options = mock_service.attack.call_args[0][0] + assert options.intensity == "high" + + def test_post_attack_sqli_and_traversal(self, client, mock_attack_result): + """Multiple attack types can be enabled at the same time.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.return_value = _make_service_result( + data=mock_attack_result + ) + + response = client.post( + "/api/attack", + json={ + "url": "https://example.com", + "sqli": True, + "traversal": True, + "user_consent": True, + }, + ) + + assert response.status_code == 200 + options = mock_service.attack.call_args[0][0] + assert options.sqli is True + assert options.traversal is True + + +# ============================================================================= +# Validation / consent tests +# ============================================================================= + + +class TestAttackEndpointValidation: + """Tests for request validation on the attack endpoint.""" + + def test_missing_user_consent_returns_422(self, client): + """user_consent=false is rejected by Pydantic at the request level.""" + response = client.post( + "/api/attack", + json={"url": "https://example.com", "xss": True, "user_consent": False}, + ) + # Pydantic field_validator raises ValueError → FastAPI returns 422 + assert response.status_code == 422 + + def test_missing_url_returns_422(self, client): + """Missing url field returns 422 Unprocessable Entity.""" + response = client.post( + "/api/attack", + json={"xss": True, "user_consent": True}, + ) + assert response.status_code == 422 + + def test_invalid_intensity_returns_422(self, client): + """Invalid intensity value (not low/medium/high) returns 422.""" + response = client.post( + "/api/attack", + json={**BASE_PAYLOAD, "intensity": "extreme"}, + ) + assert response.status_code == 422 + + def test_service_validation_error_returns_400(self, client): + """ServiceValidationError (e.g. attack disabled in config) returns 400.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.side_effect = ServiceValidationError( + "Attack simulation is disabled in configuration." + ) + + response = client.post("/api/attack", json=BASE_PAYLOAD) + + assert response.status_code == 400 + assert "disabled" in response.json()["detail"].lower() + + def test_service_validation_error_whitelist_returns_400(self, client): + """ServiceValidationError for whitelist violation returns 400.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.side_effect = ServiceValidationError( + "Target host 'example.com' is not in the attack whitelist." + ) + + response = client.post("/api/attack", json=BASE_PAYLOAD) + + assert response.status_code == 400 + assert "whitelist" in response.json()["detail"].lower() + + +# ============================================================================= +# Error handling tests +# ============================================================================= + + +class TestAttackEndpointErrors: + """Tests for error handling in the attack endpoint.""" + + def test_service_result_failure_returns_500(self, client): + """When service returns success=False, the endpoint raises HTTP 500.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.return_value = _make_service_result( + success=False, error="Attack execution failed" + ) + + response = client.post("/api/attack", json=BASE_PAYLOAD) + + assert response.status_code == 500 + # detail = result.error when set, otherwise "Attack simulation failed" + assert "Attack execution failed" in response.json()["detail"] + + def test_unexpected_exception_returns_500(self, client): + """Unhandled exceptions from the service layer return HTTP 500.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.side_effect = RuntimeError("Something went wrong") + + response = client.post("/api/attack", json=BASE_PAYLOAD) + + assert response.status_code == 500 + assert "Something went wrong" in response.json()["detail"] + + def test_response_shape(self, client, mock_attack_result): + """Response always contains success, data, and timestamp fields.""" + with patch( + "ciberwebscan.api.routes.attack.AttackService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + mock_service.attack.return_value = _make_service_result( + data=mock_attack_result + ) + + response = client.post("/api/attack", json=BASE_PAYLOAD) + + assert response.status_code == 200 + body = response.json() + assert "success" in body + assert "data" in body + assert "timestamp" in body diff --git a/tests/unit/api/routes/test_auth_routes.py b/tests/unit/api/routes/test_auth_routes.py new file mode 100644 index 0000000..e51cd42 --- /dev/null +++ b/tests/unit/api/routes/test_auth_routes.py @@ -0,0 +1,77 @@ +""" +Unit tests for auth route endpoints. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app + + +@pytest.fixture +def client(): + """Create a test client with auth dependency overridden.""" + from ciberwebscan.api.auth import AuthenticatedUser, get_current_user + + app = create_app() + + def mock_get_current_user() -> AuthenticatedUser: + return AuthenticatedUser( + identifier="test-user", + auth_method="api_key", + scopes=["read", "write"], + ) + + app.dependency_overrides[get_current_user] = mock_get_current_user + yield TestClient(app) + app.dependency_overrides.clear() + + +class TestAuthRouteEndpoints: + """Tests for /api/auth routes.""" + + def test_get_me_returns_current_user(self, client: TestClient): + response = client.get("/api/auth/me") + + assert response.status_code == 200 + body = response.json() + assert body["identifier"] == "test-user" + assert body["auth_method"] == "api_key" + assert body["authenticated"] is True + + def test_generate_key_without_admin_scope_returns_403(self, client: TestClient): + response = client.post("/api/auth/generate-key") + + assert response.status_code == 403 + assert "Admin access required" in response.json()["detail"] + + def test_generate_key_with_admin_scope_returns_200(self): + from ciberwebscan.api.auth import AuthenticatedUser, get_current_user + + app = create_app() + + def mock_admin_user() -> AuthenticatedUser: + return AuthenticatedUser( + identifier="admin-user", + auth_method="api_key", + scopes=["admin"], + ) + + app.dependency_overrides[get_current_user] = mock_admin_user + + with patch( + "ciberwebscan.api.routes.auth.generate_api_key", + return_value="fixed-generated-key", + ): + response = TestClient(app).post("/api/auth/generate-key") + + app.dependency_overrides.clear() + + assert response.status_code == 200 + body = response.json() + assert body["api_key"] == "fixed-generated-key" + assert "Store this key securely" in body["message"] diff --git a/tests/unit/api/routes/test_config_routes.py b/tests/unit/api/routes/test_config_routes.py new file mode 100644 index 0000000..a7c4cf4 --- /dev/null +++ b/tests/unit/api/routes/test_config_routes.py @@ -0,0 +1,488 @@ +""" +Tests for configuration management endpoints. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app + + +@pytest.fixture +def client(): + """Create a test client with a mocked user.""" + from ciberwebscan.api.auth import AuthenticatedUser, get_current_user + + app = create_app() + + # Mock the authentication dependency + def mock_get_current_user() -> AuthenticatedUser: + return AuthenticatedUser( + identifier="test-user", + auth_method="api_key", + scopes=["read", "write"], + ) + + app.dependency_overrides[get_current_user] = mock_get_current_user + client = TestClient(app) + + yield client + + # Clean up dependency overrides + app.dependency_overrides.clear() + + +@pytest.fixture +def temp_config_dir(): + """Create a temporary directory for config files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +class TestConfigEndpoints: + """Test configuration management endpoints.""" + + def test_get_all_config(self, client): + """Test GET /api/config returns all configuration.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the get_all result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = { + "scraping": {"timeout": 30}, + "analysis": {"enabled": True}, + } + mock_service.get_all.return_value = mock_result + + response = client.get("/api/config") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "data" in data + assert "scraping" in data["data"] + + def test_get_config_section(self, client): + """Test GET /api/config/sections/{section} returns section config.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the get_section result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = {"timeout": 30, "retries": 3} + mock_service.get_section.return_value = mock_result + + response = client.get("/api/config/sections/scraping") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["timeout"] == 30 + + def test_get_config_section_not_found(self, client): + """Test GET /api/config/sections/{section} with invalid section.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the get_section failure + mock_result = MagicMock() + mock_result.success = False + mock_result.error = "Section not found: invalid" + mock_service.get_section.return_value = mock_result + + response = client.get("/api/config/sections/invalid") + + assert response.status_code == 404 + + def test_get_config_value(self, client): + """Test GET /api/config/value returns specific value with metadata.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock ConfigValue + from ciberwebscan.services.config_service import ConfigValue + + config_value = ConfigValue( + key="scraping.timeout", + value=30, + default=30, + source="default", + description="Timeout for scraping operations", + ) + + mock_result = MagicMock() + mock_result.success = True + mock_result.data = config_value + mock_service.get.return_value = mock_result + + response = client.get("/api/config/value?path=scraping.timeout") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["value"] == 30 + assert data["data"]["source"] == "default" + + def test_get_config_value_not_found(self, client): + """Test GET /api/config/value with invalid key.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the get failure + mock_result = MagicMock() + mock_result.success = False + mock_result.error = "Configuration key not found: invalid.key" + mock_service.get.return_value = mock_result + + response = client.get("/api/config/value?path=invalid.key") + + assert response.status_code == 404 + + def test_update_config(self, client): + """Test PUT /api/config updates a configuration value.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock ConfigValue + from ciberwebscan.services.config_service import ConfigValue + + config_value = ConfigValue( + key="scraping.timeout", + value=60, + default=30, + source="runtime", + description="Timeout for scraping operations", + ) + + mock_result = MagicMock() + mock_result.success = True + mock_result.data = config_value + mock_service.set.return_value = mock_result + + response = client.put( + "/api/config", + json={"path": "scraping.timeout", "value": 60, "save": False}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["value"] == 60 + assert data["data"]["source"] == "runtime" + + def test_update_config_with_save(self, client): + """Test PUT /api/config with save=True.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock ConfigValue + from ciberwebscan.services.config_service import ConfigValue + + config_value = ConfigValue( + key="scraping.timeout", + value=60, + default=30, + source="runtime", + description="Timeout for scraping operations", + ) + + mock_result = MagicMock() + mock_result.success = True + mock_result.data = config_value + mock_service.set.return_value = mock_result + + save_result = MagicMock() + save_result.success = True + save_result.data = Path("/tmp/config.yaml") + mock_service.save.return_value = save_result + + response = client.put( + "/api/config", + json={"path": "scraping.timeout", "value": 60, "save": True}, + ) + + assert response.status_code == 200 + mock_service.save.assert_called_once() + + def test_reset_config_all(self, client): + """Test POST /api/config/reset resets all configuration.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the reset result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = True + mock_service.reset.return_value = mock_result + + response = client.post( + "/api/config/reset", + json={"path": None, "save": False}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["reset"] is True + assert data["data"]["path"] == "all" + + def test_reset_config_key(self, client): + """Test POST /api/config/reset resets specific key.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the reset result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = True + mock_service.reset.return_value = mock_result + + response = client.post( + "/api/config/reset", + json={"path": "scraping.timeout", "save": False}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["path"] == "scraping.timeout" + + def test_list_config_keys(self, client): + """Test GET /api/config/keys lists all configuration keys.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the list_keys result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = [ + "scraping.timeout", + "scraping.retries", + "analysis.enabled", + ] + mock_service.list_keys.return_value = mock_result + + response = client.get("/api/config/keys") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["count"] == 3 + assert "scraping.timeout" in data["data"]["keys"] + + def test_list_config_keys_with_section(self, client): + """Test GET /api/config/keys?section=... filters by section.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the list_keys result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = ["scraping.timeout", "scraping.retries"] + mock_service.list_keys.return_value = mock_result + + response = client.get("/api/config/keys?section=scraping") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_service.list_keys.assert_called_once_with("scraping") + + def test_export_config_yaml(self, client, temp_config_dir): + """Test POST /api/config/export exports to YAML.""" + output_path = temp_config_dir / "config.yaml" + + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the export result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = output_path + mock_service.export_config.return_value = mock_result + + response = client.post( + "/api/config/export", + json={"path": str(output_path), "format": "yaml"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["operation"] == "export" + assert data["data"]["format"] == "yaml" + + def test_export_config_json(self, client, temp_config_dir): + """Test POST /api/config/export exports to JSON.""" + output_path = temp_config_dir / "config.json" + + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the export result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = output_path + mock_service.export_config.return_value = mock_result + + response = client.post( + "/api/config/export", + json={"path": str(output_path), "format": "json"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["data"]["format"] == "json" + + def test_export_config_invalid_format(self, client, temp_config_dir): + """Test POST /api/config/export with invalid format.""" + output_path = temp_config_dir / "config.txt" + + response = client.post( + "/api/config/export", + json={"path": str(output_path), "format": "txt"}, + ) + + assert response.status_code == 400 + + def test_load_config(self, client, temp_config_dir): + """Test POST /api/config/load loads configuration from file.""" + config_path = temp_config_dir / "config.yaml" + + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the load result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = {"scraping": {"timeout": 30}} + mock_service.load.return_value = mock_result + + response = client.post( + "/api/config/load", + json={"path": str(config_path)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "scraping" in data["data"] + + def test_load_config_file_not_found(self, client, temp_config_dir): + """Test POST /api/config/load with non-existent file.""" + config_path = temp_config_dir / "nonexistent.yaml" + + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the load failure + mock_result = MagicMock() + mock_result.success = False + mock_result.error = f"Config file not found: {config_path}" + mock_service.load.return_value = mock_result + + response = client.post( + "/api/config/load", + json={"path": str(config_path)}, + ) + + assert response.status_code == 404 + + def test_save_config(self, client, temp_config_dir): + """Test POST /api/config/save saves configuration to file.""" + save_path = temp_config_dir / "config.yaml" + + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the save result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = save_path + mock_service.save.return_value = mock_result + + response = client.post( + "/api/config/save", + json={"path": str(save_path)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["data"]["operation"] == "save" + + def test_save_config_default_path(self, client): + """Test POST /api/config/save uses default path.""" + with patch( + "ciberwebscan.api.routes.config.ConfigService" + ) as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + # Mock the save result + mock_result = MagicMock() + mock_result.success = True + mock_result.data = Path.home() / ".ciberwebscan" / "config.yaml" + mock_service.save.return_value = mock_result + + response = client.post("/api/config/save") + + assert response.status_code == 200 + mock_service.save.assert_called_once_with(None) diff --git a/tests/unit/api/routes/test_health.py b/tests/unit/api/routes/test_health.py new file mode 100644 index 0000000..9a0bebb --- /dev/null +++ b/tests/unit/api/routes/test_health.py @@ -0,0 +1,261 @@ +""" +Tests for health check endpoints. + +Tests for /health and /health/ready endpoints. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from ciberwebscan.api.models.responses import HealthCheckResponse +from ciberwebscan.api.routes.health import router + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def app() -> FastAPI: + """Create a FastAPI app with health routes.""" + app = FastAPI() + app.include_router(router) + return app + + +@pytest.fixture +def client(app: FastAPI) -> TestClient: + """Create test client.""" + return TestClient(app) + + +# ============================================================================= +# HealthCheckResponse Model Tests +# ============================================================================= + + +class TestHealthCheckResponseModel: + """Tests for HealthCheckResponse Pydantic model.""" + + def test_health_response_creation(self): + """Test HealthCheckResponse model instantiation.""" + now = datetime.now(timezone.utc) + response = HealthCheckResponse( + status="healthy", + timestamp=now, + version="2.0.0", + message="Test message", + ) + + assert response.status == "healthy" + assert response.timestamp == now + assert response.version == "2.0.0" + assert response.message == "Test message" + + def test_health_response_serialization(self): + """Test HealthCheckResponse JSON serialization.""" + now = datetime.now(timezone.utc) + response = HealthCheckResponse( + status="ready", + timestamp=now, + version="1.0.0", + message="Ready", + ) + + data = response.model_dump() + assert data["status"] == "ready" + assert data["version"] == "1.0.0" + assert data["message"] == "Ready" + assert "timestamp" in data + + +# ============================================================================= +# Health Check Endpoint Tests +# ============================================================================= + + +class TestHealthCheckEndpoint: + """Tests for /health endpoint.""" + + def test_health_check_returns_200(self, client: TestClient): + """Test health check returns 200 status code.""" + response = client.get("/health") + assert response.status_code == 200 + + def test_health_check_returns_healthy_status(self, client: TestClient): + """Test health check returns healthy status.""" + response = client.get("/health") + data = response.json() + + assert data["status"] == "healthy" + + def test_health_check_returns_version(self, client: TestClient): + """Test health check includes version.""" + response = client.get("/health") + data = response.json() + + assert "version" in data + assert isinstance(data["version"], str) + assert len(data["version"]) > 0 + + def test_health_check_returns_timestamp(self, client: TestClient): + """Test health check includes timestamp.""" + response = client.get("/health") + data = response.json() + + assert "timestamp" in data + # Verify timestamp is parseable + timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) + assert timestamp is not None + + def test_health_check_returns_message(self, client: TestClient): + """Test health check includes message.""" + response = client.get("/health") + data = response.json() + + assert data["message"] == "CiberWebScan API is running" + + def test_health_check_response_structure(self, client: TestClient): + """Test health check response has all required fields.""" + response = client.get("/health") + data = response.json() + + expected_keys = {"status", "timestamp", "version", "message", "uptime_seconds"} + assert set(data.keys()) == expected_keys + + @patch("ciberwebscan.api.routes.health.__version__", "3.0.0-test") + def test_health_check_uses_package_version(self, client: TestClient): + """Test health check uses actual package version.""" + response = client.get("/health") + data = response.json() + + assert data["version"] == "3.0.0-test" + + +# ============================================================================= +# Readiness Check Endpoint Tests +# ============================================================================= + + +class TestReadinessCheckEndpoint: + """Tests for /health/ready endpoint.""" + + def test_readiness_check_returns_200(self, client: TestClient): + """Test readiness check returns 200 status code.""" + response = client.get("/health/ready") + assert response.status_code == 200 + + def test_readiness_check_returns_ready_status(self, client: TestClient): + """Test readiness check returns ready status.""" + response = client.get("/health/ready") + data = response.json() + + assert data["status"] == "ready" + + def test_readiness_check_returns_version(self, client: TestClient): + """Test readiness check includes version.""" + response = client.get("/health/ready") + data = response.json() + + assert "version" in data + assert isinstance(data["version"], str) + + def test_readiness_check_returns_timestamp(self, client: TestClient): + """Test readiness check includes timestamp.""" + response = client.get("/health/ready") + data = response.json() + + assert "timestamp" in data + timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) + assert timestamp is not None + + def test_readiness_check_returns_message(self, client: TestClient): + """Test readiness check includes message.""" + response = client.get("/health/ready") + data = response.json() + + assert data["message"] == "CiberWebScan API is ready to accept requests" + + def test_readiness_check_response_structure(self, client: TestClient): + """Test readiness check response has all required fields.""" + response = client.get("/health/ready") + data = response.json() + + expected_keys = {"status", "timestamp", "version", "message", "uptime_seconds"} + assert set(data.keys()) == expected_keys + + @patch("ciberwebscan.api.routes.health.__version__", "2.5.0-beta") + def test_readiness_check_uses_package_version(self, client: TestClient): + """Test readiness check uses actual package version.""" + response = client.get("/health/ready") + data = response.json() + + assert data["version"] == "2.5.0-beta" + + +# ============================================================================= +# Edge Cases & Integration +# ============================================================================= + + +class TestHealthEndpointsEdgeCases: + """Edge case tests for health endpoints.""" + + def test_health_endpoint_is_get_only(self, client: TestClient): + """Test health endpoint only accepts GET requests.""" + assert client.post("/health").status_code == 405 + assert client.put("/health").status_code == 405 + assert client.delete("/health").status_code == 405 + assert client.patch("/health").status_code == 405 + + def test_readiness_endpoint_is_get_only(self, client: TestClient): + """Test readiness endpoint only accepts GET requests.""" + assert client.post("/health/ready").status_code == 405 + assert client.put("/health/ready").status_code == 405 + assert client.delete("/health/ready").status_code == 405 + assert client.patch("/health/ready").status_code == 405 + + def test_health_and_ready_have_different_status(self, client: TestClient): + """Test health and ready endpoints return different status values.""" + health_response = client.get("/health").json() + ready_response = client.get("/health/ready").json() + + assert health_response["status"] == "healthy" + assert ready_response["status"] == "ready" + + def test_health_and_ready_have_different_messages(self, client: TestClient): + """Test health and ready endpoints return different messages.""" + health_response = client.get("/health").json() + ready_response = client.get("/health/ready").json() + + assert health_response["message"] != ready_response["message"] + + def test_timestamps_are_recent(self, client: TestClient): + """Test that timestamps are recent (within last minute).""" + before = datetime.now(timezone.utc) + response = client.get("/health") + after = datetime.now(timezone.utc) + + data = response.json() + timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) + + assert before <= timestamp <= after + + def test_multiple_requests_return_different_timestamps(self, client: TestClient): + """Test that each request gets a fresh timestamp.""" + response1 = client.get("/health") + response2 = client.get("/health") + + # Timestamps should be >= (could be same if fast enough) + ts1 = response1.json()["timestamp"] + ts2 = response2.json()["timestamp"] + + # Both should be valid timestamps + datetime.fromisoformat(ts1.replace("Z", "+00:00")) + datetime.fromisoformat(ts2.replace("Z", "+00:00")) diff --git a/tests/unit/api/routes/test_scrape_routes.py b/tests/unit/api/routes/test_scrape_routes.py new file mode 100644 index 0000000..8d6721d --- /dev/null +++ b/tests/unit/api/routes/test_scrape_routes.py @@ -0,0 +1,160 @@ +""" +Unit tests for scraping API endpoints. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from ciberwebscan.api.app import create_app +from ciberwebscan.export.models import ScrapeResult + + +@pytest.fixture +def client(): + """Create a test client with auth dependency overridden.""" + from ciberwebscan.api.auth import AuthenticatedUser, get_current_user + + app = create_app() + + def mock_get_current_user() -> AuthenticatedUser: + return AuthenticatedUser( + identifier="test-user", + auth_method="api_key", + scopes=["read", "write"], + ) + + app.dependency_overrides[get_current_user] = mock_get_current_user + yield TestClient(app) + app.dependency_overrides.clear() + + +@pytest.fixture +def mock_scrape_result() -> ScrapeResult: + """Minimal valid ScrapeResult.""" + return ScrapeResult( + url="https://example.com/", + status_code=200, + content_type="text/html", + title="Example", + text_content="ok", + ) + + +def _make_service_result(data=None, success: bool = True, error: str | None = None): + result = MagicMock() + result.success = success + result.data = data + result.error = error + result.duration_seconds = 1.2 + return result + + +class TestScrapeEndpoint: + """Tests for POST /api/scrape.""" + + def test_post_scrape_returns_200(self, client: TestClient, mock_scrape_result): + with patch("ciberwebscan.api.routes.scrape.ScrapeService") as mock_service_cls: + mock_service = MagicMock() + mock_service_cls.return_value = mock_service + mock_service.scrape.return_value = _make_service_result( + data=mock_scrape_result + ) + + response = client.post("/api/scrape", json={"url": "https://example.com"}) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["url"] == "https://example.com/" + + def test_post_scrape_service_failure_returns_500(self, client: TestClient): + with patch("ciberwebscan.api.routes.scrape.ScrapeService") as mock_service_cls: + mock_service = MagicMock() + mock_service_cls.return_value = mock_service + mock_service.scrape.return_value = _make_service_result( + success=False, + error="scrape failed", + ) + + response = client.post("/api/scrape", json={"url": "https://example.com"}) + + assert response.status_code == 500 + assert "scrape failed" in response.json()["detail"] + + def test_post_scrape_none_data_returns_500(self, client: TestClient): + with patch("ciberwebscan.api.routes.scrape.ScrapeService") as mock_service_cls: + mock_service = MagicMock() + mock_service_cls.return_value = mock_service + mock_service.scrape.return_value = _make_service_result(data=None) + + response = client.post("/api/scrape", json={"url": "https://example.com"}) + + assert response.status_code == 500 + assert "returned no data" in response.json()["detail"] + + +class TestScrapeBatchEndpoint: + """Tests for POST /api/scrape/batch.""" + + def test_post_scrape_batch_returns_200( + self, + client: TestClient, + mock_scrape_result: ScrapeResult, + ): + second_result = ScrapeResult( + url="https://example.org/", + status_code=200, + content_type="text/html", + title="Example Org", + text_content="ok", + ) + + with patch("ciberwebscan.api.routes.scrape.ScrapeService") as mock_service_cls: + mock_service = MagicMock() + mock_service_cls.return_value = mock_service + mock_service.scrape_multiple.return_value = _make_service_result( + data=[mock_scrape_result, second_result] + ) + + response = client.post( + "/api/scrape/batch", + json={ + "urls": [ + "https://example.com", + "https://example.org", + "https://example.net", + ] + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["total_success"] == 2 + assert body["data"]["total_failed"] == 1 + assert body["data"]["failed_urls"][0]["url"] == "https://example.net/" + + def test_post_scrape_batch_service_failure_returns_500(self, client: TestClient): + with patch("ciberwebscan.api.routes.scrape.ScrapeService") as mock_service_cls: + mock_service = MagicMock() + mock_service_cls.return_value = mock_service + mock_service.scrape_multiple.return_value = _make_service_result( + success=False, + error="batch failed", + ) + + response = client.post( + "/api/scrape/batch", + json={"urls": ["https://example.com"]}, + ) + + assert response.status_code == 500 + assert "batch failed" in response.json()["detail"] + + def test_post_scrape_batch_empty_urls_returns_422(self, client: TestClient): + response = client.post("/api/scrape/batch", json={"urls": []}) + assert response.status_code == 422 diff --git a/tests/unit/api/test_auth.py b/tests/unit/api/test_auth.py new file mode 100644 index 0000000..9ab7551 --- /dev/null +++ b/tests/unit/api/test_auth.py @@ -0,0 +1,353 @@ +""" +Tests for API authentication module. + +Tests for API Key authentication with security best practices. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from ciberwebscan.api.auth import ( + _secure_compare_key, + generate_api_key, + get_auth_config, + verify_api_key, +) +from ciberwebscan.api.routes.auth import router as auth_router + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def test_api_key() -> str: + """Test API key.""" + return "test-api-key-12345" + + +def _create_mock_config(api_keys: list[str] | None = None) -> MagicMock: + """Create a mock config object with api.auth settings.""" + mock_config = MagicMock() + mock_config.api.auth.api_keys = api_keys or [] + return mock_config + + +def _create_mock_request(client_ip: str = "127.0.0.1") -> Mock: + """Create a mock request object.""" + mock_request = Mock(spec=Request) + mock_request.headers = {} + mock_request.client = Mock() + mock_request.client.host = client_ip + mock_request.method = "GET" + mock_request.url = Mock() + mock_request.url.path = "/test" + return mock_request + + +@pytest.fixture +def auth_config_patch(test_api_key: str): + """Patch get_config to return test auth configuration.""" + mock_config = _create_mock_config(api_keys=[test_api_key]) + with patch("ciberwebscan.api.auth.get_config", return_value=mock_config): + yield + + +@pytest.fixture +def app() -> FastAPI: + """Create a FastAPI app with auth routes.""" + app = FastAPI() + app.include_router(auth_router, prefix="/auth") + return app + + +@pytest.fixture +def client(app: FastAPI) -> TestClient: + """Create test client.""" + return TestClient(app) + + +# ============================================================================= +# AuthConfig Tests +# ============================================================================= + + +class TestAuthConfig: + """Tests for AuthConfig loading.""" + + def test_get_auth_config_with_config(self, auth_config_patch, test_api_key): + """Test config loading from global config.""" + config = get_auth_config() + + assert config.api_key_enabled is True + assert test_api_key in config.api_keys + + def test_get_auth_config_no_keys(self): + """Test config with no keys configured.""" + mock_config = _create_mock_config(api_keys=[]) + with patch("ciberwebscan.api.auth.get_config", return_value=mock_config): + config = get_auth_config() + assert config.api_key_enabled is False + assert config.api_keys == [] + + +# ============================================================================= +# Secure Key Comparison Tests +# ============================================================================= + + +class TestSecureKeyComparison: + """Tests for constant-time key comparison.""" + + def test_secure_compare_valid_key(self): + """Test constant-time comparison finds valid key.""" + stored_keys = ["key1-abcdef", "key2-ghijkl", "key3-mnopqr"] + result = _secure_compare_key("key2-ghijkl", stored_keys) + + assert result == "key2-ghi" # Returns first 8 chars + + def test_secure_compare_invalid_key(self): + """Test constant-time comparison rejects invalid key.""" + stored_keys = ["key1-abcdef", "key2-ghijkl"] + result = _secure_compare_key("invalid-key", stored_keys) + + assert result is None + + def test_secure_compare_empty_list(self): + """Test comparison with empty key list.""" + result = _secure_compare_key("any-key", []) + + assert result is None + + def test_secure_compare_similar_keys(self): + """Test comparison correctly distinguishes similar keys.""" + stored_keys = ["test-key-1"] + + # Should not match similar but different key + assert _secure_compare_key("test-key-2", stored_keys) is None + # Should match exact key + assert _secure_compare_key("test-key-1", stored_keys) == "test-key" + + +# ============================================================================= +# API Key Tests +# ============================================================================= + + +class TestApiKeyAuthentication: + """Tests for API key authentication.""" + + @pytest.mark.asyncio + async def test_verify_valid_api_key(self, auth_config_patch, test_api_key): + """Test valid API key verification.""" + mock_request = _create_mock_request() + user = await verify_api_key(mock_request, test_api_key) + + assert user is not None + assert user.auth_method == "api_key" + assert "full_access" in user.scopes + + @pytest.mark.asyncio + async def test_verify_invalid_api_key(self, auth_config_patch): + """Test invalid API key returns None.""" + mock_request = _create_mock_request() + user = await verify_api_key(mock_request, "invalid-key") + + assert user is None + + @pytest.mark.asyncio + async def test_verify_no_api_key(self, auth_config_patch): + """Test no API key returns None.""" + mock_request = _create_mock_request() + user = await verify_api_key(mock_request, None) + + assert user is None + + @pytest.mark.asyncio + async def test_verify_logs_failed_attempt(self, auth_config_patch): + """Test that failed authentication attempts are logged.""" + mock_request = _create_mock_request(client_ip="192.168.1.100") + + with patch("ciberwebscan.api.auth.logger") as mock_logger: + await verify_api_key(mock_request, "bad-key-attempt") + + # Should log warning for failed attempt + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args + assert "Invalid API key attempt" in call_args[0][0] + + @pytest.mark.asyncio + async def test_verify_logs_success(self, auth_config_patch, test_api_key): + """Test that successful authentication is logged.""" + mock_request = _create_mock_request() + + with patch("ciberwebscan.api.auth.logger") as mock_logger: + await verify_api_key(mock_request, test_api_key) + + # Should log info for success + mock_logger.info.assert_called() + + +# ============================================================================= +# Auth Endpoint Tests +# ============================================================================= + + +class TestAuthEndpoints: + """Tests for authentication endpoints.""" + + def test_me_endpoint_with_api_key( + self, client: TestClient, auth_config_patch, test_api_key + ): + """Test /auth/me with API key authentication.""" + response = client.get( + "/auth/me", + headers={"X-API-Key": test_api_key}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["auth_method"] == "api_key" + assert data["authenticated"] is True + + def test_me_endpoint_without_auth(self, client: TestClient, auth_config_patch): + """Test /auth/me without authentication fails.""" + response = client.get("/auth/me") + + assert response.status_code == 401 + + def test_status_endpoint_removed(self, client: TestClient, auth_config_patch): + """Test /auth/status endpoint no longer exists.""" + response = client.get("/auth/status") + + assert response.status_code == 404 + + +# ============================================================================= +# Protected Route Tests +# ============================================================================= + + +class TestProtectedRoutes: + """Tests for route protection.""" + + @pytest.fixture + def protected_app(self) -> FastAPI: + """Create app with protected routes.""" + from ciberwebscan.api.routes import analyze, scrape + + app = FastAPI() + app.include_router(auth_router, prefix="/auth") + app.include_router(scrape.router, prefix="/api") + app.include_router(analyze.router, prefix="/api") + return app + + @pytest.fixture + def protected_client(self, protected_app: FastAPI) -> TestClient: + """Create test client for protected routes.""" + return TestClient(protected_app) + + def test_scrape_requires_auth( + self, protected_client: TestClient, auth_config_patch + ): + """Test /api/scrape requires authentication.""" + response = protected_client.post( + "/api/scrape", + json={"url": "https://example.com"}, + ) + + assert response.status_code == 401 + + def test_scrape_with_api_key( + self, protected_client: TestClient, auth_config_patch, test_api_key + ): + """Test /api/scrape works with API key.""" + response = protected_client.post( + "/api/scrape", + json={"url": "https://example.com"}, + headers={"X-API-Key": test_api_key}, + ) + + # May return 500 if service fails, but auth should pass + assert response.status_code != 401 + + def test_analyze_requires_auth( + self, protected_client: TestClient, auth_config_patch + ): + """Test /api/analyze requires authentication.""" + response = protected_client.post( + "/api/analyze", + json={"url": "https://example.com"}, + ) + + assert response.status_code == 401 + + +# ============================================================================= +# Utility Tests +# ============================================================================= + + +class TestUtilities: + """Tests for utility functions.""" + + def test_generate_api_key(self): + """Test API key generation.""" + key1 = generate_api_key() + key2 = generate_api_key() + + assert isinstance(key1, str) + assert len(key1) >= 32 + assert key1 != key2 # Should be unique + + def test_generate_api_key_endpoint( + self, client: TestClient, auth_config_patch, test_api_key + ): + """Test API key generation endpoint.""" + response = client.post( + "/auth/generate-key", + headers={"X-API-Key": test_api_key}, + ) + + assert response.status_code == 200 + data = response.json() + assert "api_key" in data + assert len(data["api_key"]) >= 32 + + +# ============================================================================= +# Security Tests +# ============================================================================= + + +class TestSecurityMeasures: + """Tests for security measures.""" + + def test_auth_required_returns_401_not_403( + self, client: TestClient, auth_config_patch + ): + """Test unauthenticated requests get 401, not 403.""" + response = client.get("/auth/me") + + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + + def test_invalid_key_returns_401(self, client: TestClient, auth_config_patch): + """Test invalid API key returns 401.""" + response = client.get( + "/auth/me", + headers={"X-API-Key": "definitely-not-a-valid-key"}, + ) + + assert response.status_code == 401 + + def test_generate_key_requires_auth(self, client: TestClient, auth_config_patch): + """Test generate-key endpoint requires authentication.""" + response = client.post("/auth/generate-key") + + assert response.status_code == 401 diff --git a/tests/unit/api/test_middleware.py b/tests/unit/api/test_middleware.py new file mode 100644 index 0000000..209ff0c --- /dev/null +++ b/tests/unit/api/test_middleware.py @@ -0,0 +1,452 @@ +""" +Tests for API middleware. + +Tests for RequestLoggingMiddleware and RateLimitingMiddleware. +""" + +from __future__ import annotations + +import logging + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from ciberwebscan.api.middleware import ( + RateLimitingMiddleware, + add_rate_limiting_middleware, + add_request_logging_middleware, +) + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def app() -> FastAPI: + """Create a basic FastAPI app for testing.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(): + return {"message": "success"} + + @app.get("/error") + def error_endpoint(): + raise ValueError("Test error") + + return app + + +@pytest.fixture +def app_with_logging(app: FastAPI) -> FastAPI: + """Create app with logging middleware.""" + add_request_logging_middleware(app) + return app + + +@pytest.fixture +def app_with_rate_limiting(app: FastAPI) -> FastAPI: + """Create app with rate limiting middleware (5 requests per minute).""" + add_rate_limiting_middleware(app, requests_per_minute=5) + return app + + +@pytest.fixture +def client_logging(app_with_logging: FastAPI) -> TestClient: + """Create test client with logging middleware.""" + return TestClient(app_with_logging, raise_server_exceptions=False) + + +@pytest.fixture +def client_rate_limiting(app_with_rate_limiting: FastAPI) -> TestClient: + """Create test client with rate limiting middleware.""" + return TestClient(app_with_rate_limiting, raise_server_exceptions=False) + + +# ============================================================================= +# RequestLoggingMiddleware Tests +# ============================================================================= + + +class TestRequestLoggingMiddleware: + """Tests for RequestLoggingMiddleware.""" + + def test_logs_successful_request( + self, client_logging: TestClient, caplog: pytest.LogCaptureFixture + ): + """Test that successful requests are logged with correct info.""" + with caplog.at_level(logging.INFO): + response = client_logging.get("/test") + + assert response.status_code == 200 + assert any("GET /test - 200" in record.message for record in caplog.records) + + def test_logs_request_with_timing( + self, client_logging: TestClient, caplog: pytest.LogCaptureFixture + ): + """Test that request timing is included in logs.""" + with caplog.at_level(logging.INFO): + client_logging.get("/test") + + # Check that timing format (X.XXXs) is in the log + log_messages = [r.message for r in caplog.records] + assert any("s)" in msg and "GET /test" in msg for msg in log_messages) + + def test_logs_error_request( + self, client_logging: TestClient, caplog: pytest.LogCaptureFixture + ): + """Test that error requests are logged with 500 status.""" + with caplog.at_level(logging.INFO): + response = client_logging.get("/error") + + assert response.status_code == 500 + assert any("GET /error - 500" in record.message for record in caplog.records) + + def test_logs_not_found_request( + self, client_logging: TestClient, caplog: pytest.LogCaptureFixture + ): + """Test that 404 requests are logged correctly.""" + with caplog.at_level(logging.INFO): + response = client_logging.get("/nonexistent") + + assert response.status_code == 404 + assert any( + "GET /nonexistent - 404" in record.message for record in caplog.records + ) + + def test_log_record_has_extra_fields( + self, client_logging: TestClient, caplog: pytest.LogCaptureFixture + ): + """Test that log records contain expected extra fields.""" + with caplog.at_level(logging.INFO): + client_logging.get("/test") + + # Find the relevant log record + log_record = next((r for r in caplog.records if "GET /test" in r.message), None) + assert log_record is not None + assert hasattr(log_record, "method") + assert log_record.method == "GET" + assert hasattr(log_record, "path") + assert log_record.path == "/test" + assert hasattr(log_record, "status_code") + assert log_record.status_code == 200 + assert hasattr(log_record, "duration") + assert isinstance(log_record.duration, float) + + +# ============================================================================= +# RateLimitingMiddleware Tests +# ============================================================================= + + +class TestRateLimitingMiddleware: + """Tests for RateLimitingMiddleware.""" + + def test_allows_requests_within_limit(self, client_rate_limiting: TestClient): + """Test that requests within the limit are allowed.""" + for _ in range(5): + response = client_rate_limiting.get("/test") + assert response.status_code == 200 + + def test_blocks_requests_exceeding_limit(self, client_rate_limiting: TestClient): + """Test that requests exceeding the limit are blocked with 429.""" + # Make 5 allowed requests + for _ in range(5): + response = client_rate_limiting.get("/test") + assert response.status_code == 200 + + # 6th request should be blocked + response = client_rate_limiting.get("/test") + assert response.status_code == 429 + + def test_429_response_contains_error_message( + self, client_rate_limiting: TestClient + ): + """Test that 429 response contains proper error details.""" + # Exhaust limit + for _ in range(5): + client_rate_limiting.get("/test") + + response = client_rate_limiting.get("/test") + assert response.status_code == 429 + + data = response.json() + assert "error" in data + assert data["error"] == "Rate limit exceeded" + assert "retry_after_seconds" in data + assert isinstance(data["retry_after_seconds"], int) + + def test_429_response_has_retry_after_header( + self, client_rate_limiting: TestClient + ): + """Test that 429 response includes Retry-After header.""" + # Exhaust limit + for _ in range(5): + client_rate_limiting.get("/test") + + response = client_rate_limiting.get("/test") + assert response.status_code == 429 + assert "Retry-After" in response.headers + retry_after = int(response.headers["Retry-After"]) + assert 0 < retry_after <= 60 + + def test_different_clients_have_separate_limits(self, app: FastAPI): + """Test that different clients have independent rate limits.""" + add_rate_limiting_middleware(app, requests_per_minute=2) + client = TestClient(app) + + # Client 1 makes 2 requests + for _ in range(2): + response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.1"}) + assert response.status_code == 200 + + # Client 1 should be blocked + response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.1"}) + assert response.status_code == 429 + + # Client 2 should still be allowed + for _ in range(2): + response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.2"}) + assert response.status_code == 200 + + def test_x_forwarded_for_header_used_for_client_identification(self, app: FastAPI): + """Test that X-Forwarded-For header is used for client identification.""" + add_rate_limiting_middleware(app, requests_per_minute=2) + client = TestClient(app) + + # Make requests from "client1" via X-Forwarded-For + for _ in range(2): + response = client.get("/test", headers={"X-Forwarded-For": "192.168.1.1"}) + assert response.status_code == 200 + + # 3rd request from same "client" should be blocked + response = client.get("/test", headers={"X-Forwarded-For": "192.168.1.1"}) + assert response.status_code == 429 + + # Request from different "client" should be allowed + response = client.get("/test", headers={"X-Forwarded-For": "192.168.1.2"}) + assert response.status_code == 200 + + def test_x_forwarded_for_uses_first_ip(self, app: FastAPI): + """Test that first IP in X-Forwarded-For chain is used.""" + add_rate_limiting_middleware(app, requests_per_minute=2) + client = TestClient(app) + + # Make requests with chained X-Forwarded-For + for _ in range(2): + response = client.get( + "/test", + headers={"X-Forwarded-For": "10.0.0.1, 192.168.1.1, 172.16.0.1"}, + ) + assert response.status_code == 200 + + # Should be blocked based on first IP (10.0.0.1) + response = client.get( + "/test", + headers={"X-Forwarded-For": "10.0.0.1, 192.168.1.1"}, + ) + assert response.status_code == 429 + + # Different first IP should work + response = client.get( + "/test", + headers={"X-Forwarded-For": "10.0.0.2, 192.168.1.1"}, + ) + assert response.status_code == 200 + + def test_rate_limiting_logs_warning_on_exceeded( + self, client_rate_limiting: TestClient, caplog: pytest.LogCaptureFixture + ): + """Test that rate limit exceeded events are logged as warnings.""" + # Exhaust limit + for _ in range(5): + client_rate_limiting.get("/test") + + with caplog.at_level(logging.WARNING): + client_rate_limiting.get("/test") + + assert any("Rate limit exceeded" in record.message for record in caplog.records) + + +class TestRateLimitingMiddlewareWindowRotation: + """Tests for rate limiting window rotation behavior.""" + + def test_window_rotation_allows_new_requests(self, app: FastAPI): + """Test that new window allows requests again.""" + add_rate_limiting_middleware(app, requests_per_minute=2) + client = TestClient(app) + + # Find the middleware instance + middleware_instance = None + current = app.middleware_stack + while hasattr(current, "app"): + if isinstance(current, RateLimitingMiddleware): + middleware_instance = current + break + current = current.app + + if middleware_instance is None: + pytest.skip("Could not access middleware instance") + + # Exhaust limit + client.get("/test") + client.get("/test") + response = client.get("/test") + assert response.status_code == 429 + + # Manually rotate window + middleware_instance.window = 0 + middleware_instance.counts.clear() + + # Should be allowed now + response = client.get("/test") + assert response.status_code == 200 + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + + +class TestAddMiddlewareFunctions: + """Tests for middleware helper functions.""" + + def test_add_request_logging_middleware(self): + """Test that add_request_logging_middleware adds the middleware.""" + app = FastAPI() + initial_middleware_count = len(app.user_middleware) + + add_request_logging_middleware(app) + + assert len(app.user_middleware) == initial_middleware_count + 1 + + def test_add_rate_limiting_middleware(self): + """Test that add_rate_limiting_middleware adds the middleware.""" + app = FastAPI() + initial_middleware_count = len(app.user_middleware) + + add_rate_limiting_middleware(app) + + assert len(app.user_middleware) == initial_middleware_count + 1 + + def test_add_rate_limiting_middleware_with_custom_limit(self): + """Test that custom rate limit is applied.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + add_rate_limiting_middleware(app, requests_per_minute=3) + client = TestClient(app) + + # Make 3 allowed requests + for _ in range(3): + response = client.get("/test") + assert response.status_code == 200 + + # 4th should be blocked + response = client.get("/test") + assert response.status_code == 429 + + def test_add_rate_limiting_middleware_default_limit(self): + """Test that default rate limit is 60 requests per minute.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + add_rate_limiting_middleware(app) + client = TestClient(app) + + # Should allow 60 requests + for _ in range(60): + response = client.get("/test") + assert response.status_code == 200 + + # 61st should be blocked + response = client.get("/test") + assert response.status_code == 429 + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestMiddlewareIntegration: + """Integration tests with both middlewares active.""" + + def test_both_middlewares_work_together(self, caplog: pytest.LogCaptureFixture): + """Test that logging and rate limiting work together.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + # Add rate limiting first, then logging (logging wraps rate limiting) + add_rate_limiting_middleware(app, requests_per_minute=2) + add_request_logging_middleware(app) + + client = TestClient(app) + + with caplog.at_level(logging.INFO): + # First two requests should succeed + for _ in range(2): + response = client.get("/test") + assert response.status_code == 200 + + # Third request should be rate limited + response = client.get("/test") + assert response.status_code == 429 + + # Verify logging happened for all requests + log_messages = [r.message for r in caplog.records] + assert sum("GET /test - 200" in msg for msg in log_messages) == 2 + assert sum("GET /test - 429" in msg for msg in log_messages) == 1 + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge case tests for middleware.""" + + def test_logging_handles_missing_client(self, app: FastAPI): + """Test logging middleware handles missing client info gracefully.""" + add_request_logging_middleware(app) + # TestClient always provides client info, so this is tested implicitly + + def test_rate_limiting_empty_x_forwarded_for(self, app: FastAPI): + """Test rate limiting with empty X-Forwarded-For header.""" + add_rate_limiting_middleware(app, requests_per_minute=2) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + client = TestClient(app) + + # Empty X-Forwarded-For should fall back to client IP + response = client.get("/test", headers={"X-Forwarded-For": ""}) + assert response.status_code == 200 + + def test_rate_limiting_whitespace_only_x_forwarded_for(self, app: FastAPI): + """Test rate limiting with whitespace-only X-Forwarded-For header.""" + add_rate_limiting_middleware(app, requests_per_minute=2) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + client = TestClient(app) + + # Whitespace-only X-Forwarded-For should fall back to client IP + response = client.get("/test", headers={"X-Forwarded-For": " "}) + assert response.status_code == 200 diff --git a/tests/unit/core/attacks/test_traversal.py b/tests/unit/core/attacks/test_traversal.py index 7e99513..5329335 100644 --- a/tests/unit/core/attacks/test_traversal.py +++ b/tests/unit/core/attacks/test_traversal.py @@ -246,7 +246,8 @@ def test_analyze_path_structure(self, traversal_attacker): assert traversal_attacker._is_file_parameter("path") is True assert traversal_attacker._is_file_parameter("template") is True - def test_basic_payloads_and_execution( + @pytest.mark.asyncio + async def test_basic_payloads_and_execution( self, traversal_attacker, attack_context, safe_response ): """Minimal checks for payload generation and execution flow.""" @@ -260,10 +261,5 @@ def test_basic_payloads_and_execution( attack_context.http_client.post.return_value = safe_response # Run execute (async) to ensure integration path works - async def _run_execute(): - return await traversal_attacker.execute(attack_context) - - import asyncio - - vulnerabilities = asyncio.get_event_loop().run_until_complete(_run_execute()) + vulnerabilities = await traversal_attacker.execute(attack_context) assert isinstance(vulnerabilities, list) diff --git a/tests/unit/services/test_download_service.py b/tests/unit/services/test_download_service.py new file mode 100644 index 0000000..42096db --- /dev/null +++ b/tests/unit/services/test_download_service.py @@ -0,0 +1,284 @@ +""" +Tests for DownloadService class. +""" + +from __future__ import annotations + +import asyncio +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ciberwebscan.api.models.responses import DownloadTokenResponse +from ciberwebscan.config.loader import get_config +from ciberwebscan.services.download_service import DownloadService, _registry + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def service() -> DownloadService: + """Create a test service instance.""" + return DownloadService() + + +@pytest.fixture +def test_file() -> Path: + """Create a temporary test file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + f.write('{"test": "data"}') + return Path(f.name) + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the registry before and after each test.""" + asyncio.run(_registry.cleanup_expired()) + yield + asyncio.run(_registry.cleanup_expired()) + + +# ============================================================================= +# Test: Generate Token +# ============================================================================= + + +class TestGenerateDownloadToken: + """Tests for token generation.""" + + def test_generate_token_success(self, service: DownloadService, test_file: Path): + """TEST 1: Generate valid token with correct metadata.""" + result = service.generate_download_token( + file_path=test_file, + user_id="test_user_123", + file_format="json", + ) + + assert result.success is True + assert result.data is not None + assert isinstance(result.data, DownloadTokenResponse) + assert result.data.token is not None + assert len(result.data.token) == 36 # UUID v4 length + assert result.data.download_url == f"/api/v1/download/{result.data.token}" + assert result.data.expires_at > datetime.now(timezone.utc) + + def test_generate_token_file_not_found(self, service: DownloadService): + """TEST 2: Reject non-existent file.""" + result = service.generate_download_token( + file_path="/nonexistent/file.json", + user_id="test_user", + file_format="json", + ) + + assert result.success is False + assert "not found" in result.error.lower() + + def test_generate_token_file_too_large(self, service: DownloadService): + """TEST 5: Reject file exceeding max size.""" + config = get_config() + max_size = config.download.max_file_size_mb + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + # Create file larger than max + f.write("x" * int((max_size + 1) * 1024 * 1024)) + large_file = Path(f.name) + + try: + result = service.generate_download_token( + file_path=large_file, + user_id="test_user", + file_format="json", + ) + + assert result.success is False + assert "exceeds limit" in result.error.lower() + finally: + large_file.unlink() + + +# ============================================================================= +# Test: Validate Download Request +# ============================================================================= + + +class TestValidateDownloadRequest: + """Tests for request validation.""" + + def test_validate_expired_token(self, service: DownloadService, test_file: Path): + """TEST2: Reject expired token.""" + # Generate token + gen_result = service.generate_download_token( + file_path=test_file, + user_id="test_user", + file_format="json", + ) + token = gen_result.data.token + + # Manually expire the token (update directly in the dict since it's internal) + # This is a hack for testing - normally the scheduler would handle this + async def expire_token(): + info = await _registry.get_info(token) + if info: + info.expires_at = datetime.now(timezone.utc) - timedelta(seconds=1) + await _registry.store(token, info, b"test") + + asyncio.run(expire_token()) + + # Validate should fail + result = service.validate_download_request(token, "test_user") + assert result.success is False + assert "expired" in result.error.lower() + + def test_validate_user_mismatch(self, service: DownloadService, test_file: Path): + """TEST 3: Reject mismatched user ID.""" + config = get_config() + original_require = config.download.require_same_user + + try: + # Enable user checking + config.download.require_same_user = True + + # Generate token for one user + gen_result = service.generate_download_token( + file_path=test_file, + user_id="user_1", + file_format="json", + ) + token = gen_result.data.token + + # Try with different user + result = service.validate_download_request(token, "user_2") + assert result.success is False + assert "unauthorized" in result.error.lower() + finally: + config.download.require_same_user = original_require + + def test_validate_max_retries_exceeded( + self, service: DownloadService, test_file: Path + ): + """TEST 4: Reject request exceeding max retries.""" + config = get_config() + max_retries = config.download.max_retries + + # Generate token + gen_result = service.generate_download_token( + file_path=test_file, + user_id="test_user", + file_format="json", + ) + token = gen_result.data.token + + # Exhaust retries + for i in range(max_retries): + result = service.validate_download_request(token, "test_user") + if i < max_retries - 1: + assert result.success is True + else: + # Last one should succeed but be the limit + assert result.success is True + + # Next attempt should fail + result = service.validate_download_request(token, "test_user") + assert result.success is False + assert "exhausted" in result.error.lower() or "exceeded" in result.error.lower() + + def test_validate_token_not_found(self, service: DownloadService): + """Validate non-existent token.""" + result = service.validate_download_request("nonexistent-token", "test_user") + assert result.success is False + assert "not found" in result.error.lower() + + +# ============================================================================= +# Test: Cleanup +# ============================================================================= + + +class TestCleanupExpiredTokens: + """Tests for token cleanup.""" + + def test_cleanup_removes_expired(self, service: DownloadService, test_file: Path): + """TEST 6: Cleanup removes expired tokens.""" + # Generate two tokens + gen_result1 = service.generate_download_token( + file_path=test_file, + user_id="test_user", + file_format="json", + ) + token1 = gen_result1.data.token + + gen_result2 = service.generate_download_token( + file_path=test_file, + user_id="test_user", + file_format="json", + ) + token2 = gen_result2.data.token + + # Expire first token + async def expire_first(): + info = await _registry.get_info(token1) + info.expires_at = datetime.now(timezone.utc) - timedelta(seconds=1) + await _registry.store(token1, info, b"test") + + asyncio.run(expire_first()) + + # Cleanup should remove 1 token + result = service.cleanup_expired_tokens() + assert result.success is True + assert result.data == 1 + + # Verify token1 is gone + info1 = asyncio.run(_registry.get_info(token1)) + assert info1 is None + + # Verify token2 still exists + info2 = asyncio.run(_registry.get_info(token2)) + assert info2 is not None + + def test_cleanup_preserves_valid(self, service: DownloadService, test_file: Path): + """TEST 7: Cleanup preserves valid tokens.""" + # Generate a valid token + gen_result = service.generate_download_token( + file_path=test_file, + user_id="test_user", + file_format="json", + ) + token = gen_result.data.token + + # Cleanup should remove 0 tokens + result = service.cleanup_expired_tokens() + assert result.success is True + assert result.data == 0 + + # Verify token still exists + info = asyncio.run(_registry.get_info(token)) + assert info is not None + + +# ============================================================================= +# Test: UUID Uniqueness +# ============================================================================= + + +class TestTokenUniqueness: + """Tests for token uniqueness.""" + + def test_token_uuid_uniqueness(self, service: DownloadService, test_file: Path): + """TEST 8: 100 generated tokens are unique.""" + tokens = set() + + for i in range(100): + result = service.generate_download_token( + file_path=test_file, + user_id=f"user_{i}", + file_format="json", + ) + assert result.success is True + tokens.add(result.data.token) + + # All tokens should be unique + assert len(tokens) == 100 diff --git a/tests/unit/services/test_download_token_deletion.py b/tests/unit/services/test_download_token_deletion.py new file mode 100644 index 0000000..c1bcf28 --- /dev/null +++ b/tests/unit/services/test_download_token_deletion.py @@ -0,0 +1,134 @@ +"""Test token deletion functionality after download.""" + +from __future__ import annotations + +import asyncio +import os +import tempfile + +import pytest + +from ciberwebscan.services.download_service import DownloadService, _registry + + +class TestTokenDeletion: + """Test token deletion after successful download.""" + + @pytest.fixture + def temp_file(self): + """Create a temporary file for testing.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f: + f.write(b'{"test": "data"}') + temp_path = f.name + yield temp_path + if os.path.exists(temp_path): + os.unlink(temp_path) + + @pytest.fixture + def download_service(self): + """Create download service.""" + return DownloadService() + + @pytest.fixture(autouse=True) + def clear_registry(self): + """Clear the registry before and after each test.""" + asyncio.run(_registry.cleanup_expired()) + yield + asyncio.run(_registry.cleanup_expired()) + + def test_delete_token_removes_token(self, download_service, temp_file): + """Test that delete_token removes token from registry.""" + # Generate a token + result = download_service.generate_download_token( + file_path=temp_file, user_id="user123", file_format="json" + ) + assert result.success + token = result.data.token + + # Verify token exists + validate_result = download_service.validate_download_request( + token=token, user_id="user123" + ) + assert validate_result.success + + # Delete token + delete_result = download_service.delete_token(token) + assert delete_result.success + + # Verify token is gone + validate_after = download_service.validate_download_request( + token=token, user_id="user123" + ) + assert not validate_after.success + assert "not found" in validate_after.error.lower() + + def test_delete_nonexistent_token_returns_error(self, download_service): + """Test that deleting non-existent token returns error.""" + result = download_service.delete_token("nonexistent-token") + assert not result.success + assert "not found" in result.error.lower() + + def test_token_deleted_prevents_retry(self, download_service, temp_file): + """Test that deleted token cannot be used for retries.""" + # Generate token + result = download_service.generate_download_token( + file_path=temp_file, user_id="user123", file_format="json" + ) + assert result.success + token = result.data.token + + # Delete token immediately + delete_result = download_service.delete_token(token) + assert delete_result.success + + # Try to use token - should fail + validate_result = download_service.validate_download_request( + token=token, user_id="user123" + ) + assert not validate_result.success + + def test_multiple_tokens_independent_deletion(self, download_service, temp_file): + """Test that deleting one token doesn't affect others.""" + # Generate two tokens + result1 = download_service.generate_download_token( + file_path=temp_file, user_id="user123", file_format="json" + ) + result2 = download_service.generate_download_token( + file_path=temp_file, user_id="user456", file_format="json" + ) + assert result1.success + assert result2.success + + token1 = result1.data.token + token2 = result2.data.token + + # Delete first token + delete_result = download_service.delete_token(token1) + assert delete_result.success + + # Verify first token is gone + validate1 = download_service.validate_download_request( + token=token1, user_id="user123" + ) + assert not validate1.success + + # Verify second token still exists + validate2 = download_service.validate_download_request( + token=token2, user_id="user456" + ) + assert validate2.success + + def test_delete_token_idempotent(self, download_service, temp_file): + """Test that deleting already deleted token returns appropriate response.""" + # Generate and delete token + result = download_service.generate_download_token( + file_path=temp_file, user_id="user123", file_format="json" + ) + token = result.data.token + + delete1 = download_service.delete_token(token) + assert delete1.success + + # Try to delete again + delete2 = download_service.delete_token(token) + assert not delete2.success # Should fail as token doesn't exist