diff --git a/README.md b/README.md index 69b584a..5df9544 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,16 @@ uv run sam-render template.yaml --config samconfig.toml --env dev --env2 stag By default, `Fn::ImportValue` and `{{resolve:secretsmanager:...}}` return mock strings. Provide an AWS profile to fetch real values from your AWS account. ```bash +# Single Env, single acccout uv run sam-render template.yaml --config samconfig.toml --env dev --profile my-aws-profile + +# Compare Envs in a single account +uv run sam-render template.yaml --config samconfig.toml --env dev --env2 dev1 --profile my-aws-profile + +# Compare Envs in different accounts +uv run sam-render template.yaml --config samconfig.toml --env dev --env2 dev1 --profile my-aws-profile --profile2 my-aws-profile2 + + ``` ## **Supported Functions** diff --git a/src/samrenderer/main.py b/src/samrenderer/main.py index d673ab7..4d9a227 100644 --- a/src/samrenderer/main.py +++ b/src/samrenderer/main.py @@ -4,12 +4,26 @@ import sys import argparse import difflib +import json +import asyncio +import subprocess +from botocore.exceptions import ClientError, BotoCoreError try: import tomllib as toml # Python 3.11+ except ImportError: # pragma: no cover import tomli as toml # pip install tomli +# Constants for log levels +LOG_LEVELS = { + "DEBUG": 10, + "INFO": 20, + "WARN": 30, + "WARNING": 30, + "ERROR": 40, + "CRITICAL": 50, +} + # --- 1. YAML Tag Handling --- class CFNLoader(yaml.SafeLoader): @@ -46,8 +60,11 @@ def parse_sam_overrides(override_string): if not override_string: return {} - pattern = re.compile(r"([a-zA-Z0-9]+)=\"([^\"]*)\"") - return dict(pattern.findall(override_string)) + pattern = re.compile(r"([a-zA-Z0-9\-_]+)=(?:\"([^\"]*)\"|([^\s\"]+))") + + matches = pattern.findall(override_string) + + return {m[0]: m[1] or m[2] for m in matches} def load_sam_config(config_path, environment="default"): @@ -73,15 +90,59 @@ def load_sam_config(config_path, environment="default"): return {} -# --- 3. Resolution Logic --- +# --- 3. AWS Login Helper --- +def ensure_sso_login(profile): + """ + Checks if the session for the given profile is valid. + If not, triggers 'aws sso login'. + """ + if not profile: + return + + print(f"Checking credentials for profile '{profile}'...", file=sys.stderr) + try: + session = boto3.Session(profile_name=profile) + sts = session.client("sts") + sts.get_caller_identity() + except (ClientError, BotoCoreError): + print( + f"Credentials expired/invalid for '{profile}'. Running 'aws sso login'...", + file=sys.stderr, + ) + try: + subprocess.check_call(["aws", "sso", "login", "--profile", profile]) + print(f"Login successful for '{profile}'.", file=sys.stderr) + except subprocess.CalledProcessError: + print( + f"Error: Failed to login to AWS SSO for profile '{profile}'.", + file=sys.stderr, + ) + # We don't exit here; we let the renderer fail naturally later if it needs creds + + +# --- 4. Resolution Logic --- class TemplateRenderer: - def __init__(self, template_path, profile=None, region="us-east-1"): + def __init__( + self, + template_path, + profile=None, + region="us-east-1", + env_name="default", + log_level="WARN", + ): with open(template_path, "r") as f: self.t = yaml.load(f, Loader=CFNLoader) self.mappings = self.t.get("Mappings", {}) self.conditions = self.t.get("Conditions", {}) self.resources = self.t.get("Resources", {}) + self.env_name = env_name + self.profile = profile + self.log_level_int = LOG_LEVELS.get(log_level.upper(), 30) + + # Track current resource context for logging + self.current_resource_id = None + self.current_resource_type = None self.context = { "AWS::Region": region, @@ -96,12 +157,60 @@ def __init__(self, template_path, profile=None, region="us-east-1"): if "Default" in p: self.context[name] = p["Default"] + # Initialize clients directly (assumes login handled externally) self.boto_session = ( - boto3.Session(profile_name=profile, region_name=region) if profile else None + boto3.Session(profile_name=self.profile, region_name=region) + if self.profile + else None ) self.cfn_client = ( self.boto_session.client("cloudformation") if self.boto_session else None ) + self.sm_client = ( + self.boto_session.client("secretsmanager") if self.boto_session else None + ) + + def _log(self, operation, key, message=None, level="INFO"): + msg_level_int = LOG_LEVELS.get(level.upper(), 20) + if msg_level_int < self.log_level_int: + return + + entry = { + "level": level, + "operation": operation, + "env": self.env_name, + "profile": self.profile, + "key": key, + } + + if self.current_resource_id: + entry["resource_id"] = self.current_resource_id + if self.current_resource_type: + entry["resource_type"] = self.current_resource_type + + if message: + entry["message"] = message + print(json.dumps(entry), file=sys.stderr) + + def resolve_resources(self): + """Special resolver for the Resources block to track context.""" + resolved = {} + for logical_id, res_def in self.resources.items(): + self.current_resource_id = logical_id + self.current_resource_type = ( + res_def.get("Type") if isinstance(res_def, dict) else None + ) + + # Resolve the resource definition + resolved_val = self.resolve(res_def) + + if resolved_val is not None: + resolved[logical_id] = resolved_val + + # Reset context + self.current_resource_id = None + self.current_resource_type = None + return resolved def resolve(self, node): if isinstance(node, dict): @@ -172,7 +281,6 @@ def resolve(self, node): def _handle_ref(self, ref_key): if ref_key in self.context: result = self.context[ref_key] - # Recursively resolve in case the parameter contains a dynamic reference if isinstance(result, str): return self._resolve_dynamic_reference(result) return result @@ -185,7 +293,6 @@ def _resolve_dynamic_reference(self, text): if not isinstance(text, str): return text - # Pattern for {{resolve:service:...}} pattern = r"\{\{resolve:([^:]+):([^}]+)\}\}" match = re.search(pattern, text) @@ -198,62 +305,78 @@ def _resolve_dynamic_reference(self, text): if service == "secretsmanager": return self._resolve_secretsmanager(reference) - # Unsupported service - return as-is return text def _resolve_secretsmanager(self, reference): """Resolve a Secrets Manager reference.""" - # Parse the reference: secret-id:json-key:version-stage:version-id parts = reference.split(":") secret_id = parts[0] json_key = parts[1] if len(parts) > 1 else None - # Try to get the secret value if we have a boto session if self.boto_session: try: sm_client = self.boto_session.client("secretsmanager") response = sm_client.get_secret_value(SecretId=secret_id) - # Handle binary secrets if "SecretBinary" in response: + self._log("Resolve:SecretsManager", reference, level="INFO") return str(response["SecretBinary"]) - # Handle string secrets secret_string = response.get("SecretString", "") - # If a JSON key is specified, parse and extract if json_key: try: import json secret_data = json.loads(secret_string) if json_key not in secret_data: + self._log( + "Resolve:SecretsManager", + reference, + f"Key {json_key} not found", + level="ERROR", + ) return f"{{Error: Key {json_key} not found in secret {secret_id}}}" + self._log("Resolve:SecretsManager", reference, level="INFO") return secret_data[json_key] except json.JSONDecodeError: + self._log( + "Resolve:SecretsManager", + reference, + "Invalid JSON", + level="ERROR", + ) return f"{{Error: Secret is not valid JSON: {secret_id}}}" + self._log("Resolve:SecretsManager", reference, level="INFO") return secret_string - except Exception: - # Fall through to mock value + except (ClientError, BotoCoreError) as e: + self._log("Resolve:SecretsManager", reference, str(e), level="ERROR") pass - # Return mock value if we can't resolve return f"mock-secret-{secret_id}" def _handle_map(self, args): m_name = self.resolve(args[0]) top = self.resolve(args[1]) sec = self.resolve(args[2]) + + key_str = f"{m_name}.{top}.{sec}" try: - return self.mappings[m_name][top][sec] + val = self.mappings[m_name][top][sec] + self._log("FindInMap", key_str, level="INFO") + return self.resolve(val) except (KeyError, TypeError): if len(args) > 3: default_arg = args[3] if isinstance(default_arg, dict) and "DefaultValue" in default_arg: + self._log("FindInMap", key_str, "Used Default Value", level="WARN") return self.resolve(default_arg["DefaultValue"]) + self._log("FindInMap", key_str, "Used Default Value", level="WARN") return self.resolve(default_arg) + + self._log("FindInMap", key_str, "Key not found", level="ERROR") return f"{{Error: Could not resolve Map {m_name}.{top}.{sec}}}" def _handle_sub(self, args): @@ -263,11 +386,17 @@ def _handle_sub(self, args): def repl(match): var = match.group(1) if var in vars_map: - return str(self.resolve(vars_map[var])) + val = str(self.resolve(vars_map[var])) + self._log("Sub", var, f"Resolved to: {val}", level="DEBUG") + return val if var in self.context: - return str(self.context[var]) + val = str(self.resolve(self.context[var])) + self._log("Sub", var, f"Resolved to: {val}", level="DEBUG") + return val if var in self.resources: - return f"mock-{var.lower()}-id" + val = f"mock-{var.lower()}-id" + self._log("Sub", var, f"Resolved to Mock: {val}", level="DEBUG") + return val return match.group(0) return re.sub(r"\${([^!][^}]*)}", repl, text) @@ -279,8 +408,10 @@ def _handle_import(self, val): exports = self.cfn_client.list_exports() for exp in exports["Exports"]: if exp["Name"] == import_name: + self._log("ImportValue", import_name, level="INFO") return exp["Value"] - except Exception: + except (ClientError, BotoCoreError) as e: + self._log("ImportValue", import_name, str(e), level="ERROR") pass return f"mock-import-{import_name}" @@ -311,7 +442,6 @@ def _handle_split(self, args): return string.split(delim) def _handle_base64(self, val): - # Render as plain text for visibility, or real base64 if preferred resolved = self.resolve(val) return f"[Base64: {resolved}]" @@ -321,7 +451,6 @@ def _handle_length(self, val): def _handle_getazs(self, val): region = self.resolve(val) - # Return mock AZs based on the region string if not region: region = self.context["AWS::Region"] return [f"{region}a", f"{region}b", f"{region}c"] @@ -358,14 +487,18 @@ def _handle_if(self, args): return self.resolve(result_node) -def process(config, env, template, profile): +# --- Processing Logic --- +def process(config, env, template, profile, log_level="WARN"): sam_params = load_sam_config(config, env) region = sam_params.get("AWS::Region", "us-east-1") - renderer = TemplateRenderer(template, profile=profile, region=region) + renderer = TemplateRenderer( + template, profile=profile, region=region, env_name=env, log_level=log_level + ) renderer.context.update(sam_params) - resolved_resources = renderer.resolve(renderer.resources) + # Use resolve_resources to track Logical ID context + resolved_resources = renderer.resolve_resources() output = { "Resources": resolved_resources, @@ -375,8 +508,6 @@ def process(config, env, template, profile): def compare(a, b): - # Convert dictionaries to YAML strings for text comparison - # sort_keys=True is crucial to prevent false diffs from random dict ordering a_lines = yaml.dump(a[1], sort_keys=True).splitlines() b_lines = yaml.dump(b[1], sort_keys=True).splitlines() @@ -388,7 +519,6 @@ def compare(a, b): lineterm="", ) - # ANSI Color Codes RED = "\033[31m" GREEN = "\033[32m" CYAN = "\033[36m" @@ -410,7 +540,7 @@ def compare(a, b): return "\n".join(colored_output) -def main(): +async def async_main(): parser = argparse.ArgumentParser( description="Render CloudFormation/SAM templates.", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -438,18 +568,53 @@ def main(): default=None, ) parser.add_argument("--profile", help="AWS CLI Profile", default=None) + parser.add_argument( + "--profile2", + help="AWS CLI Profile for the second environment (optional). Defaults to --profile.", + default=None, + ) + parser.add_argument( + "--log-level", + help="Set logging level (DEBUG, INFO, WARN, ERROR)", + default="WARN", + type=str.upper, + choices=["DEBUG", "INFO", "WARN", "WARNING", "ERROR"], + ) args = parser.parse_args() - output = process(args.config, args.env, args.template, args.profile) + # Determine profiles for both envs + prof1 = args.profile + prof2 = args.profile2 if args.profile2 else args.profile + + # Check login for any profile that is set + profiles_to_check = set(p for p in [prof1, prof2] if p) + for p in profiles_to_check: + ensure_sso_login(p) if args.env2 is not None: - output2 = process(args.config, args.env2, args.template, args.profile) - diff = compare([args.env, output], [args.env2, output2]) + # Run both process calls in parallel threads + task1 = asyncio.to_thread( + process, args.config, args.env, args.template, prof1, args.log_level + ) + task2 = asyncio.to_thread( + process, args.config, args.env2, args.template, prof2, args.log_level + ) + + output1, output2 = await asyncio.gather(task1, task2) + + diff = compare([args.env, output1], [args.env2, output2]) print(diff) else: + output = await asyncio.to_thread( + process, args.config, args.env, args.template, prof1, args.log_level + ) print(yaml.dump(output)) +def main(): + asyncio.run(async_main()) + + if __name__ == "__main__": main() diff --git a/tests/test_main.py b/tests/test_main.py index 29a9f6d..9d7aa82 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,6 +2,7 @@ import yaml import boto3 from unittest.mock import patch, MagicMock +from botocore.exceptions import ClientError from samrenderer.main import ( TemplateRenderer, parse_sam_overrides, @@ -98,6 +99,62 @@ def test_main_cli_diff(capsys, simple_template, tmp_path): assert "IsProd: true" in captured.out +@patch("samrenderer.main.process") +@patch("samrenderer.main.ensure_sso_login") +def test_cli_profile_fallback(mock_login, mock_process, simple_template, tmp_path): + """Verify that if only --profile is set, it is used for both envs.""" + # Create dummy config + config_file = tmp_path / "samconfig.toml" + config_file.write_text("", encoding="utf-8") + + # Case 1: Only --profile set + args = [ + "sam-render", + simple_template, + "--config", + str(config_file), + "--env", + "dev", + "--env2", + "prod", + "--profile", + "my-profile", + ] + + # Mock process to return simple dicts to avoid diff errors + mock_process.return_value = {"Resources": {}} + + with patch("sys.argv", args): + main() + + # Expect 2 calls to process. Both should have 'my-profile' as the 4th arg + assert mock_process.call_count == 2 + + # Check args of all calls + # call_args_list is [call(args...), call(args...)] + # call args: (config, env, template, profile, log_level) + profiles_used = [c.args[3] for c in mock_process.call_args_list] + assert profiles_used == ["my-profile", "my-profile"] + + # Expect login called once + mock_login.assert_called_once_with("my-profile") + + # Reset mocks for Case 2 + mock_process.reset_mock() + mock_login.reset_mock() + + # Case 2: Both profiles set + args_two = args + ["--profile2", "prod-profile"] + with patch("sys.argv", args_two): + main() + + profiles_used_2 = sorted([c.args[3] for c in mock_process.call_args_list]) + assert profiles_used_2 == ["my-profile", "prod-profile"] + + # Expect login called for both + assert mock_login.call_count == 2 + + def test_compare_function(): """Test the compare logic and ANSI coloring.""" # Setup data with some shared lines (context) and some diffs @@ -242,7 +299,14 @@ def test_import_value_mock_aws(simple_template): assert r.resolve({"Fn::ImportValue": "MyExport"}) == "RealValue" assert r.resolve({"Fn::ImportValue": "Missing"}) == "mock-import-Missing" - mock_client.list_exports.side_effect = Exception("AWS Down") + # Raise a proper ClientError to test the exception handling + error_response = { + "Error": {"Code": "ServiceUnavailable", "Message": "AWS Down"} + } + mock_client.list_exports.side_effect = ClientError( + error_response, "ListExports" + ) + assert r.resolve({"Fn::ImportValue": "MyExport"}) == "mock-import-MyExport"