diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 1489b7a..910c407 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -72,7 +72,13 @@ def __init__(self, name: bytes | None) -> None: ... def set_none(self) -> None: ... class FileProtocol(_FormProtocol, Protocol): - def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ... + def __init__( + self, + file_name: bytes | None, + field_name: bytes | None, + config: FileConfig, + content_type: bytes | None = None, + ) -> None: ... OnFieldCallback = Callable[[FieldProtocol], None] OnFileCallback = Callable[[FileProtocol], None] @@ -355,9 +361,17 @@ class File: field_name: The name of the form field that this file was uploaded with. This can be None, if, for example, the file was uploaded with Content-Type application/octet-stream. config: The configuration for this File. See above for valid configuration keys and their corresponding values. + content_type: The Content-Type of the uploaded file as specified in the multipart headers. This can be None + if no Content-Type header was provided in the multipart data. """ # noqa: E501 - def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None: + def __init__( + self, + file_name: bytes | None, + field_name: bytes | None = None, + config: FileConfig = {}, + content_type: bytes | None = None, + ) -> None: # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config @@ -369,6 +383,9 @@ def __init__(self, file_name: bytes | None, field_name: bytes | None = None, con self._field_name = field_name self._file_name = file_name + # Save the content type. + self._content_type = content_type + # Our actual file name is None by default, since, depending on our # config, we may not actually use the provided name. self._actual_file_name: bytes | None = None @@ -393,6 +410,13 @@ def file_name(self) -> bytes | None: """The file name given in the upload request.""" return self._file_name + @property + def content_type(self) -> bytes | None: + """The Content-Type of the uploaded file as declared in the multipart headers. + Returns None if no Content-Type header was provided. + """ + return self._content_type + @property def actual_file_name(self) -> bytes | None: """The file name that this file is saved as. Will be None if it's not @@ -571,7 +595,10 @@ def close(self) -> None: self._fileobj.close() def __repr__(self) -> str: - return f"{self.__class__.__name__}(file_name={self.file_name!r}, field_name={self.field_name!r})" + return ( + f"{self.__class__.__name__}(file_name={self.file_name!r}, " + f"field_name={self.field_name!r}, content_type={self.content_type!r})" + ) class BaseParser: @@ -1692,11 +1719,16 @@ def on_headers_finished() -> None: file_name = options.get(b"filename") # TODO: check for errors + # Get the Content-Type of the file, if provided. + file_content_type = headers.get(b"Content-Type") + # Create the proper class. if file_name is None: f_multi = FieldClass(field_name) else: - f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config)) + f_multi = FileClass( + file_name, field_name, config=cast("FileConfig", self.config), content_type=file_content_type + ) is_file = True # Parse the given Content-Transfer-Encoding to determine what diff --git a/tests/test_file.py b/tests/test_file.py index a2aa134..e6eb896 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -3,6 +3,37 @@ from python_multipart.multipart import File +def test_file_content_type() -> None: + """Test that content_type is properly stored and accessible.""" + # Test with content_type provided + file_with_ct = File(b"test.png", b"image", content_type=b"image/png") + assert file_with_ct.content_type == b"image/png" + assert file_with_ct.file_name == b"test.png" + assert file_with_ct.field_name == b"image" + + # Test without content_type (defaults to None) + file_without_ct = File(b"test.txt", b"document") + assert file_without_ct.content_type is None + assert file_without_ct.file_name == b"test.txt" + + # Test with explicit None content_type + file_explicit_none = File(b"test.txt", content_type=None) + assert file_explicit_none.content_type is None + + +def test_file_repr_with_content_type() -> None: + """Test that the repr includes content_type.""" + file_with_ct = File(b"test.png", b"image", content_type=b"image/png") + repr_str = repr(file_with_ct) + assert "content_type=b'image/png'" in repr_str + assert "file_name=b'test.png'" in repr_str + assert "field_name=b'image'" in repr_str + + file_without_ct = File(b"test.txt", b"doc") + repr_str = repr(file_without_ct) + assert "content_type=None" in repr_str + + def test_upload_dir_with_leading_slash_in_filename(tmp_path: Path) -> None: upload_dir = tmp_path / "upload" upload_dir.mkdir() diff --git a/tests/test_multipart.py b/tests/test_multipart.py index b137331..97c128c 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1193,6 +1193,57 @@ def on_file(f: FileProtocol) -> None: f.finalize() self.assert_file_data(files[0], b"Test") + def test_file_content_type_from_multipart(self) -> None: + """Test that Content-Type header from multipart is passed to File.""" + data = ( + b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.png"\r\n' + b"Content-Type: image/png\r\n\r\n" + b"Test\r\n----boundary--\r\n" + ) + + files: list[File] = [] + + def on_file(f: FileProtocol) -> None: + files.append(cast(File, f)) + + on_field = Mock() + on_end = Mock() + + f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary") + + f.write(data) + f.finalize() + + self.assertEqual(len(files), 1) + self.assertEqual(files[0].content_type, b"image/png") + self.assertEqual(files[0].file_name, b"test.png") + self.assertEqual(files[0].field_name, b"file") + self.assert_file_data(files[0], b"Test") + + def test_file_content_type_none_when_not_provided(self) -> None: + """Test that content_type is None when no Content-Type header in multipart.""" + data = ( + b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\n\r\n' + b"Test\r\n----boundary--\r\n" + ) + + files: list[File] = [] + + def on_file(f: FileProtocol) -> None: + files.append(cast(File, f)) + + on_field = Mock() + on_end = Mock() + + f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary") + + f.write(data) + f.finalize() + + self.assertEqual(len(files), 1) + self.assertIsNone(files[0].content_type) + self.assertEqual(files[0].file_name, b"test.txt") + def test_handles_None_fields(self) -> None: fields: list[Field] = []