Skip to content

Commit 4e0dc2d

Browse files
refactor(file-utils): enhance file handling utilities for images and PDFs
- Breaking change: Required `datatype` in data url to detect supported file formats - Introduced a new `io_utils.py` module for determining MIME types and reading bytes from various sources. - Refactored `load_source` in `file_utils.py` to utilize the new `source_file_type` function for improved file type detection. - Updated `ImageSource` and `PdfSource` classes to leverage the `read_bytes` function for loading data. - Removed the deprecated `load_image` and `load_pdf` functions, streamlining the codebase. - Enhanced tests for `PdfSource` and `ImageSource` to validate loading from paths and data URLs, ensuring robust error handling.
1 parent 723742a commit 4e0dc2d

File tree

6 files changed

+183
-151
lines changed

6 files changed

+183
-151
lines changed

src/askui/utils/file_utils.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,49 @@
11
from pathlib import Path
22
from typing import Union
33

4-
from filetype import guess # type: ignore[import-untyped]
5-
from PIL import Image
4+
from PIL import Image as PILImage
65

76
from askui.utils.image_utils import ImageSource
7+
from askui.utils.io_utils import source_file_type
88
from askui.utils.pdf_utils import PdfSource
99

10+
# to avoid circular imports from image_utils and pdf_utils on read_bytes
1011
Source = Union[ImageSource, PdfSource]
1112

13+
ALLOWED_IMAGE_TYPES = [
14+
"image/png",
15+
"image/jpeg",
16+
"image/gif",
17+
"image/webp",
18+
]
1219

13-
def load_source(source: Union[str, Path, Image.Image]) -> Source:
14-
"""Load a source and return appropriate Source object based on file type."""
20+
PDF_TYPE = "application/pdf"
1521

16-
if isinstance(source, Image.Image):
17-
return ImageSource(source)
22+
ALLOWED_MIMETYPES = [PDF_TYPE] + ALLOWED_IMAGE_TYPES
1823

19-
filepath = Path(source)
20-
if not filepath.is_file():
21-
msg = f"No such file or directory: '{source}'"
22-
raise FileNotFoundError(msg)
2324

24-
kind = guess(str(filepath))
25-
if kind and kind.mime == "application/pdf":
26-
return PdfSource(source)
27-
if kind and kind.mime.startswith("image/"):
25+
def load_source(source: Union[str, Path, PILImage.Image]) -> Source:
26+
"""Load a source and return it as an ImageSource or PdfSource.
27+
28+
Args:
29+
source (Union[str, Path]): The source to load.
30+
31+
Returns:
32+
Source: The loaded source as an ImageSource or PdfSource.
33+
34+
Raises:
35+
ValueError: If the source is not a valid image or PDF file.
36+
"""
37+
if isinstance(source, PILImage.Image):
38+
return ImageSource(source)
39+
40+
file_type = source_file_type(source)
41+
if file_type in ALLOWED_IMAGE_TYPES:
2842
return ImageSource(source)
29-
msg = f"Unsupported file type: {filepath.suffix}"
43+
if file_type == PDF_TYPE:
44+
return PdfSource(source)
45+
msg = f"Unsupported file type: {file_type}"
3046
raise ValueError(msg)
3147

3248

33-
__all__ = ["load_source", "Source"]
49+
__all__ = ["Source", "load_source"]

src/askui/utils/image_utils.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,52 +11,20 @@
1111
from PIL import Image as PILImage
1212
from pydantic import ConfigDict, RootModel, field_validator
1313

14+
from askui.utils.io_utils import read_bytes
15+
1416
# Regex to capture any kind of valid base64 data url (with optional media type and ;base64)
1517
# e.g., data:image/png;base64,... or data:;base64,... or data:,... or just ,...
1618
_DATA_URL_GENERIC_RE = re.compile(r"^(?:data:)?[^,]*?,(.*)$", re.DOTALL)
1719

1820

