From 9dd01677c7ddcdabac0a03273bfa191872e37d86 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Mon, 8 Sep 2025 16:14:40 +0300 Subject: [PATCH 1/6] flashers: support http headers Signed-off-by: Benny Zlotnik (cherry picked from commit e511e62146531b84aa91431dfdfc934dc92399b6) --- .../jumpstarter_driver_flashers/client.py | 105 ++++++++++++++++-- 1 file changed, 94 insertions(+), 11 deletions(-) diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 123e0c005..52aa3974c 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -81,6 +81,7 @@ def flash( force_flash_bundle: str | None = None, cacert_file: str | None = None, insecure_tls: bool = False, + headers: dict[str, str] | None = None, ): """Flash image to DUT""" should_download_to_httpd = True @@ -94,7 +95,15 @@ def flash( else: # use the exporter's http server for the flasher image, we should download it first if operator is None: - path, operator, operator_scheme = operator_for_path(path) + if path.startswith(("http://", "https://")) and headers: + parsed = urlparse(path) + operator = Operator( + "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", headers=headers + ) + operator_scheme = "http" + path = urlparse(path).path + else: + path, operator, operator_scheme = operator_for_path(path) image_url = self.http.get_url() + "/" + path.name # start counting time for the flash operation @@ -152,9 +161,17 @@ def flash( else: stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file) - - self._flash_with_progress(console, manifest, path, image_url, target_device, - insecure_tls, stored_cacert) + header_args = self._curl_header_args(headers) + self._flash_with_progress( + console, + manifest, + path, + image_url, + target_device, + insecure_tls, + stored_cacert, + header_args, + ) total_time = time.time() - start_time # total time in minutes:seconds @@ -222,7 +239,30 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str: tls_args += f"--cacert {stored_cacert} " return tls_args.strip() - def _flash_with_progress(self, console, manifest, path, image_url, target_path, insecure_tls, stored_cacert): + def _curl_header_args(self, headers: dict[str, str] | None) -> str: + """Generate header arguments for curl command.""" + if not headers: + return "" + parts: list[str] = [] + for k, v in headers.items(): + k = str(k).strip() + v = str(v) + if not k: + continue + parts.append(f"-H '{k}: {v}'") + return " ".join(parts) + + def _flash_with_progress( + self, + console, + manifest, + path, + image_url, + target_path, + insecure_tls, + stored_cacert, + header_args: str, + ): """Flash image to target device with progress monitoring. Args: @@ -241,11 +281,11 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path, tls_args = self._curl_tls_args(insecure_tls, stored_cacert) # Check if the image URL is accessible using curl and the TLS arguments - self._check_url_access(console, prompt, image_url, tls_args) + self._check_url_access(console, prompt, image_url, tls_args, header_args) # Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress flash_cmd = ( - f'( curl -fsSL {tls_args} "{image_url}" | ' + f'( curl -fsSL {tls_args} {header_args} "{image_url}" | ' f"{decompress_cmd} " f"dd of={target_path} bs=64k iflag=fullblock oflag=direct) &" ) @@ -287,7 +327,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path, console.sendline("sync") console.expect(prompt, timeout=EXPECT_TIMEOUT_SYNC) - def _check_url_access(self, console, prompt, image_url: str, tls_args: str): + def _check_url_access(self, console, prompt, image_url: str, tls_args: str, header_args: str): """Check if the image URL is accessible using curl. Args: @@ -299,7 +339,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str): Raises: RuntimeError: If the URL is not accessible """ - console.sendline(f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} "{image_url}"') + console.sendline( + f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} {header_args} "{image_url}"' + ) console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT) curl_output = console.before.decode(errors="ignore").strip() console.sendline("echo $?") @@ -415,7 +457,7 @@ def _sha256_file(self, src_operator, src_path) -> str: return m.hexdigest() def _create_metadata_and_json( - self, src_operator, src_path, file_hash=None, original_url=None + self, src_operator, src_path, file_hash=None, original_url=None, headers: dict[str, str] | None = None ) -> tuple[Metadata | None, str]: """Create a metadata json string from a metadata object""" metadata = None @@ -436,7 +478,10 @@ def _create_metadata_and_json( if original_url and original_url.startswith(("http://", "https://")): try: - response = requests.head(original_url) + if headers: + response = requests.head(original_url, headers=headers) + else: + response = requests.head(original_url) http_metadata = {} if "content-length" in response.headers: @@ -611,6 +656,33 @@ def manifest(self): self._manifest = FlasherBundleManifestV1Alpha1.from_string(yaml_str) return self._manifest + def _parse_headers(self, headers: list[str]) -> dict[str, str]: + """Parse header strings into a dict + + Args: + headers: List of header strings in 'Key: Value' format + + Returns: + Dictionary mapping header keys to values + + Raises: + click.ClickException: If header format is invalid + """ + header_map = {} + for h in headers: + if ":" not in h: + raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.") + + key, value = h.split(":", 1) + key = key.strip() + value = value.strip() + + if not key: + raise click.ClickException(f"Invalid header key in: {h!r}") + + header_map[key] = value + return header_map + def cli(self): @driver_click_group(self) def base(): @@ -630,6 +702,12 @@ def base(): @click.option("--force-flash-bundle", type=str, help="Force use of a specific flasher OCI bundle") @click.option("--cacert", type=click.Path(exists=True, dir_okay=False), help="CA certificate to use for HTTPS") @click.option("--insecure-tls", is_flag=True, help="Skip TLS certificate verification") + @click.option( + "--header", + "header", + multiple=True, + help="Custom HTTP header in 'Key: Value' format", + ) @debug_console_option def flash( file, @@ -641,6 +719,7 @@ def flash( force_flash_bundle, cacert, insecure_tls, + header, ): """Flash image to DUT from file""" if os_image_checksum_file and os.path.exists(os_image_checksum_file): @@ -649,6 +728,9 @@ def flash( self.logger.info(f"Read checksum from file: {os_image_checksum}") self.set_console_debug(console_debug) + + headers = self._parse_headers(header) if header else None + self.flash( file, partition=target, @@ -656,6 +738,7 @@ def flash( force_flash_bundle=force_flash_bundle, cacert_file=cacert, insecure_tls=insecure_tls, + headers=headers, ) @base.command() From 05981d8eb38a1706feb755a5350f86332cff9667 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sun, 14 Sep 2025 19:33:53 +0300 Subject: [PATCH 2/6] add --bearer to pass token Signed-off-by: Benny Zlotnik (cherry picked from commit f165ecffa477c5039f30db40ac3b03aab014abc1) --- .../jumpstarter_driver_flashers/client.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 52aa3974c..9bfbf5385 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -82,6 +82,7 @@ def flash( cacert_file: str | None = None, insecure_tls: bool = False, headers: dict[str, str] | None = None, + bearer_token: str | None = None, ): """Flash image to DUT""" should_download_to_httpd = True @@ -95,13 +96,17 @@ def flash( else: # use the exporter's http server for the flasher image, we should download it first if operator is None: - if path.startswith(("http://", "https://")) and headers: + if path.startswith(("http://", "https://")) and bearer_token: parsed = urlparse(path) + self.logger.info(f"Using Bearer token authentication for {parsed.netloc}") operator = Operator( - "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", headers=headers + "http", + root="/", + endpoint=f"{parsed.scheme}://{parsed.netloc}", + token=bearer_token ) operator_scheme = "http" - path = urlparse(path).path + path = Path(urlparse(path).path) else: path, operator, operator_scheme = operator_for_path(path) image_url = self.http.get_url() + "/" + path.name @@ -161,7 +166,10 @@ def flash( else: stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file) - header_args = self._curl_header_args(headers) + all_headers = headers.copy() if headers else {} + if bearer_token: + all_headers["Authorization"] = f"Bearer {bearer_token}" + header_args = self._curl_header_args(all_headers) self._flash_with_progress( console, manifest, @@ -240,7 +248,7 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str: return tls_args.strip() def _curl_header_args(self, headers: dict[str, str] | None) -> str: - """Generate header arguments for curl command.""" + """Generate header arguments for curl command""" if not headers: return "" parts: list[str] = [] @@ -708,6 +716,11 @@ def base(): multiple=True, help="Custom HTTP header in 'Key: Value' format", ) + @click.option( + "--bearer", + type=str, + help="Bearer token for HTTP authentication", + ) @debug_console_option def flash( file, @@ -720,6 +733,7 @@ def flash( cacert, insecure_tls, header, + bearer, ): """Flash image to DUT from file""" if os_image_checksum_file and os.path.exists(os_image_checksum_file): @@ -739,6 +753,7 @@ def flash( cacert_file=cacert, insecure_tls=insecure_tls, headers=headers, + bearer_token=bearer, ) @base.command() From 41cf6cfde1f2da0aa98f428771b100eda531d47b Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Tue, 16 Sep 2025 13:33:42 +0300 Subject: [PATCH 3/6] flashers: validate bearer token Signed-off-by: Benny Zlotnik (cherry picked from commit b0ca00e25ddd93e269ad11f1fb83d0931c119307) --- .../jumpstarter_driver_flashers/client.py | 51 ++++++++++++++----- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 9bfbf5385..f092442eb 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -84,6 +84,9 @@ def flash( headers: dict[str, str] | None = None, bearer_token: str | None = None, ): + if bearer_token: + bearer_token = self._validate_bearer_token(bearer_token) + """Flash image to DUT""" should_download_to_httpd = True image_url = "" @@ -100,10 +103,7 @@ def flash( parsed = urlparse(path) self.logger.info(f"Using Bearer token authentication for {parsed.netloc}") operator = Operator( - "http", - root="/", - endpoint=f"{parsed.scheme}://{parsed.netloc}", - token=bearer_token + "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token ) operator_scheme = "http" path = Path(urlparse(path).path) @@ -166,10 +166,7 @@ def flash( else: stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file) - all_headers = headers.copy() if headers else {} - if bearer_token: - all_headers["Authorization"] = f"Bearer {bearer_token}" - header_args = self._curl_header_args(all_headers) + header_args = self._prepare_headers(headers, bearer_token) self._flash_with_progress( console, manifest, @@ -247,17 +244,29 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str: tls_args += f"--cacert {stored_cacert} " return tls_args.strip() + def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str: + all_headers = headers.copy() if headers else {} + if bearer_token: + all_headers["Authorization"] = f"Bearer {bearer_token}" + return self._curl_header_args(all_headers) + def _curl_header_args(self, headers: dict[str, str] | None) -> str: """Generate header arguments for curl command""" if not headers: return "" + parts: list[str] = [] + + def _sq(s: str) -> str: + return s.replace("'", "'\"'\"'") + for k, v in headers.items(): k = str(k).strip() - v = str(v) + v = str(v).strip() if not k: continue - parts.append(f"-H '{k}: {v}'") + parts.append(f"-H '{_sq(k)}: {_sq(v)}'") + return " ".join(parts) def _flash_with_progress( @@ -666,13 +675,10 @@ def manifest(self): def _parse_headers(self, headers: list[str]) -> dict[str, str]: """Parse header strings into a dict - Args: headers: List of header strings in 'Key: Value' format - Returns: Dictionary mapping header keys to values - Raises: click.ClickException: If header format is invalid """ @@ -800,3 +806,22 @@ def _get_decompression_command(filename_or_url) -> str: elif filename.endswith(".xz"): return "xzcat |" return "" + + +def _validate_bearer_token(self, token: str | None) -> str | None: + if token is None: + return None + + token = token.strip() + if not token: + raise click.ClickException("Bearer token cannot be empty") + + # RFC 6750 allows token68 format (base64url-encoded) or other token formats + # Basic validation: printable ASCII excluding whitespace and special chars that could cause issues + if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token): + raise click.ClickException("Bearer token contains invalid characters") + + if len(token) > 4096: + raise click.ClickException("Bearer token is too long (max 4096 characters)") + + return token From 2e07d6a094da48affae66460b9b5548099ba78ff Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Tue, 16 Sep 2025 14:33:09 +0300 Subject: [PATCH 4/6] flashers: add tests Signed-off-by: Benny Zlotnik (cherry picked from commit 2e92d93b0ef80c09727b7fc77996eae04b08ae08) --- .../jumpstarter_driver_flashers/client.py | 77 +++++++++++++------ .../client_test.py | 39 ++++++++++ 2 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index f092442eb..2e2fce544 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -90,6 +90,7 @@ def flash( """Flash image to DUT""" should_download_to_httpd = True image_url = "" + original_http_url = None operator_scheme = None # initrmafs cannot handle https yet, fallback to using the exporter's http server if path.startswith(("http://", "https://")) and not force_exporter_http: @@ -102,11 +103,12 @@ def flash( if path.startswith(("http://", "https://")) and bearer_token: parsed = urlparse(path) self.logger.info(f"Using Bearer token authentication for {parsed.netloc}") + original_http_url = path operator = Operator( "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token ) operator_scheme = "http" - path = Path(urlparse(path).path) + path = Path(parsed.path) else: path, operator, operator_scheme = operator_for_path(path) image_url = self.http.get_url() + "/" + path.name @@ -121,7 +123,16 @@ def flash( # Start the storage write operation in the background storage_thread = threading.Thread( target=self._transfer_bg_thread, - args=(path, operator, operator_scheme, os_image_checksum, self.http.storage, error_queue, image_url), + args=( + path, + operator, + operator_scheme, + os_image_checksum, + self.http.storage, + error_queue, + original_http_url, + headers, + ), name="storage_transfer", ) storage_thread.start() @@ -247,7 +258,11 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str: def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str: all_headers = headers.copy() if headers else {} if bearer_token: - all_headers["Authorization"] = f"Bearer {bearer_token}" + if any(k.lower() == "authorization" for k in all_headers.keys()): + self.logger.warning("Authorization header provided - ignoring bearer token") + else: + all_headers["Authorization"] = f"Bearer {bearer_token}" + return self._curl_header_args(all_headers) def _curl_header_args(self, headers: dict[str, str] | None) -> str: @@ -417,6 +432,7 @@ def _transfer_bg_thread( to_storage: OpendalClient, error_queue, original_url: str | None = None, + headers: dict[str, str] | None = None, ): """Transfer image to exporter storage in the background Args: @@ -426,6 +442,7 @@ def _transfer_bg_thread( error_queue: Queue to put exceptions in if any known_hash: Known hash of the image original_url: Original URL for HTTP fallback + headers: HTTP headers for requests """ self.logger.info(f"Writing image to storage in the background: {src_path}") try: @@ -451,7 +468,9 @@ def _transfer_bg_thread( self.logger.info(f"Uploading image to storage: {filename}") to_storage.write_from_path(filename, src_path, src_operator) - metadata, metadata_json = self._create_metadata_and_json(src_operator, src_path, file_hash, original_url) + metadata, metadata_json = self._create_metadata_and_json( + src_operator, src_path, file_hash, original_url, headers + ) metadata_file = filename + ".metadata" to_storage.write_bytes(metadata_file, metadata_json.encode(errors="ignore")) @@ -682,7 +701,9 @@ def _parse_headers(self, headers: list[str]) -> dict[str, str]: Raises: click.ClickException: If header format is invalid """ - header_map = {} + token_re = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") + header_map: dict[str, str] = {} + seen: set[str] = set() for h in headers: if ":" not in h: raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.") @@ -694,9 +715,36 @@ def _parse_headers(self, headers: list[str]) -> dict[str, str]: if not key: raise click.ClickException(f"Invalid header key in: {h!r}") + if not token_re.match(key): + raise click.ClickException(f"Invalid header name '{key}': must be an HTTP token (RFC7230)") + if any(c in ("\r", "\n") for c in key) or any(c in ("\r", "\n") for c in value): + raise click.ClickException("Header names/values must not contain CR/LF") + kl = key.lower() + if kl in seen: + raise click.ClickException(f"Duplicate header '{key}'") + seen.add(kl) header_map[key] = value + return header_map + def _validate_bearer_token(self, token: str | None) -> str | None: + if token is None: + return None + + token = token.strip() + if not token: + raise click.ClickException("Bearer token cannot be empty") + + # RFC 6750 allows token68 format (base64url-encoded) or other token formats + # Basic validation: printable ASCII excluding whitespace and special chars that could cause issues + if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token): + raise click.ClickException("Bearer token contains invalid characters") + + if len(token) > 4096: + raise click.ClickException("Bearer token is too long (max 4096 characters)") + + return token + def cli(self): @driver_click_group(self) def base(): @@ -806,22 +854,3 @@ def _get_decompression_command(filename_or_url) -> str: elif filename.endswith(".xz"): return "xzcat |" return "" - - -def _validate_bearer_token(self, token: str | None) -> str | None: - if token is None: - return None - - token = token.strip() - if not token: - raise click.ClickException("Bearer token cannot be empty") - - # RFC 6750 allows token68 format (base64url-encoded) or other token formats - # Basic validation: printable ASCII excluding whitespace and special chars that could cause issues - if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token): - raise click.ClickException("Bearer token contains invalid characters") - - if len(token) > 4096: - raise click.ClickException("Bearer token is too long (max 4096 characters)") - - return token diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py new file mode 100644 index 000000000..cfd16c147 --- /dev/null +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py @@ -0,0 +1,39 @@ +import click +import pytest + +from .client import BaseFlasherClient + + +class MockFlasherClient(BaseFlasherClient): + """Mock client for testing without full initialization""" + + def __init__(self): + self._manifest = None + self._console_debug = False + self.logger = type( + "MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None} + )() + + +def test_validate_bearer_token_fails_invalid(): + """Test bearer token validation fails with invalid tokens""" + client = MockFlasherClient() + + with pytest.raises(click.ClickException, match="Bearer token cannot be empty"): + client._validate_bearer_token("") + + with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"): + client._validate_bearer_token("token with spaces") + + with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"): + client._validate_bearer_token('token"with"quotes') + + +def test_curl_header_args_handles_quotes(): + """Test curl header formatting safely handles quotes""" + client = MockFlasherClient() + + result = client._curl_header_args({"Authorization": "Bearer abc'def"}) + assert "'\"'\"'" in result + assert result.startswith("-H '") + assert result.endswith("'") From 2e9031d07b6c295f9b1ba255cabed0da91112f0f Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Tue, 16 Sep 2025 16:17:21 +0300 Subject: [PATCH 5/6] flashers: unify header parsing Signed-off-by: Benny Zlotnik (cherry picked from commit 8e97288e5f7ffec47226b273883745ac29182162) --- .../jumpstarter_driver_flashers/client.py | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 2e2fce544..f9b284e12 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -255,16 +255,6 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str: tls_args += f"--cacert {stored_cacert} " return tls_args.strip() - def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str: - all_headers = headers.copy() if headers else {} - if bearer_token: - if any(k.lower() == "authorization" for k in all_headers.keys()): - self.logger.warning("Authorization header provided - ignoring bearer token") - else: - all_headers["Authorization"] = f"Bearer {bearer_token}" - - return self._curl_header_args(all_headers) - def _curl_header_args(self, headers: dict[str, str] | None) -> str: """Generate header arguments for curl command""" if not headers: @@ -692,28 +682,14 @@ def manifest(self): self._manifest = FlasherBundleManifestV1Alpha1.from_string(yaml_str) return self._manifest - def _parse_headers(self, headers: list[str]) -> dict[str, str]: - """Parse header strings into a dict - Args: - headers: List of header strings in 'Key: Value' format - Returns: - Dictionary mapping header keys to values - Raises: - click.ClickException: If header format is invalid - """ + def _validate_header_dict(self, header_map: dict[str, str]) -> dict[str, str]: token_re = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") - header_map: dict[str, str] = {} seen: set[str] = set() - for h in headers: - if ":" not in h: - raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.") - - key, value = h.split(":", 1) + for key, value in header_map.items(): key = key.strip() value = value.strip() - if not key: - raise click.ClickException(f"Invalid header key in: {h!r}") + raise click.ClickException(f"Invalid header key: '{key}'") if not token_re.match(key): raise click.ClickException(f"Invalid header name '{key}': must be an HTTP token (RFC7230)") @@ -723,10 +699,30 @@ def _parse_headers(self, headers: list[str]) -> dict[str, str]: if kl in seen: raise click.ClickException(f"Duplicate header '{key}'") seen.add(kl) - header_map[key] = value - return header_map + def _parse_headers(self, headers: list[str]) -> dict[str, str]: + header_map: dict[str, str] = {} + for h in headers: + if ":" not in h: + raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.") + + key, value = h.split(":", 1) + header_map[key.strip()] = value.strip() + + return self._validate_header_dict(header_map) + + def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str: + all_headers = headers.copy() if headers else {} + if bearer_token: + if any(k.lower() == "authorization" for k in all_headers.keys()): + self.logger.warning("Authorization header provided - ignoring bearer token") + else: + all_headers["Authorization"] = f"Bearer {bearer_token}" + + validated_headers = self._validate_header_dict(all_headers) + return self._curl_header_args(validated_headers) + def _validate_bearer_token(self, token: str | None) -> str | None: if token is None: return None From 55cbc007c4ede063e6821d561c8e6c91ed47edbc Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Wed, 17 Sep 2025 11:54:59 +0300 Subject: [PATCH 6/6] flashers: fail early on header parsing validation Signed-off-by: Benny Zlotnik (cherry picked from commit f381274c518a09d9a91c646e311f9bffee75adc4) --- .../jumpstarter_driver_flashers/client.py | 25 +++++++++++++------ .../client_test.py | 12 +++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index f9b284e12..e9d9117b3 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -70,7 +70,7 @@ def bootloader_shell(self): pass yield self.serial - def flash( + def flash( # noqa: C901 self, path: PathBuf, *, @@ -87,6 +87,9 @@ def flash( if bearer_token: bearer_token = self._validate_bearer_token(bearer_token) + if headers: + headers = self._validate_header_dict(headers) + """Flash image to DUT""" should_download_to_httpd = True image_url = "" @@ -689,15 +692,15 @@ def _validate_header_dict(self, header_map: dict[str, str]) -> dict[str, str]: key = key.strip() value = value.strip() if not key: - raise click.ClickException(f"Invalid header key: '{key}'") + raise ArgumentError(f"Invalid header key: '{key}'") if not token_re.match(key): - raise click.ClickException(f"Invalid header name '{key}': must be an HTTP token (RFC7230)") + raise ArgumentError(f"Invalid header name '{key}': must be an HTTP token (RFC7230)") if any(c in ("\r", "\n") for c in key) or any(c in ("\r", "\n") for c in value): - raise click.ClickException("Header names/values must not contain CR/LF") + raise ArgumentError("Header names/values must not contain CR/LF") kl = key.lower() if kl in seen: - raise click.ClickException(f"Duplicate header '{key}'") + raise ArgumentError(f"Duplicate header '{key}'") seen.add(kl) return header_map @@ -710,7 +713,10 @@ def _parse_headers(self, headers: list[str]) -> dict[str, str]: key, value = h.split(":", 1) header_map[key.strip()] = value.strip() - return self._validate_header_dict(header_map) + try: + return self._validate_header_dict(header_map) + except ArgumentError as e: + raise click.ClickException(str(e)) from e def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str: all_headers = headers.copy() if headers else {} @@ -720,8 +726,11 @@ def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | N else: all_headers["Authorization"] = f"Bearer {bearer_token}" - validated_headers = self._validate_header_dict(all_headers) - return self._curl_header_args(validated_headers) + if bearer_token and "Authorization" not in (headers or {}): + auth_header = {"Authorization": all_headers["Authorization"]} + self._validate_header_dict(auth_header) + + return self._curl_header_args(all_headers) def _validate_bearer_token(self, token: str | None) -> str | None: if token is None: diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py index cfd16c147..67282b55d 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py @@ -2,6 +2,7 @@ import pytest from .client import BaseFlasherClient +from jumpstarter.common.exceptions import ArgumentError class MockFlasherClient(BaseFlasherClient): @@ -14,6 +15,9 @@ def __init__(self): "MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None} )() + def close(self): + pass + def test_validate_bearer_token_fails_invalid(): """Test bearer token validation fails with invalid tokens""" @@ -37,3 +41,11 @@ def test_curl_header_args_handles_quotes(): assert "'\"'\"'" in result assert result.startswith("-H '") assert result.endswith("'") + + +def test_flash_fails_with_invalid_headers(): + """Test flash method fails early with invalid headers""" + client = MockFlasherClient() + + with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"): + client.flash("test.raw", headers={"Invalid Header": "value"})