Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,97 @@ Downloads a specified encrypted file, decrypts it and then behaves identically t
The request body for this route is the same as for
`POST /_matrix/media_proxy/unstable/download_encrypted`.

### `POST /_matrix/media_proxy/unstable/scan_file`

Scans a file directly without downloading it from a Matrix homeserver. The file
content is sent in the request body as a `multipart/form-data` upload.

#### Request

The request must use `Content-Type: multipart/form-data` with the following parts:

| Part name | Required | Type | Description |
|-----------|----------|-a-----|-------------|
| `body` | **Yes** | Binary (file content) | The raw file to scan. |
| `file` | No | JSON string | Decryption metadata for an encrypted file. Follows the [`EncryptedFile`](https://spec.matrix.org/v1.2/client-server-api/#extensions-to-mroommessage-msgtypes) structure from the Matrix specification. Only needed when the file in `body` is encrypted. |

#### Request examples

Scan an unencrypted file with `curl`:

```bash
curl -X POST \
http://localhost:8080/_matrix/media_proxy/unstable/scan_file \
-F "body=@document.pdf;type=application/pdf"
```

Scan an encrypted file (provide decryption metadata via the `file` part):

```bash
curl -X POST \
http://localhost:8080/_matrix/media_proxy/unstable/scan_file \
-F "body=@encrypted_file.bin;type=application/octet-stream" \
-F "file={\"v\":\"v2\",\"key\":{...},\"iv\":\"...\",\"hashes\":{...}};type=application/json"
```

Scan a file with Python (`requests`):

```python
import requests

resp = requests.post(
"http://localhost:8080/_matrix/media_proxy/unstable/scan_file",
files={"body": ("image.png", open("image.png", "rb"), "image/png")},
)
print(resp.json()) # {"clean": true, "info": "File is clean"}
```

Scan an encrypted file with Python (`requests`), providing decryption metadata via the `file` part:

```python
import json
import requests

encrypted_file_metadata = {
"v": "v2",
"key": {
"alg": "A256CTR",
"ext": True,
"k": "base64-encoded-key",
"key_ops": ["encrypt", "decrypt"],
"kty": "oct",
},
"iv": "base64-encoded-iv",
"hashes": {
"sha256": "base64-encoded-hash",
},
}

resp = requests.post(
"http://localhost:8080/_matrix/media_proxy/unstable/scan_file",
files={
"body": ("encrypted.bin", open("encrypted.bin", "rb"), "application/octet-stream"),
"file": ("metadata.json", json.dumps(encrypted_file_metadata), "application/json"),
},
)
print(resp.json()) # {"clean": true, "info": "File is clean"}
```

#### Response

| Parameter | Type | Description |
|-----------|------|-------------|
| `clean` | bool | `true` if the file passed the scan, `false` otherwise. |
| `info` | str | Human-readable result description. |

Example response:

```json
{
"clean": false,
"info": "***VIRUS DETECTED***"
}
```

### `GET /_matrix/media_proxy/unstable/public_key`

Expand Down
50 changes: 49 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ humanfriendly = ">=10.0"
# Required for calculating cache keys deterministically. Type annotations aren't
# discoverable in versions older than 1.6.3.
canonicaljson = ">=1.6.3"
# Required for non-blocking file I/O.
aiofile = ">=3.8.0"
setuptools_rust = ">=1.3"

[tool.poetry.dev-dependencies]
Expand Down
1 change: 1 addition & 0 deletions src/matrix_content_scanner/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _build_app(self) -> web.Application:
[
web.get("/scan" + _MEDIA_PATH_REGEXP, scan_handler.handle_plain),
web.post("/scan_encrypted", scan_handler.handle_encrypted),
web.post("/scan_file", scan_handler.handle_file),
web.get(
"/download" + _MEDIA_PATH_REGEXP, download_handler.handle_plain
),
Expand Down
118 changes: 88 additions & 30 deletions src/matrix_content_scanner/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import logging
import os
import subprocess
import uuid
from asyncio import Future
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import attr
import magic
from aiofile import async_open
from cachetools import TTLCache
from canonicaljson import encode_canonical_json
from humanfriendly import format_size
Expand Down Expand Up @@ -320,6 +322,76 @@ async def _scan_file(

return media

async def scan_content(
self, content: bytes, metadata: Optional[JsonDict] = None
) -> None:
"""Scan raw file bytes. The content is written to disk once (decrypted if
needed), scanned, and cleaned up.

This does not use the result cache or concurrent-request deduplication.

Args:
content: The raw file bytes (possibly still encrypted).
metadata: The metadata attached to the file (e.g. decryption key), or None
if the file isn't encrypted.

Raises:
FileDirtyError if the result of the scan said that the file is dirty.
"""
exit_code = await self._do_scan(content, metadata)
result = exit_code == 0

cacheable = exit_code not in self._exit_codes_to_ignore

if result is False:
raise FileDirtyError(cacheable=cacheable)

async def _do_scan(
self,
content: bytes,
metadata: Optional[JsonDict] = None,
file_id: Optional[str] = None,
) -> int:
"""Core scan pipeline shared by all request paths.

Handles: decrypt (if needed) → write to disk → mimetype check → scan → cleanup.

Args:
content: The raw file bytes (encrypted or plaintext).
metadata: Decryption metadata, or None if the file is unencrypted.
file_id: Identifier used as the temp filename on disk. If None, a random
UUID is generated. Passing the media_path (server_name/media_id)
preserves the original directory structure for traceability.

Returns:
The exit code from the scan script (0 = clean).
"""
# Decrypt the content if necessary.
if metadata is not None:
# If the file is encrypted, we need to decrypt it before we can scan it.
content = self._decrypt_file(content, metadata)

# Write the file to disk.
file_path = await self._write_file_to_disk(
file_id or str(uuid.uuid4()), content
)

try:
# Check the file's MIME type to see if it's allowed.
self._check_mimetype(file_path)
# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(file_path)
# Log the result of the scan.
logger.info("Scan has finished")
finally:
# This could be own function.
logger.info("Removing file")
removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)

return exit_code

async def _scan_media(
self,
media: MediaDescription,
Expand All @@ -344,21 +416,7 @@ async def _scan_media(
FileDirtyError if the result of the scan said that the file is dirty, or if
the media path is malformed.
"""

# Decrypt the content if necessary.
media_content = media.content
if metadata is not None:
# If the file is encrypted, we need to decrypt it before we can scan it.
media_content = self._decrypt_file(media_content, metadata)

# Check the file's MIME type to see if it's allowed.
self._check_mimetype(media_content)

# Write the file to disk.
file_path = self._write_file_to_disk(media_path, media_content)

# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(file_path)
exit_code = await self._do_scan(media.content, metadata, file_id=media_path)
result = exit_code == 0

# If the exit code isn't part of the ones we should ignore, cache the result.
Expand All @@ -369,13 +427,6 @@ async def _scan_media(
)
cacheable = False

# Delete the file now that we've scanned it.
logger.info("Scan has finished, removing file")
removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)

# Raise an error if the result isn't clean.
if result is False:
raise FileDirtyError(cacheable=cacheable)

Expand Down Expand Up @@ -445,7 +496,7 @@ def _decrypt_file(self, body: bytes, metadata: JsonDict) -> bytes:
info=str(e),
)

def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
async def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
"""Writes the given content to disk. The final file name will be a concatenation
of `temp_directory` and the media's `server_name/media_id` path.

Expand Down Expand Up @@ -475,8 +526,16 @@ def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
# Create any directory we need.
os.makedirs(full_path.parent, exist_ok=True)

with open(full_path, "wb") as fp:
fp.write(body)
try:
async with async_open(full_path, "wb") as fp:
await fp.write(body if isinstance(body, bytes) else bytes(body))
except Exception:
# Delete the file if the write fails.
try:
os.unlink(full_path)
except OSError:
pass
raise

return str(full_path)

Expand Down Expand Up @@ -506,16 +565,15 @@ async def _run_scan(self, file_name: str) -> int:

return retcode

def _check_mimetype(self, media_content: bytes) -> None:
"""Detects the MIME type of the provided bytes, and checks that this type is allowed
def _check_mimetype(self, filepath: str) -> None:
"""Detects the MIME type of the provided file, and checks that this type is allowed
(if an allow list is provided in the configuration)
Args:
media_content: The file's content. If the file is encrypted, this is its
decrypted content.
filepath: The full file path.
Raises:
FileMimeTypeForbiddenError if one of the checks fail.
"""
detected_mimetype = magic.from_buffer(media_content, mime=True)
detected_mimetype = magic.from_file(filepath, mime=True)
logger.debug("Detected MIME type for file is %s", detected_mimetype)

# If there's an allow list for MIME types, check that the MIME type that's been
Expand Down
26 changes: 26 additions & 0 deletions src/matrix_content_scanner/servlets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,32 @@ async def get_media_metadata_from_request(
return media_path, metadata


async def get_media_metadata_from_filebody(
file_body: JsonDict,
crypto_handler: crypto.CryptoHandler,
) -> JsonDict:
"""Extracts, optionally decrypts, and validates encrypted file metadata from a
request body.

Args:
request: The request to extract the data from.
crypto_handler: The crypto handler to use if we need to decrypt an Olm-encrypted
body.

Raises:
ContentScannerRestError(400) if the request's body is None or if the metadata
didn't pass schema validation.
"""
metadata = _metadata_from_body(file_body, crypto_handler)

validate_encrypted_file_metadata(metadata)

# Unlike get_media_metadata_from_request, we intentionally skip extracting
# the file URL from the metadata because the caller already has the media content.

return metadata


def _metadata_from_body(
body: JsonDict, crypto_handler: crypto.CryptoHandler
) -> JsonDict:
Expand Down
Loading