Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def bootloader_shell(self):
pass
yield self.serial

def flash(
def flash( # noqa: C901
self,
path: PathBuf,
*,
Expand All @@ -80,10 +80,19 @@ def flash(
force_flash_bundle: str | None = None,
cacert_file: str | None = None,
insecure_tls: bool = False,
headers: dict[str, str] | None = None,
bearer_token: str | None = None,
):
if bearer_token:
bearer_token = self._validate_bearer_token(bearer_token)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
if headers:
headers = self._validate_header_dict(headers)

"""Flash image to DUT"""
should_download_to_httpd = True
image_url = ""
original_http_url = None
operator_scheme = None
# initrmafs cannot handle https yet, fallback to using the exporter's http server
if path.startswith(("http://", "https://")) and not force_exporter_http:
Expand All @@ -93,7 +102,17 @@ def flash(
else:
# use the exporter's http server for the flasher image, we should download it first
if operator is None:
path, operator, operator_scheme = operator_for_path(path)
if path.startswith(("http://", "https://")) and bearer_token:
parsed = urlparse(path)
self.logger.info(f"Using Bearer token authentication for {parsed.netloc}")
original_http_url = path
operator = Operator(
"http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token
)
operator_scheme = "http"
path = Path(parsed.path)
else:
path, operator, operator_scheme = operator_for_path(path)
image_url = self.http.get_url() + "/" + path.name

# start counting time for the flash operation
Expand All @@ -106,7 +125,16 @@ def flash(
# Start the storage write operation in the background
storage_thread = threading.Thread(
target=self._transfer_bg_thread,
args=(path, operator, operator_scheme, os_image_checksum, self.http.storage, error_queue, image_url),
args=(
path,
operator,
operator_scheme,
os_image_checksum,
self.http.storage,
error_queue,
original_http_url,
headers,
),
name="storage_transfer",
)
storage_thread.start()
Expand Down Expand Up @@ -151,9 +179,17 @@ def flash(
else:
stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file)


self._flash_with_progress(console, manifest, path, image_url, target_device,
insecure_tls, stored_cacert)
header_args = self._prepare_headers(headers, bearer_token)
self._flash_with_progress(
console,
manifest,
path,
image_url,
target_device,
insecure_tls,
stored_cacert,
header_args,
)

total_time = time.time() - start_time
# total time in minutes:seconds
Expand Down Expand Up @@ -221,7 +257,36 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str:
tls_args += f"--cacert {stored_cacert} "
return tls_args.strip()

def _flash_with_progress(self, console, manifest, path, image_url, target_path, insecure_tls, stored_cacert):
def _curl_header_args(self, headers: dict[str, str] | None) -> str:
"""Generate header arguments for curl command"""
if not headers:
return ""

parts: list[str] = []

def _sq(s: str) -> str:
return s.replace("'", "'\"'\"'")

for k, v in headers.items():
k = str(k).strip()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we prevent/fail if k/v have spaces or something incompatible?, just to avoid the command call crashing and then giving a hard to understand message to the user.

v = str(v).strip()
if not k:
continue
parts.append(f"-H '{_sq(k)}: {_sq(v)}'")

return " ".join(parts)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def _flash_with_progress(
self,
console,
manifest,
path,
image_url,
target_path,
insecure_tls,
stored_cacert,
header_args: str,
):
"""Flash image to target device with progress monitoring.

Args:
Expand All @@ -240,11 +305,11 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
tls_args = self._curl_tls_args(insecure_tls, stored_cacert)

# Check if the image URL is accessible using curl and the TLS arguments
self._check_url_access(console, prompt, image_url, tls_args)
self._check_url_access(console, prompt, image_url, tls_args, header_args)

# Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress
flash_cmd = (
f'( curl -fsSL {tls_args} "{image_url}" | '
f'( curl -fsSL {tls_args} {header_args} "{image_url}" | '
f"{decompress_cmd} "
f"dd of={target_path} bs=64k iflag=fullblock oflag=direct) &"
)
Expand Down Expand Up @@ -286,7 +351,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
console.sendline("sync")
console.expect(prompt, timeout=EXPECT_TIMEOUT_SYNC)

def _check_url_access(self, console, prompt, image_url: str, tls_args: str):
def _check_url_access(self, console, prompt, image_url: str, tls_args: str, header_args: str):
"""Check if the image URL is accessible using curl.

Args:
Expand All @@ -298,7 +363,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str):
Raises:
RuntimeError: If the URL is not accessible
"""
console.sendline(f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} "{image_url}"')
console.sendline(
f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} {header_args} "{image_url}"'
)
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
curl_output = console.before.decode(errors="ignore").strip()
console.sendline("echo $?")
Expand Down Expand Up @@ -357,6 +424,7 @@ def _transfer_bg_thread(
to_storage: OpendalClient,
error_queue,
original_url: str | None = None,
headers: dict[str, str] | None = None,
):
"""Transfer image to exporter storage in the background
Args:
Expand All @@ -366,6 +434,7 @@ def _transfer_bg_thread(
error_queue: Queue to put exceptions in if any
known_hash: Known hash of the image
original_url: Original URL for HTTP fallback
headers: HTTP headers for requests
"""
self.logger.info(f"Writing image to storage in the background: {src_path}")
try:
Expand All @@ -391,7 +460,9 @@ def _transfer_bg_thread(
self.logger.info(f"Uploading image to storage: {filename}")
to_storage.write_from_path(filename, src_path, src_operator)

metadata, metadata_json = self._create_metadata_and_json(src_operator, src_path, file_hash, original_url)
metadata, metadata_json = self._create_metadata_and_json(
src_operator, src_path, file_hash, original_url, headers
)
metadata_file = filename + ".metadata"
to_storage.write_bytes(metadata_file, metadata_json.encode(errors="ignore"))

Expand All @@ -414,7 +485,7 @@ def _sha256_file(self, src_operator, src_path) -> str:
return m.hexdigest()

def _create_metadata_and_json(
self, src_operator, src_path, file_hash=None, original_url=None
self, src_operator, src_path, file_hash=None, original_url=None, headers: dict[str, str] | None = None
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> tuple[Metadata | None, str]:
"""Create a metadata json string from a metadata object"""
metadata = None
Expand All @@ -435,7 +506,10 @@ def _create_metadata_and_json(

if original_url and original_url.startswith(("http://", "https://")):
try:
response = requests.head(original_url)
if headers:
response = requests.head(original_url, headers=headers)
else:
response = requests.head(original_url)
Comment on lines +509 to +512
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Sep 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider timeout and error handling for HEAD requests

The HEAD requests for metadata retrieval don't specify timeouts, which could cause the process to hang on unresponsive servers.

Add timeout and improve error handling:

                    if headers:
-                        response = requests.head(original_url, headers=headers)
+                        response = requests.head(original_url, headers=headers, timeout=30)
                    else:
-                        response = requests.head(original_url)
+                        response = requests.head(original_url, timeout=30)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if headers:
response = requests.head(original_url, headers=headers)
else:
response = requests.head(original_url)
if headers:
response = requests.head(original_url, headers=headers, timeout=30)
else:
response = requests.head(original_url, timeout=30)
🤖 Prompt for AI Agents
In packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py
around lines 488-491, the requests.head call has no timeout and lacks exception
handling; update both HEAD branches to pass a sensible timeout (e.g., timeout=5
or configurable constant), wrap the call in a try/except catching
requests.exceptions.RequestException, handle timeouts and connection errors by
logging a warning/error and returning a safe fallback (e.g., None or an error
response), and avoid raising unhandled exceptions; if HEAD fails consider a
controlled fallback to a GET with the same timeout.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to this patch.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mangelajo You're absolutely right! The timeout and error handling suggestion, while a good practice, is indeed outside the scope of this PR which is specifically focused on adding HTTP headers support. I apologize for the tangential suggestion.

The core functionality you're implementing—passing headers to the requests.head() calls—works correctly as written. Thanks for keeping the PR focused on its intended purpose.

🐰💭


✏️ Learnings added
Learnt from: mangelajo
PR: jumpstarter-dev/jumpstarter#610
File: packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py:488-491
Timestamp: 2025-09-15T08:18:48.549Z
Learning: In the jumpstarter project, code review suggestions should stay focused on the specific scope of the PR. Suggestions about general improvements like timeout handling or error handling that are unrelated to the core changes being made should be avoided, even if they apply to modified code lines.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


http_metadata = {}
if "content-length" in response.headers:
Expand Down Expand Up @@ -610,6 +684,71 @@ def manifest(self):
self._manifest = FlasherBundleManifestV1Alpha1.from_string(yaml_str)
return self._manifest

def _validate_header_dict(self, header_map: dict[str, str]) -> dict[str, str]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

This is probably nitpicking, but a suggestion,

header validation should probably happen inside the flash command, at the start (early failure) if headers provided and return jumpstarter exceptions.

Then in the click commands we can convert those exceptions to click.

This would cover the programatic use of flash commands with python.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

token_re = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$")
seen: set[str] = set()
for key, value in header_map.items():
key = key.strip()
value = value.strip()
if not key:
raise ArgumentError(f"Invalid header key: '{key}'")

if not token_re.match(key):
raise ArgumentError(f"Invalid header name '{key}': must be an HTTP token (RFC7230)")
if any(c in ("\r", "\n") for c in key) or any(c in ("\r", "\n") for c in value):
raise ArgumentError("Header names/values must not contain CR/LF")
kl = key.lower()
if kl in seen:
raise ArgumentError(f"Duplicate header '{key}'")
seen.add(kl)
return header_map

def _parse_headers(self, headers: list[str]) -> dict[str, str]:
header_map: dict[str, str] = {}
for h in headers:
if ":" not in h:
raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.")

key, value = h.split(":", 1)
header_map[key.strip()] = value.strip()

try:
return self._validate_header_dict(header_map)
except ArgumentError as e:
raise click.ClickException(str(e)) from e

def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str:
all_headers = headers.copy() if headers else {}
if bearer_token:
if any(k.lower() == "authorization" for k in all_headers.keys()):
self.logger.warning("Authorization header provided - ignoring bearer token")
else:
all_headers["Authorization"] = f"Bearer {bearer_token}"

if bearer_token and "Authorization" not in (headers or {}):
auth_header = {"Authorization": all_headers["Authorization"]}
self._validate_header_dict(auth_header)

return self._curl_header_args(all_headers)

def _validate_bearer_token(self, token: str | None) -> str | None:
if token is None:
return None

token = token.strip()
if not token:
raise click.ClickException("Bearer token cannot be empty")

# RFC 6750 allows token68 format (base64url-encoded) or other token formats
# Basic validation: printable ASCII excluding whitespace and special chars that could cause issues
if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token):
raise click.ClickException("Bearer token contains invalid characters")

if len(token) > 4096:
raise click.ClickException("Bearer token is too long (max 4096 characters)")

return token

Comment on lines +734 to +751
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use project exception (ArgumentError) instead of Click in driver layer.

Per prior guidance, driver code should raise project exceptions; the CLI should translate to Click. Switch _validate_bearer_token to ArgumentError.

-        if not token:
-            raise click.ClickException("Bearer token cannot be empty")
+        if not token:
+            raise ArgumentError("Bearer token cannot be empty")
@@
-        if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token):
-            raise click.ClickException("Bearer token contains invalid characters")
+        if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token):
+            raise ArgumentError("Bearer token contains invalid characters")
@@
-        if len(token) > 4096:
-            raise click.ClickException("Bearer token is too long (max 4096 characters)")
+        if len(token) > 4096:
+            raise ArgumentError("Bearer token is too long (max 4096 characters)")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _validate_bearer_token(self, token: str | None) -> str | None:
if token is None:
return None
token = token.strip()
if not token:
raise click.ClickException("Bearer token cannot be empty")
# RFC 6750 allows token68 format (base64url-encoded) or other token formats
# Basic validation: printable ASCII excluding whitespace and special chars that could cause issues
if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token):
raise click.ClickException("Bearer token contains invalid characters")
if len(token) > 4096:
raise click.ClickException("Bearer token is too long (max 4096 characters)")
return token
def _validate_bearer_token(self, token: str | None) -> str | None:
if token is None:
return None
token = token.strip()
if not token:
raise ArgumentError("Bearer token cannot be empty")
# RFC 6750 allows token68 format (base64url-encoded) or other token formats
# Basic validation: printable ASCII excluding whitespace and special chars that could cause issues
if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token):
raise ArgumentError("Bearer token contains invalid characters")
if len(token) > 4096:
raise ArgumentError("Bearer token is too long (max 4096 characters)")
return token
🤖 Prompt for AI Agents
In packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py
around lines 734 to 751, replace uses of click.ClickException with the project
ArgumentError exception: import ArgumentError from the project's exceptions
module (or the correct package path) and raise ArgumentError("message") for the
empty, invalid-character, and too-long token cases; keep all current validation
logic and messages unchanged and ensure the module imports ArgumentError at top
so the driver layer raises project-specific exceptions while the CLI layer
remains responsible for translating them to Click errors.

def cli(self):
@click.group
def base():
Expand All @@ -629,6 +768,17 @@ def base():
@click.option("--force-flash-bundle", type=str, help="Force use of a specific flasher OCI bundle")
@click.option("--cacert", type=click.Path(exists=True, dir_okay=False), help="CA certificate to use for HTTPS")
@click.option("--insecure-tls", is_flag=True, help="Skip TLS certificate verification")
@click.option(
"--header",
"header",
multiple=True,
help="Custom HTTP header in 'Key: Value' format",
)
@click.option(
"--bearer",
type=str,
help="Bearer token for HTTP authentication",
)
@debug_console_option
def flash(
file,
Expand All @@ -640,6 +790,8 @@ def flash(
force_flash_bundle,
cacert,
insecure_tls,
header,
bearer,
):
"""Flash image to DUT from file"""
if os_image_checksum_file and os.path.exists(os_image_checksum_file):
Expand All @@ -648,13 +800,18 @@ def flash(
self.logger.info(f"Read checksum from file: {os_image_checksum}")

self.set_console_debug(console_debug)

headers = self._parse_headers(header) if header else None

self.flash(
file,
partition=target,
force_exporter_http=force_exporter_http,
force_flash_bundle=force_flash_bundle,
cacert_file=cacert,
insecure_tls=insecure_tls,
headers=headers,
bearer_token=bearer,
)

@base.command()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import click
import pytest

from .client import BaseFlasherClient
from jumpstarter.common.exceptions import ArgumentError


class MockFlasherClient(BaseFlasherClient):
"""Mock client for testing without full initialization"""

def __init__(self):
self._manifest = None
self._console_debug = False
self.logger = type(
"MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None}
)()

def close(self):
pass


def test_validate_bearer_token_fails_invalid():
"""Test bearer token validation fails with invalid tokens"""
client = MockFlasherClient()

with pytest.raises(click.ClickException, match="Bearer token cannot be empty"):
client._validate_bearer_token("")

with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"):
client._validate_bearer_token("token with spaces")

with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"):
client._validate_bearer_token('token"with"quotes')


def test_curl_header_args_handles_quotes():
"""Test curl header formatting safely handles quotes"""
client = MockFlasherClient()

result = client._curl_header_args({"Authorization": "Bearer abc'def"})
assert "'\"'\"'" in result
assert result.startswith("-H '")
assert result.endswith("'")


def test_flash_fails_with_invalid_headers():
"""Test flash method fails early with invalid headers"""
client = MockFlasherClient()

with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"):
client.flash("test.raw", headers={"Invalid Header": "value"})
Loading