diff --git a/docs/api.md b/docs/api.md index e820059..a11d287 100644 --- a/docs/api.md +++ b/docs/api.md @@ -128,6 +128,97 @@ Downloads a specified encrypted file, decrypts it and then behaves identically t The request body for this route is the same as for `POST /_matrix/media_proxy/unstable/download_encrypted`. +### `POST /_matrix/media_proxy/unstable/scan_file` + +Scans a file directly without downloading it from a Matrix homeserver. The file +content is sent in the request body as a `multipart/form-data` upload. + +#### Request + +The request must use `Content-Type: multipart/form-data` with the following parts: + +| Part name | Required | Type | Description | +|-----------|----------|-a-----|-------------| +| `body` | **Yes** | Binary (file content) | The raw file to scan. | +| `file` | No | JSON string | Decryption metadata for an encrypted file. Follows the [`EncryptedFile`](https://spec.matrix.org/v1.2/client-server-api/#extensions-to-mroommessage-msgtypes) structure from the Matrix specification. Only needed when the file in `body` is encrypted. | + +#### Request examples + +Scan an unencrypted file with `curl`: + +```bash +curl -X POST \ + http://localhost:8080/_matrix/media_proxy/unstable/scan_file \ + -F "body=@document.pdf;type=application/pdf" +``` + +Scan an encrypted file (provide decryption metadata via the `file` part): + +```bash +curl -X POST \ + http://localhost:8080/_matrix/media_proxy/unstable/scan_file \ + -F "body=@encrypted_file.bin;type=application/octet-stream" \ + -F "file={\"v\":\"v2\",\"key\":{...},\"iv\":\"...\",\"hashes\":{...}};type=application/json" +``` + +Scan a file with Python (`requests`): + +```python +import requests + +resp = requests.post( + "http://localhost:8080/_matrix/media_proxy/unstable/scan_file", + files={"body": ("image.png", open("image.png", "rb"), "image/png")}, +) +print(resp.json()) # {"clean": true, "info": "File is clean"} +``` + +Scan an encrypted file with Python (`requests`), providing decryption metadata via the `file` part: + +```python +import json +import requests + +encrypted_file_metadata = { + "v": "v2", + "key": { + "alg": "A256CTR", + "ext": True, + "k": "base64-encoded-key", + "key_ops": ["encrypt", "decrypt"], + "kty": "oct", + }, + "iv": "base64-encoded-iv", + "hashes": { + "sha256": "base64-encoded-hash", + }, +} + +resp = requests.post( + "http://localhost:8080/_matrix/media_proxy/unstable/scan_file", + files={ + "body": ("encrypted.bin", open("encrypted.bin", "rb"), "application/octet-stream"), + "file": ("metadata.json", json.dumps(encrypted_file_metadata), "application/json"), + }, +) +print(resp.json()) # {"clean": true, "info": "File is clean"} +``` + +#### Response + +| Parameter | Type | Description | +|-----------|------|-------------| +| `clean` | bool | `true` if the file passed the scan, `false` otherwise. | +| `info` | str | Human-readable result description. | + +Example response: + +```json +{ + "clean": false, + "info": "***VIRUS DETECTED***" +} +``` ### `GET /_matrix/media_proxy/unstable/public_key` diff --git a/poetry.lock b/poetry.lock index a6f6d53..0fd3b7c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,19 @@ # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +[[package]] +name = "aiofile" +version = "3.9.0" +description = "Asynchronous file operations." +optional = false +python-versions = "<4,>=3.8" +files = [ + {file = "aiofile-3.9.0-py3-none-any.whl", hash = "sha256:ce2f6c1571538cbdfa0143b04e16b208ecb0e9cb4148e528af8a640ed51cc8aa"}, + {file = "aiofile-3.9.0.tar.gz", hash = "sha256:e5ad718bb148b265b6df1b3752c4d1d83024b93da9bd599df74b9d9ffcf7919b"}, +] + +[package.dependencies] +caio = ">=0.9.0,<0.10.0" + [[package]] name = "aiohappyeyeballs" version = "2.4.0" @@ -178,6 +192,40 @@ files = [ {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, ] +[[package]] +name = "caio" +version = "0.9.25" +description = "Asynchronous file IO for Linux MacOS or Windows." +optional = false +python-versions = ">=3.10" +files = [ + {file = "caio-0.9.25-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ca6c8ecda611478b6016cb94d23fd3eb7124852b985bdec7ecaad9f3116b9619"}, + {file = "caio-0.9.25-cp310-cp310-manylinux2010_x86_64.manylinux2014_x86_64.manylinux_2_12_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:db9b5681e4af8176159f0d6598e73b2279bb661e718c7ac23342c550bd78c241"}, + {file = "caio-0.9.25-cp310-cp310-manylinux_2_34_aarch64.whl", hash = "sha256:bf61d7d0c4fd10ffdd98ca47f7e8db4d7408e74649ffaf4bef40b029ada3c21b"}, + {file = "caio-0.9.25-cp310-cp310-manylinux_2_34_x86_64.whl", hash = "sha256:ab52e5b643f8bbd64a0605d9412796cd3464cb8ca88593b13e95a0f0b10508ae"}, + {file = "caio-0.9.25-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d6956d9e4a27021c8bd6c9677f3a59eb1d820cc32d0343cea7961a03b1371965"}, + {file = "caio-0.9.25-cp311-cp311-manylinux2010_x86_64.manylinux2014_x86_64.manylinux_2_12_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bf84bfa039f25ad91f4f52944452a5f6f405e8afab4d445450978cd6241d1478"}, + {file = "caio-0.9.25-cp311-cp311-manylinux_2_34_aarch64.whl", hash = "sha256:ae3d62587332bce600f861a8de6256b1014d6485cfd25d68c15caf1611dd1f7c"}, + {file = "caio-0.9.25-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:fc220b8533dcf0f238a6b1a4a937f92024c71e7b10b5a2dfc1c73604a25709bc"}, + {file = "caio-0.9.25-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fb7ff95af4c31ad3f03179149aab61097a71fd85e05f89b4786de0359dffd044"}, + {file = "caio-0.9.25-cp312-cp312-manylinux2010_x86_64.manylinux2014_x86_64.manylinux_2_12_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:97084e4e30dfa598449d874c4d8e0c8d5ea17d2f752ef5e48e150ff9d240cd64"}, + {file = "caio-0.9.25-cp312-cp312-manylinux_2_34_aarch64.whl", hash = "sha256:4fa69eba47e0f041b9d4f336e2ad40740681c43e686b18b191b6c5f4c5544bfb"}, + {file = "caio-0.9.25-cp312-cp312-manylinux_2_34_x86_64.whl", hash = "sha256:6bebf6f079f1341d19f7386db9b8b1f07e8cc15ae13bfdaff573371ba0575d69"}, + {file = "caio-0.9.25-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d6c2a3411af97762a2b03840c3cec2f7f728921ff8adda53d7ea2315a8563451"}, + {file = "caio-0.9.25-cp313-cp313-manylinux2010_x86_64.manylinux2014_x86_64.manylinux_2_12_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0998210a4d5cd5cb565b32ccfe4e53d67303f868a76f212e002a8554692870e6"}, + {file = "caio-0.9.25-cp313-cp313-manylinux_2_34_aarch64.whl", hash = "sha256:1a177d4777141b96f175fe2c37a3d96dec7911ed9ad5f02bac38aaa1c936611f"}, + {file = "caio-0.9.25-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:9ed3cfb28c0e99fec5e208c934e5c157d0866aa9c32aa4dc5e9b6034af6286b7"}, + {file = "caio-0.9.25-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:fab6078b9348e883c80a5e14b382e6ad6aabbc4429ca034e76e730cf464269db"}, + {file = "caio-0.9.25-cp314-cp314-manylinux2010_x86_64.manylinux2014_x86_64.manylinux_2_12_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:44a6b58e52d488c75cfaa5ecaa404b2b41cc965e6c417e03251e868ecd5b6d77"}, + {file = "caio-0.9.25-cp314-cp314-manylinux_2_34_aarch64.whl", hash = "sha256:628a630eb7fb22381dd8e3c8ab7f59e854b9c806639811fc3f4310c6bd711d79"}, + {file = "caio-0.9.25-cp314-cp314-manylinux_2_34_x86_64.whl", hash = "sha256:0ba16aa605ccb174665357fc729cf500679c2d94d5f1458a6f0d5ca48f2060a7"}, + {file = "caio-0.9.25-py3-none-any.whl", hash = "sha256:06c0bb02d6b929119b1cfbe1ca403c768b2013a369e2db46bfa2a5761cf82e40"}, + {file = "caio-0.9.25.tar.gz", hash = "sha256:16498e7f81d1d0f5a4c0ad3f2540e65fe25691376e0a5bd367f558067113ed10"}, +] + +[package.extras] +develop = ["aiomisc-pytest", "coveralls", "pylama[toml]", "pytest", "pytest-cov", "setuptools"] + [[package]] name = "canonicaljson" version = "2.0.0" @@ -955,4 +1003,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10.0" -content-hash = "6201c75d864f2c640ca78708a232e46320a326cc3bef8950cfcde56efe2f91e6" +content-hash = "2ed0bcdad855c181359b9c94295b6e18fe7e078247e74edc444e292f6633df43" diff --git a/pyproject.toml b/pyproject.toml index 1ee9068..e06faf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,8 @@ humanfriendly = ">=10.0" # Required for calculating cache keys deterministically. Type annotations aren't # discoverable in versions older than 1.6.3. canonicaljson = ">=1.6.3" +# Required for non-blocking file I/O. +aiofile = ">=3.8.0" setuptools_rust = ">=1.3" [tool.poetry.dev-dependencies] diff --git a/src/matrix_content_scanner/httpserver.py b/src/matrix_content_scanner/httpserver.py index d1189d2..311dde8 100644 --- a/src/matrix_content_scanner/httpserver.py +++ b/src/matrix_content_scanner/httpserver.py @@ -109,6 +109,7 @@ def _build_app(self) -> web.Application: [ web.get("/scan" + _MEDIA_PATH_REGEXP, scan_handler.handle_plain), web.post("/scan_encrypted", scan_handler.handle_encrypted), + web.post("/scan_file", scan_handler.handle_file), web.get( "/download" + _MEDIA_PATH_REGEXP, download_handler.handle_plain ), diff --git a/src/matrix_content_scanner/scanner/scanner.py b/src/matrix_content_scanner/scanner/scanner.py index 80a9f6b..7ad6a79 100644 --- a/src/matrix_content_scanner/scanner/scanner.py +++ b/src/matrix_content_scanner/scanner/scanner.py @@ -7,12 +7,14 @@ import logging import os import subprocess +import uuid from asyncio import Future from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr import magic +from aiofile import async_open from cachetools import TTLCache from canonicaljson import encode_canonical_json from humanfriendly import format_size @@ -320,6 +322,76 @@ async def _scan_file( return media + async def scan_content( + self, content: bytes, metadata: Optional[JsonDict] = None + ) -> None: + """Scan raw file bytes. The content is written to disk once (decrypted if + needed), scanned, and cleaned up. + + This does not use the result cache or concurrent-request deduplication. + + Args: + content: The raw file bytes (possibly still encrypted). + metadata: The metadata attached to the file (e.g. decryption key), or None + if the file isn't encrypted. + + Raises: + FileDirtyError if the result of the scan said that the file is dirty. + """ + exit_code = await self._do_scan(content, metadata) + result = exit_code == 0 + + cacheable = exit_code not in self._exit_codes_to_ignore + + if result is False: + raise FileDirtyError(cacheable=cacheable) + + async def _do_scan( + self, + content: bytes, + metadata: Optional[JsonDict] = None, + file_id: Optional[str] = None, + ) -> int: + """Core scan pipeline shared by all request paths. + + Handles: decrypt (if needed) → write to disk → mimetype check → scan → cleanup. + + Args: + content: The raw file bytes (encrypted or plaintext). + metadata: Decryption metadata, or None if the file is unencrypted. + file_id: Identifier used as the temp filename on disk. If None, a random + UUID is generated. Passing the media_path (server_name/media_id) + preserves the original directory structure for traceability. + + Returns: + The exit code from the scan script (0 = clean). + """ + # Decrypt the content if necessary. + if metadata is not None: + # If the file is encrypted, we need to decrypt it before we can scan it. + content = self._decrypt_file(content, metadata) + + # Write the file to disk. + file_path = await self._write_file_to_disk( + file_id or str(uuid.uuid4()), content + ) + + try: + # Check the file's MIME type to see if it's allowed. + self._check_mimetype(file_path) + # Scan the file and see if the result is positive or negative. + exit_code = await self._run_scan(file_path) + # Log the result of the scan. + logger.info("Scan has finished") + finally: + # This could be own function. + logger.info("Removing file") + removal_command_parts = self._removal_command.split() + removal_command_parts.append(file_path) + subprocess.run(removal_command_parts) + + return exit_code + async def _scan_media( self, media: MediaDescription, @@ -344,21 +416,7 @@ async def _scan_media( FileDirtyError if the result of the scan said that the file is dirty, or if the media path is malformed. """ - - # Decrypt the content if necessary. - media_content = media.content - if metadata is not None: - # If the file is encrypted, we need to decrypt it before we can scan it. - media_content = self._decrypt_file(media_content, metadata) - - # Check the file's MIME type to see if it's allowed. - self._check_mimetype(media_content) - - # Write the file to disk. - file_path = self._write_file_to_disk(media_path, media_content) - - # Scan the file and see if the result is positive or negative. - exit_code = await self._run_scan(file_path) + exit_code = await self._do_scan(media.content, metadata, file_id=media_path) result = exit_code == 0 # If the exit code isn't part of the ones we should ignore, cache the result. @@ -369,13 +427,6 @@ async def _scan_media( ) cacheable = False - # Delete the file now that we've scanned it. - logger.info("Scan has finished, removing file") - removal_command_parts = self._removal_command.split() - removal_command_parts.append(file_path) - subprocess.run(removal_command_parts) - - # Raise an error if the result isn't clean. if result is False: raise FileDirtyError(cacheable=cacheable) @@ -445,7 +496,7 @@ def _decrypt_file(self, body: bytes, metadata: JsonDict) -> bytes: info=str(e), ) - def _write_file_to_disk(self, media_path: str, body: bytes) -> str: + async def _write_file_to_disk(self, media_path: str, body: bytes) -> str: """Writes the given content to disk. The final file name will be a concatenation of `temp_directory` and the media's `server_name/media_id` path. @@ -475,8 +526,16 @@ def _write_file_to_disk(self, media_path: str, body: bytes) -> str: # Create any directory we need. os.makedirs(full_path.parent, exist_ok=True) - with open(full_path, "wb") as fp: - fp.write(body) + try: + async with async_open(full_path, "wb") as fp: + await fp.write(body if isinstance(body, bytes) else bytes(body)) + except Exception: + # Delete the file if the write fails. + try: + os.unlink(full_path) + except OSError: + pass + raise return str(full_path) @@ -506,16 +565,15 @@ async def _run_scan(self, file_name: str) -> int: return retcode - def _check_mimetype(self, media_content: bytes) -> None: - """Detects the MIME type of the provided bytes, and checks that this type is allowed + def _check_mimetype(self, filepath: str) -> None: + """Detects the MIME type of the provided file, and checks that this type is allowed (if an allow list is provided in the configuration) Args: - media_content: The file's content. If the file is encrypted, this is its - decrypted content. + filepath: The full file path. Raises: FileMimeTypeForbiddenError if one of the checks fail. """ - detected_mimetype = magic.from_buffer(media_content, mime=True) + detected_mimetype = magic.from_file(filepath, mime=True) logger.debug("Detected MIME type for file is %s", detected_mimetype) # If there's an allow list for MIME types, check that the MIME type that's been diff --git a/src/matrix_content_scanner/servlets/__init__.py b/src/matrix_content_scanner/servlets/__init__.py index 53e4ff7..2ed6168 100644 --- a/src/matrix_content_scanner/servlets/__init__.py +++ b/src/matrix_content_scanner/servlets/__init__.py @@ -180,6 +180,32 @@ async def get_media_metadata_from_request( return media_path, metadata +async def get_media_metadata_from_filebody( + file_body: JsonDict, + crypto_handler: crypto.CryptoHandler, +) -> JsonDict: + """Extracts, optionally decrypts, and validates encrypted file metadata from a + request body. + + Args: + request: The request to extract the data from. + crypto_handler: The crypto handler to use if we need to decrypt an Olm-encrypted + body. + + Raises: + ContentScannerRestError(400) if the request's body is None or if the metadata + didn't pass schema validation. + """ + metadata = _metadata_from_body(file_body, crypto_handler) + + validate_encrypted_file_metadata(metadata) + + # Unlike get_media_metadata_from_request, we intentionally skip extracting + # the file URL from the metadata because the caller already has the media content. + + return metadata + + def _metadata_from_body( body: JsonDict, crypto_handler: crypto.CryptoHandler ) -> JsonDict: diff --git a/src/matrix_content_scanner/servlets/scan.py b/src/matrix_content_scanner/servlets/scan.py index c83cc8a..44ae6ea 100644 --- a/src/matrix_content_scanner/servlets/scan.py +++ b/src/matrix_content_scanner/servlets/scan.py @@ -2,13 +2,19 @@ # # SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial # Please see LICENSE files in the repository root for full details. +import json from typing import TYPE_CHECKING, Optional, Tuple -from aiohttp import web +from aiohttp import BodyPartReader, web from multidict import MultiMapping -from matrix_content_scanner.servlets import get_media_metadata_from_request, web_handler -from matrix_content_scanner.utils.errors import FileDirtyError +from matrix_content_scanner.servlets import ( + get_media_metadata_from_filebody, + get_media_metadata_from_request, + web_handler, +) +from matrix_content_scanner.utils.constants import ErrCode +from matrix_content_scanner.utils.errors import ContentScannerRestError, FileDirtyError from matrix_content_scanner.utils.types import JsonDict if TYPE_CHECKING: @@ -60,3 +66,55 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: metadata, auth_header=request.headers.get("Authorization"), ) + + @web_handler + async def handle_file(self, request: web.Request) -> Tuple[int, JsonDict]: + """Handles GET requests to ../scan_file""" + try: + reader = await request.multipart() + except Exception: + raise ContentScannerRestError( + 400, + ErrCode.MALFORMED_MULTIPART, + "Request body was not a multipart body.", + ) + + body = None + metadata: Optional[JsonDict] = None + + # Iterate to find the fields. + while True: + field = await reader.next() + if (metadata and body) or field is None: + break + if not isinstance(field, BodyPartReader): + continue + if field.name == "file": + try: + file_json = await field.json() + if file_json is None: + raise Exception("'file' field is empty") + except json.decoder.JSONDecodeError as e: + raise ContentScannerRestError(400, ErrCode.MALFORMED_JSON, str(e)) + + metadata = await get_media_metadata_from_filebody( + file_json, self._crypto_handler + ) + elif field.name == "body": + body = await field.read() + + if body is None: + raise ContentScannerRestError( + 400, ErrCode.MALFORMED_MULTIPART, "Missing 'body' field" + ) + + # 'metadata' is optional + + try: + await self._scanner.scan_content(body, metadata) + except FileDirtyError as e: + res = {"clean": False, "info": e.info} + else: + res = {"clean": True, "info": "File is clean"} + + return 200, res diff --git a/src/matrix_content_scanner/utils/constants.py b/src/matrix_content_scanner/utils/constants.py index aac0b06..ab8931c 100644 --- a/src/matrix_content_scanner/utils/constants.py +++ b/src/matrix_content_scanner/utils/constants.py @@ -32,3 +32,5 @@ class ErrCode(str, Enum): MALFORMED_JSON = "MCS_MALFORMED_JSON" # The Mime type is not in the allowed list of Mime types. MIME_TYPE_FORBIDDEN = "MCS_MIME_TYPE_FORBIDDEN" + # The body was not a multipart. + MALFORMED_MULTIPART = "MCS_MALFORMED_MULTIPART"