-
Notifications
You must be signed in to change notification settings - Fork 20
qemu: add OCI flashing to qemu driver #555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,28 @@ | ||
| import sys | ||
| from contextlib import contextmanager | ||
|
|
||
| import click | ||
| from jumpstarter_driver_composite.client import CompositeClient | ||
| from jumpstarter_driver_network.adapters import FabricAdapter, NovncAdapter | ||
| from jumpstarter_driver_opendal.client import FlasherClient | ||
|
|
||
|
|
||
| class QemuFlasherClient(FlasherClient): | ||
| """Flasher client for QEMU with OCI support via fls.""" | ||
|
|
||
| def flash(self, path, *, target=None, operator=None, compression=None): | ||
| if isinstance(path, str) and path.startswith("oci://"): | ||
| returncode = 0 | ||
| for stdout, stderr, code in self.streamingcall("flash_oci", path, target): | ||
| if stdout: | ||
| print(stdout, end="", flush=True) | ||
| if stderr: | ||
| print(stderr, end="", file=sys.stderr, flush=True) | ||
| if code is not None: | ||
| returncode = code | ||
| return returncode | ||
|
|
||
| return super().flash(path, target=target, operator=operator, compression=compression) | ||
|
Comment on lines
+13
to
+25
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we do the bifurcation on the driver side it will be easier with a plan that @kirkbrauer was proposing to allow multi-language clients :) |
||
|
|
||
|
|
||
| class QemuClient(CompositeClient): | ||
|
|
@@ -26,6 +46,17 @@ def set_memory_size(self, size: str) -> None: | |
| """Set the memory size for next boot.""" | ||
| self.call("set_memory_size", size) | ||
|
|
||
| def flash_oci(self, oci_url: str, partition: str | None = None): | ||
| """Flash an OCI image to the specified partition using fls. | ||
|
|
||
| Convenience method that delegates to self.flasher.flash(). | ||
|
|
||
| Args: | ||
| oci_url: OCI image reference (must start with oci://) | ||
| partition: Target partition name (default: root) | ||
| """ | ||
| return self.flasher.flash(oci_url, target=partition) | ||
|
|
||
| @contextmanager | ||
| def novnc(self): | ||
| with NovncAdapter(client=self.vnc) as url: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import json | ||
| import logging | ||
| import os | ||
|
|
@@ -25,6 +26,7 @@ | |
| from qemu.qmp import QMPClient | ||
| from qemu.qmp.protocol import ConnectError, Runstate | ||
|
|
||
| from jumpstarter.common.fls import get_fls_binary | ||
| from jumpstarter.driver import Driver, export | ||
| from jumpstarter.streams.encoding import AutoDecompressIterator | ||
|
|
||
|
|
@@ -44,23 +46,158 @@ def filter(self, record): | |
| return False | ||
|
|
||
|
|
||
| async def _read_pipe(stream: asyncio.StreamReader, name: str, queue: asyncio.Queue): | ||
| while True: | ||
| line = await stream.readline() | ||
| if not line: | ||
| break | ||
| await queue.put((name, line.decode("utf-8", errors="replace"))) | ||
| await queue.put((name, None)) | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class QemuFlasher(FlasherInterface, Driver): | ||
| parent: Qemu | ||
|
|
||
| @classmethod | ||
| def client(cls) -> str: | ||
| return "jumpstarter_driver_qemu.client.QemuFlasherClient" | ||
|
|
||
| @export | ||
| async def flash(self, source, partition: str | None = None): | ||
| """Flash an image to the specified partition. | ||
|
|
||
| Accepts OCI image references (oci://...) or streamed image data. | ||
| Supports transparent decompression of gzip, xz, bz2, and zstd compressed images. | ||
| Compression format is auto-detected from file signature. | ||
| """ | ||
| if isinstance(source, str) and source.startswith("oci://"): | ||
| async for _ in self.flash_oci(source, partition): | ||
| pass | ||
| return | ||
|
|
||
| async with await FileWriteStream.from_path(self.parent.validate_partition(partition)) as stream: | ||
| async with self.resource(source) as res: | ||
| # Wrap with auto-decompression to handle .gz, .xz, .bz2, .zstd files | ||
| async for chunk in AutoDecompressIterator(source=res): | ||
| await stream.send(chunk) | ||
|
|
||
| @export | ||
| async def flash_oci( | ||
| self, | ||
| oci_url: str, | ||
| partition: str | None = None, | ||
| oci_username: str | None = None, | ||
| oci_password: str | None = None, | ||
| ) -> AsyncGenerator[tuple[str, str, int | None], None]: | ||
| """Flash an OCI image to the specified partition using fls. | ||
|
|
||
| Streams subprocess output back to the caller as it arrives. | ||
| Yields (stdout_chunk, stderr_chunk, returncode) tuples. | ||
| returncode is None until the process completes. | ||
|
|
||
| Args: | ||
| oci_url: OCI image reference (must start with oci://) | ||
| partition: Target partition name (default: root) | ||
| oci_username: Registry username for OCI authentication | ||
| oci_password: Registry password for OCI authentication | ||
| """ | ||
| if not oci_url.startswith("oci://"): | ||
| raise ValueError(f"OCI URL must start with oci://, got: {oci_url}") | ||
|
|
||
| # Fall back to environment variables for credentials | ||
| if not oci_username: | ||
| oci_username = os.environ.get("OCI_USERNAME") | ||
| if not oci_password: | ||
| oci_password = os.environ.get("OCI_PASSWORD") | ||
|
|
||
| if bool(oci_username) != bool(oci_password): | ||
| raise ValueError("OCI authentication requires both username and password") | ||
|
bennyz marked this conversation as resolved.
|
||
|
|
||
| target_path = str(self.parent.validate_partition(partition)) | ||
|
|
||
| fls_binary = get_fls_binary( | ||
| fls_version=self.parent.fls_version, | ||
| fls_binary_url=self.parent.fls_custom_binary_url, | ||
| allow_custom_binaries=self.parent.fls_allow_custom_binaries, | ||
| ) | ||
|
Comment on lines
+119
to
+123
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would go ahead and start installing the fls binary on the jumpstarter container, I was exactly looking for using fls... :)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh sorry I didn't see it :D |
||
|
|
||
| fls_cmd = [fls_binary, "from-url", oci_url, target_path] | ||
|
|
||
| fls_env = None | ||
| if oci_username and oci_password: | ||
| fls_env = os.environ.copy() | ||
| fls_env["FLS_REGISTRY_USERNAME"] = oci_username | ||
| fls_env["FLS_REGISTRY_PASSWORD"] = oci_password | ||
|
|
||
| self.logger.info(f"Running fls: {' '.join(fls_cmd)}") | ||
|
|
||
| try: | ||
| async for chunk in self._stream_subprocess(fls_cmd, fls_env): | ||
| yield chunk | ||
| except FileNotFoundError: | ||
| raise RuntimeError("fls command not found. Install fls or configure fls_version in the driver.") from None | ||
|
|
||
| async def _stream_subprocess( | ||
| self, cmd: list[str], env: dict[str, str] | None | ||
| ) -> AsyncGenerator[tuple[str, str, int | None], None]: | ||
| """Run a subprocess and yield (stdout, stderr, returncode) tuples as output arrives.""" | ||
| process = await asyncio.create_subprocess_exec( # ty: ignore[missing-argument] | ||
| *cmd, | ||
| stdout=asyncio.subprocess.PIPE, # ty: ignore[unresolved-attribute] | ||
| stderr=asyncio.subprocess.PIPE, # ty: ignore[unresolved-attribute] | ||
| env=env, | ||
| ) | ||
|
|
||
| output_queue: asyncio.Queue[tuple[str, str | None]] = asyncio.Queue() | ||
|
|
||
| tasks = [ | ||
| asyncio.create_task(_read_pipe(process.stdout, "stdout", output_queue)), | ||
| asyncio.create_task(_read_pipe(process.stderr, "stderr", output_queue)), | ||
| ] | ||
|
|
||
| finished_streams = 0 | ||
| start_time = asyncio.get_running_loop().time() | ||
|
|
||
| try: | ||
| while finished_streams < 2: | ||
| elapsed = asyncio.get_running_loop().time() - start_time | ||
| if elapsed >= self.parent.flash_timeout: | ||
| process.kill() | ||
| await process.wait() | ||
| raise RuntimeError(f"fls flash timed out after {self.parent.flash_timeout}s") | ||
|
|
||
| remaining = self.parent.flash_timeout - elapsed | ||
| try: | ||
| name, text = await asyncio.wait_for(output_queue.get(), timeout=min(remaining, 30)) | ||
| except asyncio.TimeoutError: | ||
| continue | ||
|
|
||
| if text is None: | ||
| finished_streams += 1 | ||
| continue | ||
|
|
||
| stdout_chunk = text if name == "stdout" else "" | ||
| stderr_chunk = text if name == "stderr" else "" | ||
| yield stdout_chunk, stderr_chunk, None | ||
|
|
||
| await process.wait() | ||
| returncode = process.returncode | ||
|
|
||
| if returncode != 0: | ||
| self.logger.error(f"fls failed - return code: {returncode}") | ||
| raise RuntimeError(f"fls flash failed (return code {returncode})") | ||
|
|
||
| self.logger.info("OCI flash completed successfully") | ||
| yield "", "", returncode | ||
| finally: | ||
| for task in tasks: | ||
| task.cancel() | ||
| await asyncio.gather(*tasks, return_exceptions=True) | ||
| if process.returncode is None: | ||
| process.kill() | ||
| await process.wait() | ||
|
|
||
| @export | ||
| async def dump(self, target, partition: str | None = None): | ||
| async with await FileReadStream.from_path( | ||
|
|
@@ -300,6 +437,12 @@ class Qemu(Driver): | |
|
|
||
| hostfwd: dict[str, Hostfwd] = field(default_factory=dict) | ||
|
|
||
| # FLS configuration for OCI flashing | ||
| fls_version: str | None = field(default=None) | ||
| fls_allow_custom_binaries: bool = field(default=False) | ||
| fls_custom_binary_url: str | None = field(default=None) | ||
| flash_timeout: int = field(default=30 * 60) # 30 minutes | ||
|
|
||
| _tmp_dir: TemporaryDirectory = field(init=False, default_factory=TemporaryDirectory) | ||
|
|
||
| @classmethod | ||
|
|
@@ -357,7 +500,7 @@ def validate_partition( | |
| case "bios": | ||
| path = Path(self._tmp_dir.name) / "bios" | ||
| case _: | ||
| raise ValueError(f"invalida partition name: {partition}") | ||
| raise ValueError(f"invalid partition name: {partition}") | ||
|
|
||
| if not path.exists() and partition in self.default_partitions and use_default_partitions: | ||
| return self.default_partitions[partition] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.