19-
def load_image(source: Union[str, Path, Image.Image]) -> Image.Image:
20-
"""Load and validate an image from a PIL Image, a path, or any form of base64 data URL.
21-
22-
Args:
23-
source (Union[str, Path, Image.Image]): The image source to load from.
24-
Can be a PIL Image, file path (`str` or `pathlib.Path`), or data URL.
25-
26-
Returns:
27-
Image.Image: A valid PIL Image object.
28-
29-
Raises:
30-
ValueError: If the input is not a valid or recognizable image.
31-
"""
32-
if isinstance(source, Image.Image):
33-
return source
34-
35-
if isinstance(source, Path) or (not source.startswith(("data:", ","))):
36-
try:
37-
return Image.open(source)
38-
except (OSError, FileNotFoundError, UnidentifiedImageError) as e:
39-
error_msg = f"Could not open image from file path: {source}"
40-
raise ValueError(error_msg) from e
41-
42-
else:
43-
match = _DATA_URL_GENERIC_RE.match(source)
44-
if match:
45-
try:
46-
image_data = base64.b64decode(match.group(1))
47-
return Image.open(io.BytesIO(image_data))
48-
except (binascii.Error, UnidentifiedImageError):
49-
try:
50-
return Image.open(source)
51-
except (FileNotFoundError, UnidentifiedImageError) as e:
52-
error_msg = (
53-
f"Could not decode or identify image from input:"
54-
f"{source[:100]}{'...' if len(source) > 100 else ''}"
55-
)
56-
raise ValueError(error_msg) from e
57-
58-
error_msg = f"Unsupported image input type: {type(source)}"
59-
raise ValueError(error_msg)
21+
def _bytes_to_image(image_bytes: bytes) -> Image.Image:
22+
"""Convert bytes to a PIL Image."""
23+
try:
24+
return Image.open(io.BytesIO(image_bytes))
25+
except (FileNotFoundError, UnidentifiedImageError) as e:
26+
error_msg = "Could not identify image from bytes"
27+
raise ValueError(error_msg) from e
6028

6129

6230
def image_to_data_url(image: PILImage.Image) -> str:
@@ -391,8 +359,12 @@ def __init__(self, root: Img, **kwargs: dict[str, Any]) -> None:
391359

392360
@field_validator("root", mode="before")
393361
@classmethod
394-
def validate_root(cls, v: Any) -> PILImage.Image:
395-
return load_image(v)
362+
def validate_root(cls, v: Any) -> Image.Image:
363+
if isinstance(v, Image.Image):
364+
return v
365+
366+
image_bytes = read_bytes(v)
367+
return _bytes_to_image(image_bytes)
396368

397369
def to_data_url(self) -> str:
398370
"""Convert the image to a data URL.
@@ -422,7 +394,6 @@ def to_bytes(self) -> bytes:
422394

423395

424396
__all__ = [
425-
"load_image",
426397
"image_to_data_url",
427398
"data_url_to_image",
428399
"draw_point_on_image",

src/askui/utils/io_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import base64
2+
import binascii
3+
import re
4+
from pathlib import Path
5+
from typing import Union
6+
7+
from filetype import guess # type: ignore[import-untyped]
8+
9+
_DATA_URL_WITH_MIMETYPE_RE = re.compile(r"^data:([^;,]+)[^,]*?,(.*)$", re.DOTALL)
10+
11+
12+
def source_file_type(source: Union[str, Path]) -> str:
13+
"""Determines the MIME type of a source.
14+
15+
The source can be a file path or a data URL.
16+
17+
Args:
18+
source (Union[str , Path]): The source to determine the type of.
19+
Can be a file path (`str` or `pathlib.Path`) or a data URL.
20+
21+
Returns:
22+
str: The MIME type of the source, or "unknown" if it cannot be determined.
23+
"""
24+
25+
# when source is a data url
26+
if isinstance(source, str) and source.startswith("data:"):
27+
match = _DATA_URL_WITH_MIMETYPE_RE.match(source)
28+
if match and match.group(1):
29+
return match.group(1)
30+
else:
31+
kind = guess(str(source))
32+
if kind is not None and kind.mime is not None:
33+
return str(kind.mime)
34+
35+
return "unknown"
36+
37+
38+
def read_bytes(source: Union[str, Path]) -> bytes:
39+
"""Read the bytes of a source.
40+
41+
The source can be a file path or a data URL.
42+
43+
Args:
44+
source (Union[str, Path]): The source to read the bytes from.
45+
46+
Returns:
47+
bytes: The content of the source as bytes.
48+
"""
49+
# when source is a file path and not a data url
50+
if isinstance(source, Path) or (
51+
isinstance(source, str) and not source.startswith(("data:", ","))
52+
):
53+
filepath = Path(source)
54+
if not filepath.is_file():
55+
err_msg = f"No such file or directory: '{source}'"
56+
raise ValueError(err_msg)
57+
58+
return filepath.read_bytes()
59+
60+
# when source is a data url
61+
if isinstance(source, str) and source.startswith(("data:", ",")):
62+
match = _DATA_URL_WITH_MIMETYPE_RE.match(source)
63+
if match:
64+
try:
65+
return base64.b64decode(match.group(2))
66+
except binascii.Error as e:
67+
error_msg = (
68+
"Could not decode base64 data from input: "
69+
f"{source[:100]}{'...' if len(source) > 100 else ''}"
70+
)
71+
raise ValueError(error_msg) from e
72+
73+
msg = f"Unsupported source type: {type(source)}"
74+
raise ValueError(msg)
75+
76+
77+
__all__ = ["read_bytes"]

src/askui/utils/pdf_utils.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from pydantic import ConfigDict, RootModel, field_validator
55

6+
from askui.utils.io_utils import read_bytes
7+
68
Pdf = Union[str, Path]
79
"""Type of the input PDFs for `askui.VisionAgent.get()`, etc.
810
@@ -11,27 +13,6 @@
1113
"""
1214

1315

14-
def load_pdf(source: Union[str, Path]) -> bytes:
15-
"""Load a PDF from a path and return its bytes.
16-
17-
Args:
18-
source (Union[str, Path]): The PDF source to load from.
19-
20-
Returns:
21-
bytes: The PDF content as bytes.
22-
23-
Raises:
24-
FileNotFoundError: If the file is not found.
25-
ValueError: If the file is too large.
26-
"""
27-
filepath = Path(source)
28-
if not filepath.is_file():
29-
err_msg = f"No such file or directory: '{source}'"
30-
raise FileNotFoundError(err_msg)
31-
32-
return filepath.read_bytes()
33-
34-
3516
class PdfSource(RootModel):
3617
"""A class that represents a PDF source.
3718
It provides methods to convert it to different formats.
@@ -55,11 +36,10 @@ def __init__(self, root: Pdf, **kwargs: dict[str, Any]) -> None:
5536
@field_validator("root", mode="before")
5637
@classmethod
5738
def validate_root(cls, v: Any) -> bytes:
58-
return load_pdf(v)
39+
return read_bytes(v)
5940

