Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,30 @@

formatter = eachSystem (pkgs: treefmtEval.${pkgs.system}.config.build.wrapper);

checks =
genAttrs linuxSystems (system: {
checks = genAttrs supportedSystems (
system:
let
pkgs = nixpkgs.legacyPackages.${system};
isLinux = pkgs.stdenv.hostPlatform.isLinux;
testSrc = ./upload-ami;
testPython = pkgs.python3.withPackages (
_: self.packages.${system}.upload-ami.propagatedBuildInputs
);
in
{
upload-ami-tests = pkgs.runCommand "upload-ami-tests" { nativeBuildInputs = [ testPython ]; } ''
PYTHONPATH=${testSrc}/src python ${testSrc}/tests/test_register_image.py -v
touch $out
'';
}
// lib.optionalAttrs isLinux {
inherit (self.packages.${system}) upload-ami;
formatting = treefmtEval.${system}.config.build.check self;
})
// {
x86_64-linux.system = self.nixosConfigurations.x86_64-linux.config.system.build.images.amazon;
};
}
// lib.optionalAttrs (system == "x86_64-linux") {
system = self.nixosConfigurations.x86_64-linux.config.system.build.images.amazon;
}
);

devShells = genAttrs supportedSystems (system: {
default = self.packages.${system}.upload-ami;
Expand Down
4 changes: 3 additions & 1 deletion upload-ami/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ delete-images-by-name = "upload_ami.delete_images_by_name:main"
delete-deprecated-images = "upload_ami.delete_deprecated_images:main"
delete-orphaned-snapshots = "upload_ami.delete_orphaned_snapshots:main"
[tool.mypy]
strict=true
strict = true
mypy_path = "src"
explicit_package_bases = true
187 changes: 130 additions & 57 deletions upload-ami/src/upload_ami/upload_ami.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Iterable, Literal, TypedDict
from typing import Any, Iterable, NotRequired, TypedDict, cast
import boto3
import boto3.ec2
import boto3.ec2.createtags
Expand All @@ -18,12 +18,130 @@
from concurrent.futures import ThreadPoolExecutor


class EbsInfo(TypedDict):
VolumeType: str


class BlockDeviceMappingInfo(TypedDict):
DeviceName: str
Ebs: EbsInfo


class RegisterImageInfo(TypedDict):
Architecture: str
BootMode: str
RootDeviceName: str
VirtualizationType: str
EnaSupport: bool
ImdsSupport: str
SriovNetSupport: str
TpmSupport: str | None
BlockDeviceMappings: list[BlockDeviceMappingInfo]


class ImageInfo(TypedDict):
file: str
label: str
system: str
boot_mode: BootModeValuesType
format: str
format: NotRequired[str]
registerImage: RegisterImageInfo


def build_register_image_request(
image_name: str,
register_image: RegisterImageInfo,
snapshot_id: str,
) -> RegisterImageRequestTypeDef:
"""
Build the RegisterImage request body from image-info.json's registerImage object.

Injects Name, SnapshotId, and TagSpecifications (upload-time concerns).
Rejects any attempt to set those fields from the JSON side.
"""
# Validate required keys explicitly so bad JSON yields clear errors
for required in (
"Architecture",
"BootMode",
"RootDeviceName",
"BlockDeviceMappings",
"VirtualizationType",
"EnaSupport",
"ImdsSupport",
"SriovNetSupport",
):
if required not in register_image:
raise ValueError(f"registerImage is missing required field {required!r}")

# Reject upload-time fields that must not come from the NixOS module
for forbidden in ("Name", "TagSpecifications"):
if forbidden in register_image:
raise ValueError(
f"registerImage must not contain {forbidden!r}; "
"that is an upload-time concern owned by amis"
)

mappings = register_image["BlockDeviceMappings"]
if len(mappings) != 1:
raise ValueError(
f"Expected exactly one BlockDeviceMapping, got {len(mappings)}"
)
mapping = mappings[0]
if "DeviceName" not in mapping:
raise ValueError(
"BlockDeviceMappings[0] is missing required field 'DeviceName'"
)
if "Ebs" not in mapping:
raise ValueError("BlockDeviceMappings[0] is missing required field 'Ebs'")
if "VolumeType" not in mapping["Ebs"]:
raise ValueError(
"BlockDeviceMappings[0].Ebs is missing required field 'VolumeType'"
)
if mapping["DeviceName"] != register_image["RootDeviceName"]:
raise ValueError(
f"BlockDeviceMapping DeviceName {mapping['DeviceName']!r} "
f"does not match RootDeviceName {register_image['RootDeviceName']!r}"
)
if "SnapshotId" in mapping["Ebs"]:
raise ValueError(
"registerImage must not contain Ebs.SnapshotId; "
"that is set at upload time"
)

# Build the request explicitly from validated fields
kwargs: dict[str, Any] = {
"Name": image_name,
"Architecture": register_image["Architecture"],
"BootMode": register_image["BootMode"],
"RootDeviceName": register_image["RootDeviceName"],
"VirtualizationType": register_image["VirtualizationType"],
"EnaSupport": register_image["EnaSupport"],
"ImdsSupport": register_image["ImdsSupport"],
"SriovNetSupport": register_image["SriovNetSupport"],
"BlockDeviceMappings": [
{
"DeviceName": mapping["DeviceName"],
"Ebs": {
"VolumeType": mapping["Ebs"]["VolumeType"],
"SnapshotId": snapshot_id,
},
}
],
"TagSpecifications": [
{
"ResourceType": "image",
"Tags": [
{"Key": "Name", "Value": image_name},
{"Key": "ManagedBy", "Value": "NixOS/amis"},
],
}
],
}

if register_image.get("TpmSupport") is not None:
kwargs["TpmSupport"] = register_image["TpmSupport"]

return cast(RegisterImageRequestTypeDef, kwargs)


def upload_to_s3_if_not_exists(
Expand Down Expand Up @@ -129,7 +247,6 @@ def register_image_if_not_exists(
image_info: ImageInfo,
snapshot_id: str,
public: bool,
enable_tpm: bool,
) -> str:
"""
Register image if it doesn't exist yet
Expand All @@ -144,50 +261,15 @@ def register_image_if_not_exists(
assert "ImageId" in describe_images["Images"][0]
image_id = describe_images["Images"][0]["ImageId"]
else:
architecture: Literal["x86_64", "arm64"]
assert "system" in image_info
if image_info["system"] == "x86_64-linux":
architecture = "x86_64"
elif image_info["system"] == "aarch64-linux":
architecture = "arm64"
else:
raise Exception("Unknown system: " + image_info["system"])

register_image_kwargs: RegisterImageRequestTypeDef = {
"Name": image_name,
"Architecture": architecture,
"BootMode": image_info["boot_mode"],
"BlockDeviceMappings": [
{
"DeviceName": "/dev/xvda",
"Ebs": {
"SnapshotId": snapshot_id,
"VolumeType": "gp3",
},
}
],
"RootDeviceName": "/dev/xvda",
"VirtualizationType": "hvm",
"EnaSupport": True,
"ImdsSupport": "v2.0",
"SriovNetSupport": "simple",
"TagSpecifications": [
{
"ResourceType": "image",
"Tags": [
{"Key": "Name", "Value": image_name},
{"Key": "ManagedBy", "Value": "NixOS/amis"},
],
}
],
}

if (
enable_tpm
and architecture == "x86_64"
and image_info["boot_mode"] == "uefi"
):
register_image_kwargs["TpmSupport"] = "v2.0"
if "registerImage" not in image_info:
raise ValueError(
"image-info.json is missing required key 'registerImage'; "
"the image was likely built before amazonImage.registerImage "
"was added to amazon-image.nix"
)
register_image_kwargs = build_register_image_request(
image_name, image_info["registerImage"], snapshot_id
)

logging.info(f"Registering image {image_name} with snapshot {snapshot_id}")

Expand Down Expand Up @@ -322,7 +404,6 @@ def upload_ami(
run_id: str,
public: bool,
dest_regions: list[str],
enable_tpm: bool,
import_role_name: str,
best_effort_regions: list[str] = [],
) -> dict[str, str]:
Expand All @@ -346,7 +427,7 @@ def upload_ami(
)

image_id = register_image_if_not_exists(
ec2, image_name, image_info, snapshot_id, public, enable_tpm
ec2, image_name, image_info, snapshot_id, public
)

image_ids: dict[str, str] = {}
Expand Down Expand Up @@ -392,13 +473,6 @@ def main() -> None:
action="append",
default=[],
)
parser.add_argument(
"--enable-tpm",
action="store_true",
default=False,
help="Enable TPM 2.0 support for UEFI x86_64 images",
)

parser.add_argument(
"--import-role-name",
default="vmimport",
Expand Down Expand Up @@ -428,7 +502,6 @@ def main() -> None:
args.run_id,
args.public,
args.dest_region,
args.enable_tpm,
args.import_role_name,
args.best_effort_region,
)
Expand Down
Loading