diff --git a/src/bastion/azext_bastion/_help.py b/src/bastion/azext_bastion/_help.py index fe1380a6981..72a29681fbf 100644 --- a/src/bastion/azext_bastion/_help.py +++ b/src/bastion/azext_bastion/_help.py @@ -51,4 +51,7 @@ - name: Open a tunnel through Azure Bastion to a target virtual machine using its IP address. text: | az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 --resource-port 22 --port 50022 + - name: Open a tunnel through Azure Bastion to a managed cluster (resource-port defaults to 443). + text: | + az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-resource-id managedClusterResourceId --port 50443 """ diff --git a/src/bastion/azext_bastion/_params.py b/src/bastion/azext_bastion/_params.py index faa2379860b..586eb3e9f3c 100644 --- a/src/bastion/azext_bastion/_params.py +++ b/src/bastion/azext_bastion/_params.py @@ -23,8 +23,6 @@ def load_arguments(self, _): # pylint: disable=unused-argument with self.argument_context("network bastion") as c: c.argument("bastion_host_name", bastion_host_name_type, options_list=["--name", "-n"]) - c.argument("resource_port", help="Resource port of the target VM to which the bastion will connect.", - options_list=["--resource-port"]) c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", required=False, options_list=["--target-resource-id"]) c.argument("target_ip_address", help="IP address of target Virtual Machine.", required=False, @@ -46,5 +44,7 @@ def load_arguments(self, _): # pylint: disable=unused-argument "Available on Windows 10 20H2+, Windows 11 21H2+, WS 2022.", arg_type=get_three_state_flag()) with self.argument_context("network bastion tunnel") as c: + c.argument("resource_port", help="Resource port of the target resource to which the bastion will connect. Defaults to 443 for managed clusters.", + options_list=["--resource-port"], required=False) c.argument("port", help="Local port to use for the tunneling.", options_list=["--port"]) c.argument("timeout", help="Timeout for connection to bastion host tunnel.", options_list=["--timeout"]) diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index 82a5fd3520e..33b38376db7 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -31,6 +31,9 @@ logger = get_logger(__name__) +# Default port for managed cluster tunnel connections +DEFAULT_MANAGED_CLUSTER_PORT = 443 + class BastionCreate(_BastionCreate): @classmethod @@ -413,6 +416,13 @@ def _validate_resourceid(target_resource_id): raise InvalidArgumentValueError(err_msg) +def _is_managed_cluster(target_resource_id): + """Check if the target resource is a managed cluster (AKS).""" + if not target_resource_id: + return False + return "microsoft.containerservice/managedclusters" in target_resource_id.lower() + + def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): if bastion['sku']['name'] == BastionSku.QuickConnect.value or bastion['sku']['name'] == BastionSku.Developer.value: from .developer_sku_helper import (_get_data_pod) @@ -466,10 +476,33 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g if ip_connect: target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/" \ f"{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" - - if ip_connect and int(resource_port) not in [22, 3389]: - raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is \ - 22, 3389.") + + # For IP connect, validate resource_port is provided and is valid + if not resource_port: + raise RequiredArgumentMissingError("--resource-port is required for IP connect.") + + try: + port_int = int(resource_port) + except (TypeError, ValueError): + raise InvalidArgumentValueError(f"Invalid resource port: {resource_port}. Must be a valid integer.") + + if port_int not in [22, 3389]: + raise UnrecognizedArgumentError( + "Custom ports are not allowed. Allowed ports for Tunnel with IP connect are 22, 3389.") + else: + # Default resource_port to DEFAULT_MANAGED_CLUSTER_PORT for managed clusters if not provided + if not resource_port and _is_managed_cluster(target_resource_id): + resource_port = DEFAULT_MANAGED_CLUSTER_PORT + + # Validate that resource_port is provided for non-managed cluster targets + if not resource_port: + raise RequiredArgumentMissingError("--resource-port is required for non-managed cluster targets.") + + # Validate that resource_port is a valid integer + try: + int(resource_port) + except (TypeError, ValueError): + raise InvalidArgumentValueError(f"Invalid resource port: {resource_port}. Must be a valid integer.") _validate_resourceid(target_resource_id) bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) diff --git a/src/bastion/azext_bastion/tests/latest/test_bastion.py b/src/bastion/azext_bastion/tests/latest/test_bastion.py index fb0f1829118..69f13a9c8e7 100644 --- a/src/bastion/azext_bastion/tests/latest/test_bastion.py +++ b/src/bastion/azext_bastion/tests/latest/test_bastion.py @@ -7,10 +7,31 @@ # pylint: disable=line-too-long +import unittest from azure.cli.testsdk import * from azure.cli.testsdk.scenario_tests import AllowLargeResponse +class BastionUnitTests(unittest.TestCase): + def test_is_managed_cluster(self): + """Test the _is_managed_cluster helper function""" + from azext_bastion.custom import _is_managed_cluster + + # Test managed cluster resource ID + managed_cluster_id = "/subscriptions/12345678-1234-1234-1234-123456789012/resourceGroups/myRG/providers/Microsoft.ContainerService/managedClusters/myAKS" + self.assertTrue(_is_managed_cluster(managed_cluster_id)) + + # Test VM resource ID + vm_id = "/subscriptions/12345678-1234-1234-1234-123456789012/resourceGroups/myRG/providers/Microsoft.Compute/virtualMachines/myVM" + self.assertFalse(_is_managed_cluster(vm_id)) + + # Test None + self.assertFalse(_is_managed_cluster(None)) + + # Test empty string + self.assertFalse(_is_managed_cluster("")) + + class BastionScenario(ScenarioTest): @AllowLargeResponse(size_kb=9999) @ResourceGroupPreparer(name_prefix="cli_test_bastion_host_", location="eastus")