6041

6142
__all__ = [
6243
"PdfSource",
6344
"Pdf",
64-
"load_pdf",
6545
]
Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,41 @@
1+
import base64
12
import pathlib
23

34
import pytest
45

5-
from askui.utils.pdf_utils import load_pdf
6+
from askui.utils.file_utils import PdfSource
67

78

89
class TestLoadPdf:
910
def test_load_pdf_from_path(self, path_fixtures_dummy_pdf: pathlib.Path) -> None:
1011
# Test loading from Path
11-
loaded = load_pdf(path_fixtures_dummy_pdf)
12-
assert isinstance(loaded, bytes)
13-
assert len(loaded) > 0
12+
loaded = PdfSource(path_fixtures_dummy_pdf)
13+
assert isinstance(loaded.root, bytes)
14+
assert len(loaded.root) > 0
1415

1516
# Test loading from str path
16-
loaded = load_pdf(str(path_fixtures_dummy_pdf))
17-
assert isinstance(loaded, bytes)
18-
assert len(loaded) > 0
17+
loaded = PdfSource(str(path_fixtures_dummy_pdf))
18+
assert isinstance(loaded.root, bytes)
19+
assert len(loaded.root) > 0
1920

2021
def test_load_pdf_nonexistent_file(self) -> None:
21-
with pytest.raises(FileNotFoundError):
22-
load_pdf("nonexistent_file.pdf")
22+
with pytest.raises(ValueError):
23+
PdfSource("nonexistent_file.pdf")
24+
25+
def test_pdf_source_from_data_url(
26+
self, path_fixtures_dummy_pdf: pathlib.Path
27+
) -> None:
28+
# Load test image and convert to base64
29+
with pathlib.Path.open(path_fixtures_dummy_pdf, "rb") as f:
30+
pdf_bytes = f.read()
31+
pdf_str = base64.b64encode(pdf_bytes).decode()
32+
33+
# Test different base64 formats
34+
formats = [
35+
f"data:application/pdf;base64,{pdf_str}",
36+
]
37+
38+
for fmt in formats:
39+
source = PdfSource(fmt)
40+
assert isinstance(source.root, bytes)
41+
assert len(source.root) > 0

0 commit comments

Comments
 (0)