Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions python_multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ def __init__(self, name: bytes | None) -> None: ...
def set_none(self) -> None: ...

class FileProtocol(_FormProtocol, Protocol):
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ...
def __init__(
self,
file_name: bytes | None,
field_name: bytes | None,
config: FileConfig,
content_type: bytes | None = None,
) -> None: ...

OnFieldCallback = Callable[[FieldProtocol], None]
OnFileCallback = Callable[[FileProtocol], None]
Expand Down Expand Up @@ -355,9 +361,17 @@ class File:
field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
the file was uploaded with Content-Type application/octet-stream.
config: The configuration for this File. See above for valid configuration keys and their corresponding values.
content_type: The Content-Type of the uploaded file as specified in the multipart headers. This can be None
if no Content-Type header was provided in the multipart data.
""" # noqa: E501

def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
def __init__(
self,
file_name: bytes | None,
field_name: bytes | None = None,
config: FileConfig = {},
content_type: bytes | None = None,
) -> None:
# Save configuration, set other variables default.
self.logger = logging.getLogger(__name__)
self._config = config
Expand All @@ -369,6 +383,9 @@ def __init__(self, file_name: bytes | None, field_name: bytes | None = None, con
self._field_name = field_name
self._file_name = file_name

# Save the content type.
self._content_type = content_type

# Our actual file name is None by default, since, depending on our
# config, we may not actually use the provided name.
self._actual_file_name: bytes | None = None
Expand All @@ -393,6 +410,13 @@ def file_name(self) -> bytes | None:
"""The file name given in the upload request."""
return self._file_name

@property
def content_type(self) -> bytes | None:
"""The Content-Type of the uploaded file as declared in the multipart headers.
Returns None if no Content-Type header was provided.
"""
return self._content_type

@property
def actual_file_name(self) -> bytes | None:
"""The file name that this file is saved as. Will be None if it's not
Expand Down Expand Up @@ -571,7 +595,10 @@ def close(self) -> None:
self._fileobj.close()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(file_name={self.file_name!r}, field_name={self.field_name!r})"
return (
f"{self.__class__.__name__}(file_name={self.file_name!r}, "
f"field_name={self.field_name!r}, content_type={self.content_type!r})"
)


class BaseParser:
Expand Down Expand Up @@ -1692,11 +1719,16 @@ def on_headers_finished() -> None:
file_name = options.get(b"filename")
# TODO: check for errors

# Get the Content-Type of the file, if provided.
file_content_type = headers.get(b"Content-Type")

# Create the proper class.
if file_name is None:
f_multi = FieldClass(field_name)
else:
f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
f_multi = FileClass(
file_name, field_name, config=cast("FileConfig", self.config), content_type=file_content_type
)
is_file = True

# Parse the given Content-Transfer-Encoding to determine what
Expand Down
31 changes: 31 additions & 0 deletions tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,37 @@
from python_multipart.multipart import File


def test_file_content_type() -> None:
"""Test that content_type is properly stored and accessible."""
# Test with content_type provided
file_with_ct = File(b"test.png", b"image", content_type=b"image/png")
assert file_with_ct.content_type == b"image/png"
assert file_with_ct.file_name == b"test.png"
assert file_with_ct.field_name == b"image"

# Test without content_type (defaults to None)
file_without_ct = File(b"test.txt", b"document")
assert file_without_ct.content_type is None
assert file_without_ct.file_name == b"test.txt"

# Test with explicit None content_type
file_explicit_none = File(b"test.txt", content_type=None)
assert file_explicit_none.content_type is None


def test_file_repr_with_content_type() -> None:
"""Test that the repr includes content_type."""
file_with_ct = File(b"test.png", b"image", content_type=b"image/png")
repr_str = repr(file_with_ct)
assert "content_type=b'image/png'" in repr_str
assert "file_name=b'test.png'" in repr_str
assert "field_name=b'image'" in repr_str

file_without_ct = File(b"test.txt", b"doc")
repr_str = repr(file_without_ct)
assert "content_type=None" in repr_str


def test_upload_dir_with_leading_slash_in_filename(tmp_path: Path) -> None:
upload_dir = tmp_path / "upload"
upload_dir.mkdir()
Expand Down
51 changes: 51 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,57 @@ def on_file(f: FileProtocol) -> None:
f.finalize()
self.assert_file_data(files[0], b"Test")

def test_file_content_type_from_multipart(self) -> None:
"""Test that Content-Type header from multipart is passed to File."""
data = (
b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.png"\r\n'
b"Content-Type: image/png\r\n\r\n"
b"Test\r\n----boundary--\r\n"
)

files: list[File] = []

def on_file(f: FileProtocol) -> None:
files.append(cast(File, f))

on_field = Mock()
on_end = Mock()

f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary")

f.write(data)
f.finalize()

self.assertEqual(len(files), 1)
self.assertEqual(files[0].content_type, b"image/png")
self.assertEqual(files[0].file_name, b"test.png")
self.assertEqual(files[0].field_name, b"file")
self.assert_file_data(files[0], b"Test")

def test_file_content_type_none_when_not_provided(self) -> None:
"""Test that content_type is None when no Content-Type header in multipart."""
data = (
b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\n\r\n'
b"Test\r\n----boundary--\r\n"
)

files: list[File] = []

def on_file(f: FileProtocol) -> None:
files.append(cast(File, f))

on_field = Mock()
on_end = Mock()

f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary")

f.write(data)
f.finalize()

self.assertEqual(len(files), 1)
self.assertIsNone(files[0].content_type)
self.assertEqual(files[0].file_name, b"test.txt")

def test_handles_None_fields(self) -> None:
fields: list[Field] = []

Expand Down