diff --git a/requirements.txt b/requirements.txt index e1cb5b42..67bf34e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bitsandbytes==0.45.5 ExifRead==3.3.1 imagesize==1.4.1 pillow==11.2.1 +pillow-jxl-plugin~=1.3.4 pyparsing==3.2.1 PySide6==6.9.0 transformers==4.48.3 diff --git a/taggui/auto_captioning/auto_captioning_model.py b/taggui/auto_captioning/auto_captioning_model.py index 63d2a2a4..93e262a8 100644 --- a/taggui/auto_captioning/auto_captioning_model.py +++ b/taggui/auto_captioning/auto_captioning_model.py @@ -5,6 +5,7 @@ import numpy as np import torch +import pillow_jxl from PIL import Image as PilImage from PIL.ImageOps import exif_transpose from transformers import (AutoModelForVision2Seq, AutoProcessor, diff --git a/taggui/models/image_list_model.py b/taggui/models/image_list_model.py index 042249c8..806e0e66 100644 --- a/taggui/models/image_list_model.py +++ b/taggui/models/image_list_model.py @@ -12,28 +12,36 @@ import imagesize from PySide6.QtCore import (QAbstractListModel, QModelIndex, QMimeData, QPoint, QRect, QSize, Qt, QUrl, Signal, Slot) -from PySide6.QtGui import QIcon, QImageReader, QPixmap +from PySide6.QtGui import QIcon, QImage, QImageReader, QPixmap from PySide6.QtWidgets import QMessageBox +import pillow_jxl +from PIL import Image as pilimage # Import Pillow's Image class + from utils.image import Image, ImageMarking, Marking +from utils.jxlutil import get_jxl_size from utils.settings import DEFAULT_SETTINGS, settings from utils.utils import get_confirmation_dialog_reply, pluralize import utils.target_dimension as target_dimension UNDO_STACK_SIZE = 32 +def pil_to_qimage(pil_image): + """Convert PIL image to QImage properly""" + pil_image = pil_image.convert("RGBA") + data = pil_image.tobytes("raw", "RGBA") + qimage = QImage(data, pil_image.width, pil_image.height, QImage.Format_RGBA8888) + return qimage def get_file_paths(directory_path: Path) -> set[Path]: """ - Recursively get all file paths in a directory, including those in + Recursively get all file paths in a directory, including subdirectories. """ file_paths = set() - for path in directory_path.iterdir(): + for path in directory_path.rglob("*"): # Use rglob for recursive search if path.is_file(): file_paths.add(path) - elif path.is_dir(): - file_paths.update(get_file_paths(path)) return file_paths @@ -85,7 +93,7 @@ def mimeData(self, indexes): def rowCount(self, parent=None) -> int: return len(self.images) - def data(self, index, role=None) -> Image | str | QIcon | QSize: + def data(self, index: QModelIndex, role=None) -> Image | str | QIcon | QSize: image = self.images[index.row()] if role == Qt.ItemDataRole.UserRole: return image @@ -101,21 +109,37 @@ def data(self, index, role=None) -> Image | str | QIcon | QSize: # it. Otherwise, generate a thumbnail and save it to the image. if image.thumbnail: return image.thumbnail - image_reader = QImageReader(str(image.path)) - # Rotate the image based on the orientation tag. - image_reader.setAutoTransform(True) - if image.crop: - crop = image.crop - else: - crop = QRect(QPoint(0, 0), image_reader.size()) - if crop.height() > crop.width()*3: - # keep it reasonable, higher than 3x the width doesn't make sense - crop.setTop((crop.height() - crop.width()*3)//2) # center crop - crop.setHeight(crop.width()*3) - image_reader.setClipRect(crop) - pixmap = QPixmap.fromImageReader(image_reader).scaledToWidth( - self.image_list_image_width, - Qt.TransformationMode.SmoothTransformation) + crop = image.crop + try: + if image.path.suffix.lower() == ".jxl": + pil_image = pilimage.open(image.path) # Uses pillow-jxl + qimage = pil_to_qimage(pil_image) + if not crop: + crop = QRect(QPoint(0, 0), qimage.size()) + if crop.height() > crop.width()*3: + # keep it reasonable, higher than 3x the width doesn't make sense + crop.setTop((crop.height() - crop.width()*3)//2) # center crop + crop.setHeight(crop.width()*3) + + pixmap = QPixmap.fromImage(qimage).scaledToWidth( + self.image_list_image_width, + Qt.TransformationMode.SmoothTransformation) + else: + image_reader = QImageReader(str(image.path)) + # Rotate the image based on the orientation tag. + image_reader.setAutoTransform(True) + if not crop: + crop = QRect(QPoint(0, 0), image_reader.size()) + if crop.height() > crop.width()*3: + # keep it reasonable, higher than 3x the width doesn't make sense + crop.setTop((crop.height() - crop.width()*3)//2) # center crop + crop.setHeight(crop.width()*3) + image_reader.setClipRect(crop) + pixmap = QPixmap.fromImageReader(image_reader).scaledToWidth( + self.image_list_image_width, + Qt.TransformationMode.SmoothTransformation) + except Exception as e: + print(f"Error loading image {image.path}: {e}") thumbnail = QIcon(pixmap) image.thumbnail = thumbnail return thumbnail @@ -167,9 +191,10 @@ def load_directory(self, directory_path: Path): if path.suffix == '.json'} for image_path in image_paths: try: - dimensions = imagesize.get(image_path) - # Check the Exif orientation tag and rotate the dimensions if - # necessary. + if str(image_path).endswith('jxl'): + dimensions = get_jxl_size(image_path) + else: + dimensions = pilimage.open(image_path).size with open(image_path, 'rb') as image_file: try: exif_tags = exifread.process_file( diff --git a/taggui/utils/jxlutil.py b/taggui/utils/jxlutil.py new file mode 100644 index 00000000..5ea4c981 --- /dev/null +++ b/taggui/utils/jxlutil.py @@ -0,0 +1,184 @@ +# Modified from https://github.com/Fraetor/jxl_decode +# Added partial read support for up to 200x speedup +import os + +class JXLBitstream: + """ + A stream of bits with methods for easy handling. + """ + + def __init__(self, file, offset=0, offsets=[]) -> None: + self.shift = 0 + self.bitstream = [] + self.file = file + self.offset = offset + self.offsets = offsets + if self.offsets: + self.offset = self.offsets[0][1] + self.previous_data_len = 0 + self.index = 0 + self.file.seek(self.offset) + + def get_bits(self, length: int = 1) -> int: + if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_to_read_length = length + if self.shift < self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(0, length) + self.bitstream += self.file.read(self.partial_to_read_length) + else: + self.bitstream += self.file.read(length) + bitmask = 2**length - 1 + bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask + self.shift += length + return bits + + def partial_read(self, current_length, length): + self.previous_data_len += self.offsets[self.index][2] + to_read_length = self.previous_data_len - (self.shift + current_length) + self.bitstream += self.file.read(to_read_length) + current_length += to_read_length + self.partial_to_read_length -= to_read_length + self.index += 1 + self.file.seek(self.offsets[self.index][1]) + if self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(current_length, length) + + +def decode_codestream(file, offset=0, offsets=[]): + """ + Decodes the actual codestream. + JXL codestream specification: https://www.iso.org/standard/85066.html + """ + + # Convert codestream to int within an object to get some handy methods. + codestream = JXLBitstream(file, offset=offset, offsets=offsets) + + # Skip signature + codestream.get_bits(16) + + # SizeHeader + div8 = codestream.get_bits(1) + if div8: + height = 8 * (1 + codestream.get_bits(5)) + else: + distribution = codestream.get_bits(2) + match distribution: + case 0: + height = 1 + codestream.get_bits(9) + case 1: + height = 1 + codestream.get_bits(13) + case 2: + height = 1 + codestream.get_bits(18) + case 3: + height = 1 + codestream.get_bits(30) + ratio = codestream.get_bits(3) + if div8 and not ratio: + width = 8 * (1 + codestream.get_bits(5)) + elif not ratio: + distribution = codestream.get_bits(2) + match distribution: + case 0: + width = 1 + codestream.get_bits(9) + case 1: + width = 1 + codestream.get_bits(13) + case 2: + width = 1 + codestream.get_bits(18) + case 3: + width = 1 + codestream.get_bits(30) + else: + match ratio: + case 1: + width = height + case 2: + width = (height * 12) // 10 + case 3: + width = (height * 4) // 3 + case 4: + width = (height * 3) // 2 + case 5: + width = (height * 16) // 9 + case 6: + width = (height * 5) // 4 + case 7: + width = (height * 2) // 1 + return width, height + + +def decode_container(file): + """ + Parses the ISOBMFF container, extracts the codestream, and decodes it. + JXL container specification: http://www-internal/2022/18181-2 + """ + + def parse_box(file, file_start) -> dict: + file.seek(file_start) + LBox = int.from_bytes(file.read(4), "big") + XLBox = None + if 1 < LBox <= 8: + raise ValueError(f"Invalid LBox at byte {file_start}.") + if LBox == 1: + file.seek(file_start + 8) + XLBox = int.from_bytes(file.read(8), "big") + if XLBox <= 16: + raise ValueError(f"Invalid XLBox at byte {file_start}.") + if XLBox: + header_length = 16 + box_length = XLBox + else: + header_length = 8 + if LBox == 0: + box_length = os.fstat(file.fileno()).st_size - file_start + else: + box_length = LBox + file.seek(file_start + 4) + box_type = file.read(4) + file.seek(file_start) + return { + "length": box_length, + "type": box_type, + "offset": header_length, + } + + file.seek(0) + # Reject files missing required boxes. These two boxes are required to be at + # the start and contain no values, so we can manually check there presence. + # Signature box. (Redundant as has already been checked.) + if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"): + raise ValueError("Invalid signature box.") + # File Type box. + if file.read(20) != bytes.fromhex( + "00000014 66747970 6A786C20 00000000 6A786C20" + ): + raise ValueError("Invalid file type box.") + + offset = 0 + offsets = [] + data_offset_not_found = True + container_pointer = 32 + file_size = os.fstat(file.fileno()).st_size + while data_offset_not_found: + box = parse_box(file, container_pointer) + match box["type"]: + case b"jxlc": + offset = container_pointer + box["offset"] + data_offset_not_found = False + case b"jxlp": + file.seek(container_pointer + box["offset"]) + index = int.from_bytes(file.read(4), "big") + offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4]) + container_pointer += box["length"] + if container_pointer >= file_size: + data_offset_not_found = False + + if offsets: + offsets.sort(key=lambda i: i[0]) + file.seek(0) + + return decode_codestream(file, offset=offset, offsets=offsets) + + +def get_jxl_size(path): + with open(path, "rb") as file: + if file.read(2) == bytes.fromhex("FF0A"): + return decode_codestream(file) + return decode_container(file) \ No newline at end of file diff --git a/taggui/utils/settings.py b/taggui/utils/settings.py index 5358a035..979e3f85 100644 --- a/taggui/utils/settings.py +++ b/taggui/utils/settings.py @@ -3,8 +3,8 @@ # Defaults for settings that are accessed from multiple places. DEFAULT_SETTINGS = { 'font_size': 16, - # Common image formats that are supported in PySide6. - 'image_list_file_formats': 'bmp, gif, jpg, jpeg, png, tif, tiff, webp', + # Common image formats that are supported in PySide6, as well as JPEG XL. + 'image_list_file_formats': 'bmp, gif, jpg, jpeg, jxl, png, tif, tiff, webp', 'image_list_image_width': 200, 'tag_separator': ',', 'insert_space_after_tag_separator': True, diff --git a/taggui/widgets/image_viewer.py b/taggui/widgets/image_viewer.py index b3b156c5..aa3d3de0 100644 --- a/taggui/widgets/image_viewer.py +++ b/taggui/widgets/image_viewer.py @@ -2,13 +2,14 @@ from math import ceil, floor, sqrt from PySide6.QtCore import (QModelIndex, QPersistentModelIndex, QPoint, QPointF, QRect, QRectF, QSize, Qt, Signal, Slot) -from PySide6.QtGui import (QAction, QActionGroup, QColor, QIcon, +from PySide6.QtGui import (QAction, QActionGroup, QColor, QIcon, QImage, QPainter, QPainterPath, QPen, QPixmap, QTransform, QMouseEvent) from PySide6.QtWidgets import (QGraphicsItem, QGraphicsLineItem, QGraphicsPixmapItem, QGraphicsRectItem, QGraphicsTextItem, QGraphicsScene, QGraphicsView, QMenu, QVBoxLayout, QWidget) +from PIL import Image as pilimage from utils.settings import settings from models.proxy_image_list_model import ProxyImageListModel from utils.image import Image, ImageMarking, Marking @@ -672,7 +673,18 @@ def load_image(self, proxy_image_index: QModelIndex, is_complete = True): if is_complete: self.marking_items.clear() self.view.clear_scene() - pixmap = QPixmap(str(image.path)) + if image.path.suffix.lower() == ".jxl": + pil_image = pilimage.open(image.path) # Decode JXL using Pillow + pil_image = pil_image.convert("RGBA") # Ensure RGBA format + + pixmap = QPixmap(QImage( + pil_image.tobytes("raw", "RGBA"), + pil_image.width, + pil_image.height, + QImage.Format_RGBA8888 + )) + else: + pixmap = QPixmap(str(image.path)) image_item = QGraphicsPixmapItem(pixmap) image_item.setZValue(0) self.scene.setSceneRect(image_item.boundingRect()