diff --git a/flake.nix b/flake.nix index 8d443ac0..09cfb89d 100644 --- a/flake.nix +++ b/flake.nix @@ -83,14 +83,30 @@ formatter = eachSystem (pkgs: treefmtEval.${pkgs.system}.config.build.wrapper); - checks = - genAttrs linuxSystems (system: { + checks = genAttrs supportedSystems ( + system: + let + pkgs = nixpkgs.legacyPackages.${system}; + isLinux = pkgs.stdenv.hostPlatform.isLinux; + testSrc = ./upload-ami; + testPython = pkgs.python3.withPackages ( + _: self.packages.${system}.upload-ami.propagatedBuildInputs + ); + in + { + upload-ami-tests = pkgs.runCommand "upload-ami-tests" { nativeBuildInputs = [ testPython ]; } '' + PYTHONPATH=${testSrc}/src python ${testSrc}/tests/test_register_image.py -v + touch $out + ''; + } + // lib.optionalAttrs isLinux { inherit (self.packages.${system}) upload-ami; formatting = treefmtEval.${system}.config.build.check self; - }) - // { - x86_64-linux.system = self.nixosConfigurations.x86_64-linux.config.system.build.images.amazon; - }; + } + // lib.optionalAttrs (system == "x86_64-linux") { + system = self.nixosConfigurations.x86_64-linux.config.system.build.images.amazon; + } + ); devShells = genAttrs supportedSystems (system: { default = self.packages.${system}.upload-ami; diff --git a/upload-ami/pyproject.toml b/upload-ami/pyproject.toml index 84f29718..1b80be37 100644 --- a/upload-ami/pyproject.toml +++ b/upload-ami/pyproject.toml @@ -21,4 +21,6 @@ delete-images-by-name = "upload_ami.delete_images_by_name:main" delete-deprecated-images = "upload_ami.delete_deprecated_images:main" delete-orphaned-snapshots = "upload_ami.delete_orphaned_snapshots:main" [tool.mypy] -strict=true +strict = true +mypy_path = "src" +explicit_package_bases = true diff --git a/upload-ami/src/upload_ami/upload_ami.py b/upload-ami/src/upload_ami/upload_ami.py index b4299a9d..637ab453 100644 --- a/upload-ami/src/upload_ami/upload_ami.py +++ b/upload-ami/src/upload_ami/upload_ami.py @@ -2,7 +2,7 @@ import hashlib import logging from pathlib import Path -from typing import Iterable, Literal, TypedDict +from typing import Any, Iterable, NotRequired, TypedDict, cast import boto3 import boto3.ec2 import boto3.ec2.createtags @@ -18,12 +18,130 @@ from concurrent.futures import ThreadPoolExecutor +class EbsInfo(TypedDict): + VolumeType: str + + +class BlockDeviceMappingInfo(TypedDict): + DeviceName: str + Ebs: EbsInfo + + +class RegisterImageInfo(TypedDict): + Architecture: str + BootMode: str + RootDeviceName: str + VirtualizationType: str + EnaSupport: bool + ImdsSupport: str + SriovNetSupport: str + TpmSupport: str | None + BlockDeviceMappings: list[BlockDeviceMappingInfo] + + class ImageInfo(TypedDict): file: str label: str system: str boot_mode: BootModeValuesType - format: str + format: NotRequired[str] + registerImage: RegisterImageInfo + + +def build_register_image_request( + image_name: str, + register_image: RegisterImageInfo, + snapshot_id: str, +) -> RegisterImageRequestTypeDef: + """ + Build the RegisterImage request body from image-info.json's registerImage object. + + Injects Name, SnapshotId, and TagSpecifications (upload-time concerns). + Rejects any attempt to set those fields from the JSON side. + """ + # Validate required keys explicitly so bad JSON yields clear errors + for required in ( + "Architecture", + "BootMode", + "RootDeviceName", + "BlockDeviceMappings", + "VirtualizationType", + "EnaSupport", + "ImdsSupport", + "SriovNetSupport", + ): + if required not in register_image: + raise ValueError(f"registerImage is missing required field {required!r}") + + # Reject upload-time fields that must not come from the NixOS module + for forbidden in ("Name", "TagSpecifications"): + if forbidden in register_image: + raise ValueError( + f"registerImage must not contain {forbidden!r}; " + "that is an upload-time concern owned by amis" + ) + + mappings = register_image["BlockDeviceMappings"] + if len(mappings) != 1: + raise ValueError( + f"Expected exactly one BlockDeviceMapping, got {len(mappings)}" + ) + mapping = mappings[0] + if "DeviceName" not in mapping: + raise ValueError( + "BlockDeviceMappings[0] is missing required field 'DeviceName'" + ) + if "Ebs" not in mapping: + raise ValueError("BlockDeviceMappings[0] is missing required field 'Ebs'") + if "VolumeType" not in mapping["Ebs"]: + raise ValueError( + "BlockDeviceMappings[0].Ebs is missing required field 'VolumeType'" + ) + if mapping["DeviceName"] != register_image["RootDeviceName"]: + raise ValueError( + f"BlockDeviceMapping DeviceName {mapping['DeviceName']!r} " + f"does not match RootDeviceName {register_image['RootDeviceName']!r}" + ) + if "SnapshotId" in mapping["Ebs"]: + raise ValueError( + "registerImage must not contain Ebs.SnapshotId; " + "that is set at upload time" + ) + + # Build the request explicitly from validated fields + kwargs: dict[str, Any] = { + "Name": image_name, + "Architecture": register_image["Architecture"], + "BootMode": register_image["BootMode"], + "RootDeviceName": register_image["RootDeviceName"], + "VirtualizationType": register_image["VirtualizationType"], + "EnaSupport": register_image["EnaSupport"], + "ImdsSupport": register_image["ImdsSupport"], + "SriovNetSupport": register_image["SriovNetSupport"], + "BlockDeviceMappings": [ + { + "DeviceName": mapping["DeviceName"], + "Ebs": { + "VolumeType": mapping["Ebs"]["VolumeType"], + "SnapshotId": snapshot_id, + }, + } + ], + "TagSpecifications": [ + { + "ResourceType": "image", + "Tags": [ + {"Key": "Name", "Value": image_name}, + {"Key": "ManagedBy", "Value": "NixOS/amis"}, + ], + } + ], + } + + if register_image.get("TpmSupport") is not None: + kwargs["TpmSupport"] = register_image["TpmSupport"] + + return cast(RegisterImageRequestTypeDef, kwargs) def upload_to_s3_if_not_exists( @@ -129,7 +247,6 @@ def register_image_if_not_exists( image_info: ImageInfo, snapshot_id: str, public: bool, - enable_tpm: bool, ) -> str: """ Register image if it doesn't exist yet @@ -144,50 +261,15 @@ def register_image_if_not_exists( assert "ImageId" in describe_images["Images"][0] image_id = describe_images["Images"][0]["ImageId"] else: - architecture: Literal["x86_64", "arm64"] - assert "system" in image_info - if image_info["system"] == "x86_64-linux": - architecture = "x86_64" - elif image_info["system"] == "aarch64-linux": - architecture = "arm64" - else: - raise Exception("Unknown system: " + image_info["system"]) - - register_image_kwargs: RegisterImageRequestTypeDef = { - "Name": image_name, - "Architecture": architecture, - "BootMode": image_info["boot_mode"], - "BlockDeviceMappings": [ - { - "DeviceName": "/dev/xvda", - "Ebs": { - "SnapshotId": snapshot_id, - "VolumeType": "gp3", - }, - } - ], - "RootDeviceName": "/dev/xvda", - "VirtualizationType": "hvm", - "EnaSupport": True, - "ImdsSupport": "v2.0", - "SriovNetSupport": "simple", - "TagSpecifications": [ - { - "ResourceType": "image", - "Tags": [ - {"Key": "Name", "Value": image_name}, - {"Key": "ManagedBy", "Value": "NixOS/amis"}, - ], - } - ], - } - - if ( - enable_tpm - and architecture == "x86_64" - and image_info["boot_mode"] == "uefi" - ): - register_image_kwargs["TpmSupport"] = "v2.0" + if "registerImage" not in image_info: + raise ValueError( + "image-info.json is missing required key 'registerImage'; " + "the image was likely built before amazonImage.registerImage " + "was added to amazon-image.nix" + ) + register_image_kwargs = build_register_image_request( + image_name, image_info["registerImage"], snapshot_id + ) logging.info(f"Registering image {image_name} with snapshot {snapshot_id}") @@ -322,7 +404,6 @@ def upload_ami( run_id: str, public: bool, dest_regions: list[str], - enable_tpm: bool, import_role_name: str, best_effort_regions: list[str] = [], ) -> dict[str, str]: @@ -346,7 +427,7 @@ def upload_ami( ) image_id = register_image_if_not_exists( - ec2, image_name, image_info, snapshot_id, public, enable_tpm + ec2, image_name, image_info, snapshot_id, public ) image_ids: dict[str, str] = {} @@ -392,13 +473,6 @@ def main() -> None: action="append", default=[], ) - parser.add_argument( - "--enable-tpm", - action="store_true", - default=False, - help="Enable TPM 2.0 support for UEFI x86_64 images", - ) - parser.add_argument( "--import-role-name", default="vmimport", @@ -428,7 +502,6 @@ def main() -> None: args.run_id, args.public, args.dest_region, - args.enable_tpm, args.import_role_name, args.best_effort_region, ) diff --git a/upload-ami/tests/test_register_image.py b/upload-ami/tests/test_register_image.py new file mode 100644 index 00000000..b94ec71b --- /dev/null +++ b/upload-ami/tests/test_register_image.py @@ -0,0 +1,141 @@ +"""Tests for build_register_image_request().""" + +import unittest +from typing import Any, cast + +from upload_ami.upload_ami import RegisterImageInfo, build_register_image_request + + +def _valid_register_image() -> dict[str, Any]: + """Return a minimal valid registerImage as a plain dict for easy mutation in tests.""" + return { + "Architecture": "x86_64", + "BootMode": "legacy-bios", + "RootDeviceName": "/dev/xvda", + "VirtualizationType": "hvm", + "EnaSupport": True, + "ImdsSupport": "v2.0", + "SriovNetSupport": "simple", + "TpmSupport": None, + "BlockDeviceMappings": [ + { + "DeviceName": "/dev/xvda", + "Ebs": {"VolumeType": "gp3"}, + } + ], + } + + +def _build(image_name: str, ri: dict[str, Any], snapshot_id: str) -> dict[str, Any]: + """Wrapper that casts the plain dict to RegisterImageInfo for mypy.""" + return dict( + build_register_image_request( + image_name, cast(RegisterImageInfo, ri), snapshot_id + ) + ) + + +class TestBuildRegisterImageRequest(unittest.TestCase): + def test_valid_input(self) -> None: + result = _build("test-image", _valid_register_image(), "snap-123") + self.assertEqual(result["Name"], "test-image") + self.assertEqual(result["Architecture"], "x86_64") + self.assertEqual(result["BootMode"], "legacy-bios") + self.assertEqual( + result["BlockDeviceMappings"][0]["Ebs"]["SnapshotId"], "snap-123" + ) + self.assertEqual(result["BlockDeviceMappings"][0]["Ebs"]["VolumeType"], "gp3") + self.assertNotIn("TpmSupport", result) + + def test_tpm_included_when_set(self) -> None: + ri = _valid_register_image() + ri["TpmSupport"] = "v2.0" + result = _build("test-image", ri, "snap-123") + self.assertEqual(result["TpmSupport"], "v2.0") + + def test_missing_required_field(self) -> None: + for field in ( + "Architecture", + "BootMode", + "RootDeviceName", + "BlockDeviceMappings", + "VirtualizationType", + "EnaSupport", + "ImdsSupport", + "SriovNetSupport", + ): + ri = _valid_register_image() + del ri[field] + with self.assertRaises(ValueError, msg=f"missing {field}") as ctx: + _build("img", ri, "snap-1") + self.assertIn(field, str(ctx.exception)) + + def test_forbidden_name(self) -> None: + ri = _valid_register_image() + ri["Name"] = "sneaky" + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("Name", str(ctx.exception)) + + def test_forbidden_tag_specifications(self) -> None: + ri = _valid_register_image() + ri["TagSpecifications"] = [] + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("TagSpecifications", str(ctx.exception)) + + def test_multiple_mappings(self) -> None: + ri = _valid_register_image() + ri["BlockDeviceMappings"].append( + {"DeviceName": "/dev/sdb", "Ebs": {"VolumeType": "gp3"}} + ) + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("exactly one", str(ctx.exception)) + + def test_missing_device_name(self) -> None: + ri = _valid_register_image() + del ri["BlockDeviceMappings"][0]["DeviceName"] + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("DeviceName", str(ctx.exception)) + + def test_missing_ebs(self) -> None: + ri = _valid_register_image() + del ri["BlockDeviceMappings"][0]["Ebs"] + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("Ebs", str(ctx.exception)) + + def test_missing_volume_type(self) -> None: + ri = _valid_register_image() + del ri["BlockDeviceMappings"][0]["Ebs"]["VolumeType"] + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("VolumeType", str(ctx.exception)) + + def test_device_name_mismatch(self) -> None: + ri = _valid_register_image() + ri["BlockDeviceMappings"][0]["DeviceName"] = "/dev/sda1" + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("does not match", str(ctx.exception)) + + def test_preset_snapshot_id(self) -> None: + ri = _valid_register_image() + ri["BlockDeviceMappings"][0]["Ebs"]["SnapshotId"] = "snap-bad" + with self.assertRaises(ValueError) as ctx: + _build("img", ri, "snap-1") + self.assertIn("SnapshotId", str(ctx.exception)) + + def test_arm64_uefi(self) -> None: + ri = _valid_register_image() + ri["Architecture"] = "arm64" + ri["BootMode"] = "uefi" + result = _build("arm-img", ri, "snap-456") + self.assertEqual(result["Architecture"], "arm64") + self.assertEqual(result["BootMode"], "uefi") + + +if __name__ == "__main__": + unittest.main()