diff --git a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 43f26dad3..bd70c3b90 100644 --- a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -10,7 +10,7 @@ from concurrent.futures import CancelledError from contextlib import contextmanager from dataclasses import dataclass -from pathlib import Path, PosixPath +from pathlib import Path from queue import Queue from urllib.parse import urlparse @@ -18,7 +18,13 @@ import pexpect import requests from jumpstarter_driver_composite.client import CompositeClient -from jumpstarter_driver_opendal.client import FlasherClient, OpendalClient, operator_for_path +from jumpstarter_driver_opendal.client import ( + FlasherClient, + OpendalClient, + clean_filename, + operator_for_path, + path_with_query, +) from jumpstarter_driver_opendal.common import PathBuf from jumpstarter_driver_pyserial.client import Console from opendal import Metadata, Operator @@ -167,10 +173,10 @@ def flash( # noqa: C901 "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token ) operator_scheme = "http" - path = Path(parsed.path) + path = path_with_query(parsed) else: path, operator, operator_scheme = operator_for_path(path) - image_url = self.http.get_url() + "/" + path.name + image_url = self.http.get_url() + "/" + self._filename(path) # start counting time for the flash operation start_time = time.time() @@ -966,9 +972,9 @@ def _transfer_bg_thread( original_url: Original URL for HTTP fallback headers: HTTP headers for requests """ - self.logger.info(f"Writing image to storage in the background: {src_path}") + filename = self._filename(src_path) + self.logger.info(f"Writing image to storage in the background: {filename}") try: - filename = Path(src_path).name if isinstance(src_path, (str, os.PathLike)) else src_path.name if src_operator_scheme == "fs": file_hash = self._sha256_file(src_operator, src_path) @@ -1019,7 +1025,7 @@ def _create_metadata_and_json( ) -> tuple[Metadata | None, str]: """Create a metadata json string from a metadata object""" metadata = None - metadata_dict = {"path": str(src_path)} + metadata_dict = {"path": clean_filename(src_path)} try: metadata = src_operator.stat(src_path) @@ -1088,8 +1094,8 @@ def dump( raise NotImplementedError("Dump is not implemented for this driver yet") def _filename(self, path: PathBuf) -> str: - """Extract filename from url or path""" - if path.startswith("oci://"): + """Extract filename from url or path, stripping any query parameters""" + if isinstance(path, str) and path.startswith("oci://"): oci_path = path[6:] # Remove "oci://" prefix if ":" in oci_path: repository, tag = oci_path.rsplit(":", 1) @@ -1098,10 +1104,8 @@ def _filename(self, path: PathBuf) -> str: else: repo_name = oci_path.split("/")[-1] if "/" in oci_path else oci_path return repo_name - elif path.startswith(("http://", "https://")): - return urlparse(path).path.split("/")[-1] else: - return Path(path).name + return clean_filename(path) def _upload_artifact(self, storage, path: PathBuf, operator: Operator): """Upload artifact to storage""" @@ -1636,19 +1640,16 @@ def _get_decompression_command(filename_or_url) -> str: Determine the appropriate decompression command based on file extension Args: - filename (str): Name of the file to check + filename_or_url (str): Name of the file or URL to check Returns: - str: Decompression command ('zcat', 'xzcat', or 'cat' for uncompressed) + str: Decompression command ('zcat |', 'xzcat |', 'zstdcat |', or '' for uncompressed) """ - if type(filename_or_url) is PosixPath: - filename = filename_or_url.name - elif filename_or_url.startswith(("http://", "https://")): - filename = urlparse(filename_or_url).path.split("/")[-1] - - filename = filename.lower() + filename = clean_filename(filename_or_url).lower() if filename.endswith((".gz", ".gzip")): return "zcat |" elif filename.endswith(".xz"): return "xzcat |" + elif filename.endswith(".zst"): + return "zstdcat |" return "" diff --git a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py index 44f1214c4..d7a7e357c 100644 --- a/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py +++ b/python/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py @@ -242,19 +242,24 @@ def stop(self): def get_url(self): return "http://exporter" - client.http = DummyService() - client.tftp = DummyService() - client.call = lambda *args, **kwargs: None + client.http = DummyService() # ty: ignore[unresolved-attribute] + client.tftp = DummyService() # ty: ignore[unresolved-attribute] + client.call = lambda *args, **kwargs: None # ty: ignore[invalid-assignment] captured = {} - def capture_perform(*args): - captured["image_url"] = args[3] - captured["should_download_to_httpd"] = args[4] - captured["oci_username"] = args[14] - captured["oci_password"] = args[15] + def capture_perform( + partition, block_device, path, image_url, should_download_to_httpd, + storage_thread, error_queue, cacert_file, insecure_tls, headers, + bearer_token, method, fls_version, fls_binary_url, + oci_username, oci_password, power_off=True, + ): + captured["image_url"] = image_url + captured["should_download_to_httpd"] = should_download_to_httpd + captured["oci_username"] = oci_username + captured["oci_password"] = oci_password - client._perform_flash_operation = capture_perform + client._perform_flash_operation = capture_perform # ty: ignore[invalid-assignment] client.flash( "https://example.com/image.raw.xz", @@ -428,6 +433,159 @@ def test_categorize_exception_preserves_cause_for_wrapped_exceptions(): assert "File not found" in str(result) +def test_filename_strips_query_params_from_url_path(): + """Test _filename strips query parameters from paths with signed URL params""" + client = MockFlasherClient() + + # Full HTTP URL + assert client._filename("https://cdn.example.com/images/image.raw.xz") == "image.raw.xz" + + # Full HTTP URL with query parameters (e.g. CloudFront signed URL) + assert ( + client._filename("https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz") + == "image.raw.xz" + ) + + # Path string with query parameters (as returned by operator_for_path after fix) + assert client._filename("/images/image.raw.xz?Expires=123&Signature=abc") == "image.raw.xz" + + # Plain path without query parameters + assert client._filename("/images/image.raw.xz") == "image.raw.xz" + + # OCI path + assert client._filename("oci://quay.io/org/myimage:latest") == "myimage-latest" + + +def test_decompression_command_with_query_params(): + """Test _get_decompression_command handles paths with query parameters""" + from pathlib import PosixPath + + from .client import _get_decompression_command + + # Standard PosixPath + assert _get_decompression_command(PosixPath("/images/image.raw.xz")) == "xzcat |" + assert _get_decompression_command(PosixPath("/images/image.raw.gz")) == "zcat |" + assert _get_decompression_command(PosixPath("/images/image.raw")) == "" + + # Full HTTP URL + assert _get_decompression_command("https://cdn.example.com/images/image.raw.xz") == "xzcat |" + + # Zstandard compression + assert _get_decompression_command(PosixPath("/images/image.raw.zst")) == "zstdcat |" + assert _get_decompression_command("https://cdn.example.com/images/image.raw.zst") == "zstdcat |" + + # String path with query parameters (as returned by operator_for_path for signed URLs) + assert _get_decompression_command("/images/image.raw.xz?Expires=123&Signature=abc") == "xzcat |" + assert _get_decompression_command("/images/image.raw.gz?Expires=123") == "zcat |" + assert _get_decompression_command("/images/image.raw.zst?Expires=123") == "zstdcat |" + assert _get_decompression_command("/images/image.raw?Expires=123") == "" + + +def test_flash_signed_url_preserves_query_params(): + """Test that flash with a signed HTTP URL preserves query parameters for image_url""" + client = MockFlasherClient() + + class DummyService: + def __init__(self): + self.storage = object() + + def start(self): + pass + + def stop(self): + pass + + def get_url(self): + return "http://exporter" + + client.http = DummyService() # ty: ignore[unresolved-attribute] + client.tftp = DummyService() # ty: ignore[unresolved-attribute] + client.call = lambda *args, **kwargs: None # ty: ignore[invalid-assignment] + + captured = {} + + def capture_perform( + partition, block_device, path, image_url, should_download_to_httpd, + storage_thread, error_queue, cacert_file, insecure_tls, headers, + bearer_token, method, fls_version, fls_binary_url, + oci_username, oci_password, power_off=True, + ): + captured["image_url"] = image_url + captured["should_download_to_httpd"] = should_download_to_httpd + + client._perform_flash_operation = capture_perform # ty: ignore[invalid-assignment] + + # Direct HTTP URL with query params (no force_exporter_http) should preserve full URL + signed_url = "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + client.flash(signed_url, method="fls", fls_version="") + + assert captured["image_url"] == signed_url + assert captured["should_download_to_httpd"] is False + + +def test_flash_bearer_token_signed_url_preserves_query_params(): + """Test that flash with force_exporter_http=True and bearer token preserves query params. + + When a signed URL is used with a bearer token, the flash() method enters the + bearer token code path (lines 162-174 in client.py) which reconstructs the path + from parsed.path + '?' + parsed.query. This test verifies query params are preserved + and the path passed to the storage thread is correct. + """ + client = MockFlasherClient() + + class DummyService: + def __init__(self): + self.storage = object() + + def start(self): + pass + + def stop(self): + pass + + def get_url(self): + return "http://exporter" + + def get_host(self): + return "127.0.0.1" + + client.http = DummyService() # ty: ignore[unresolved-attribute] + client.tftp = DummyService() # ty: ignore[unresolved-attribute] + client.call = lambda *args, **kwargs: None # ty: ignore[invalid-assignment] + + captured = {} + + def capture_perform( + partition, block_device, path, image_url, should_download_to_httpd, + storage_thread, error_queue, cacert_file, insecure_tls, headers, + bearer_token, method, fls_version, fls_binary_url, + oci_username, oci_password, power_off=True, + ): + captured["path"] = path + captured["image_url"] = image_url + captured["should_download_to_httpd"] = should_download_to_httpd + + client._perform_flash_operation = capture_perform # ty: ignore[invalid-assignment] + # Mock the background transfer thread to prevent it from actually running + client._transfer_bg_thread = lambda *args, **kwargs: None # ty: ignore[invalid-assignment] + + signed_url = "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + client.flash( + signed_url, + force_exporter_http=True, + bearer_token="test-token-123", + method="fls", + fls_version="", + ) + + # With force_exporter_http=True and bearer_token, should download to httpd + assert captured["should_download_to_httpd"] is True + # The path should have query params preserved (reconstructed from parsed.path + '?' + parsed.query) + assert captured["path"] == "/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + # The image_url should point to the exporter with the clean filename (no query params) + assert captured["image_url"] == "http://exporter/image.raw.xz" + + def test_resolve_flash_parameters(): """Test flash parameter resolution for single file, partitions, and error cases""" client = MockFlasherClient() diff --git a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index eca6aa7d1..dea62551f 100644 --- a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -44,17 +44,42 @@ async def aclose(self): pass +def clean_filename(path: PathBuf) -> str: + """Extract a clean filename from a path or URL, stripping query parameters. + + Handles paths returned by operator_for_path() which may contain + query parameters for signed URLs (e.g. /path/to/image.raw.xz?Expires=...&Signature=...). + """ + path_str = str(path) + if path_str.startswith(("http://", "https://")): + return urlparse(path_str).path.split("/")[-1] + if "?" in path_str: + path_str = path_str.split("?", 1)[0] + return Path(path_str).name + + +def path_with_query(parsed_url) -> str: + """Reconstruct path preserving query parameters for signed URL support.""" + if parsed_url.query: + return f"{parsed_url.path}?{parsed_url.query}" + return parsed_url.path + + def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]: - """Create an operator for the given path + """Create an operator for the given path. + + For HTTP URLs, query parameters are preserved in the returned path so that + signed URLs (e.g. CloudFront with Expires/Signature/Key-Pair-Id) work correctly. + Return a tuple of: - - the path + - the path (str for HTTP, Path for filesystem) - the operator for the given path - - the scheme of the operator. + - the scheme of the operator """ if type(path) is str and path.startswith(("http://", "https://")): parsed_url = urlparse(path) operator = Operator("http", root="/", endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}") - return Path(parsed_url.path), operator, "http" + return path_with_query(parsed_url), operator, "http" else: return Path(path).resolve(), Operator("fs", root="/"), "fs" diff --git a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py index 2668b1760..bf13ac753 100644 --- a/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py +++ b/python/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py @@ -322,3 +322,66 @@ def test_copy_and_rename_tracking(tmp_path): assert "copied_dir" in created_paths assert "renamed_dir" in created_paths assert len(created_paths) == 4 + + +def test_clean_filename(): + """Test clean_filename extracts filenames and strips query parameters""" + from pathlib import PosixPath + + from .client import clean_filename + + # Plain filesystem path + assert clean_filename("/images/image.raw.xz") == "image.raw.xz" + assert clean_filename(PosixPath("/images/image.raw.xz")) == "image.raw.xz" + + # Filesystem path with query params (as returned by operator_for_path for signed URLs) + assert clean_filename("/images/image.raw.xz?Expires=123&Signature=abc") == "image.raw.xz" + + # Full HTTP URL without query params + assert clean_filename("https://cdn.example.com/images/image.raw.xz") == "image.raw.xz" + assert clean_filename("http://cdn.example.com/images/image.raw.xz") == "image.raw.xz" + + # Full HTTP URL with query params (e.g. CloudFront signed URL) + assert ( + clean_filename("https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz") + == "image.raw.xz" + ) + + # Edge case: no directory component + assert clean_filename("image.raw.xz") == "image.raw.xz" + assert clean_filename("image.raw.xz?Expires=123") == "image.raw.xz" + + # Edge case: compressed extensions + assert clean_filename("/path/to/image.raw.gz?token=abc") == "image.raw.gz" + assert clean_filename("/path/to/image.raw.gzip?token=abc") == "image.raw.gzip" + + # Edge case: query params with unencoded slashes (e.g. base64 signatures) + assert clean_filename("/images/image.raw.xz?Expires=123&Signature=abc/def/ghi") == "image.raw.xz" + + +def test_operator_for_path_preserves_query_params(): + """Test that operator_for_path preserves query parameters for HTTP URLs""" + from .client import operator_for_path + + # HTTP URL without query parameters + path, operator, scheme = operator_for_path("https://cdn.example.com/images/image.raw.xz") + assert scheme == "http" + assert path == "/images/image.raw.xz" + + # HTTP URL with query parameters (e.g. CloudFront signed URL) + path, operator, scheme = operator_for_path( + "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + ) + assert scheme == "http" + assert path == "/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz" + assert "Expires=123" in path + assert "Signature=abc" in path + assert "Key-Pair-Id=xyz" in path + + # Filesystem path (use resolve() for the expected value since macOS + # resolves /tmp to /private/tmp) + from pathlib import Path + + path, operator, scheme = operator_for_path("/tmp/image.raw.xz") + assert scheme == "fs" + assert path == Path("/tmp/image.raw.xz").resolve() diff --git a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py index 4bd10cfde..55bdfcaa0 100644 --- a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py +++ b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client.py @@ -5,7 +5,7 @@ import click from jumpstarter_driver_composite.client import CompositeClient -from jumpstarter_driver_opendal.client import FlasherClient, operator_for_path +from jumpstarter_driver_opendal.client import FlasherClient, clean_filename, operator_for_path from jumpstarter_driver_power.client import PowerClient from opendal import Operator @@ -39,7 +39,7 @@ def _upload_file_if_needed(self, file_path: str, operator: Operator | None = Non path_buf = Path(file_path) operator_scheme = "unknown" - filename = Path(path_buf).name + filename = clean_filename(path_buf) if self._should_upload_file(self.storage, filename, path_buf, operator, operator_scheme): if operator_scheme == "http": diff --git a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client_test.py b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client_test.py index 125f0b103..349f5150f 100644 --- a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client_test.py +++ b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/client_test.py @@ -161,3 +161,17 @@ def test_flash_no_target_no_partition_spec(ridesx_client): """Non-OCI path without colon or target should give a generic helpful error""" with pytest.raises(click.ClickException, match="requires a target partition"): ridesx_client.flash("/path/to/boot.img") + + +def test_upload_file_if_needed_strips_query_params(ridesx_client): + """Verify _upload_file_if_needed produces a clean filename for signed URLs""" + from jumpstarter_driver_opendal.client import clean_filename + + # Simulate the path_buf that would come from operator_for_path with a signed URL + path_with_query = "/images/image.raw.xz?Expires=123&Signature=abc/def&Key-Pair-Id=xyz" + result = clean_filename(path_with_query) + assert result == "image.raw.xz" + + # Also verify the direct path case + result = clean_filename("/images/image.raw.xz") + assert result == "image.raw.xz"