diff --git a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_commands.py b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_commands.py index 84edd991aac..04e902ef1dc 100644 --- a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_commands.py +++ b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_commands.py @@ -5,6 +5,7 @@ import json import os +import random import subprocess import tempfile import time @@ -50,6 +51,92 @@ def __init__(self, method_name): method_name, recording_processors=[KeyReplacer()] ) + def cmd(self, command, checks=None, expect_failure=False): + if (checks and self.is_live and + os.environ.get('AZURE_CLI_TEST_RETRY_PROVISIONING_CHECK') == 'true'): + return self._cmd_with_retry(command, checks, expect_failure) + return super().cmd(command, checks=checks, expect_failure=expect_failure) + + def _is_provisioning_state_check(self, check): + from azure.cli.testsdk.checkers import JMESPathCheck + if not isinstance(check, JMESPathCheck): + return False + return check._query == 'provisioningState' and check._expected_result == 'Succeeded' + + def _should_retry_for_provisioning_state(self, result): + if not hasattr(result, 'get_output_in_json'): + return False, None + data = result.get_output_in_json() + if not isinstance(data, dict) or 'id' not in data: + return False, None + provisioning_state = data.get('provisioningState') + if not provisioning_state: + return False, None + terminal_states = {'Succeeded', 'Failed', 'Canceled'} + if provisioning_state in terminal_states: + return False, None + return True, data['id'] + + def _cmd_with_retry(self, command, checks, expect_failure): + from azure.cli.testsdk.base import execute + import logging + + # Apply kwargs substitution (e.g. {resource_group}) before executing, + # matching what ScenarioTest.cmd() does internally. + command = self._apply_kwargs(command) + result = execute(self.cli_ctx, command, expect_failure=expect_failure) + + # Split checks into provisioning vs everything else + provisioning_checks = [c for c in (checks or []) if self._is_provisioning_state_check(c)] + other_checks = [c for c in (checks or []) if not self._is_provisioning_state_check(c)] + + if provisioning_checks: + should_retry, resource_id = self._should_retry_for_provisioning_state(result) + if should_retry: + initial_data = result.get_output_in_json() + initial_etag = initial_data.get('etag') + last_seen_etag = initial_etag + max_retries = int(os.environ.get('AZURE_CLI_TEST_PROVISIONING_MAX_RETRIES', '10')) + base_delay = float(os.environ.get('AZURE_CLI_TEST_PROVISIONING_BASE_DELAY', '2.0')) + + # Poll with exponential backoff + jitter until terminal state + for attempt in range(max_retries): + delay = base_delay * (2 ** attempt) + random.uniform(0, 1) + time.sleep(delay) + poll_result = execute(self.cli_ctx, f'resource show --ids {resource_id}', expect_failure=False) + poll_data = poll_result.get_output_in_json() + current_provisioning_state = poll_data.get('provisioningState') + current_etag = poll_data.get('etag') + + # Track etag changes to detect external modifications during polling + if current_etag and last_seen_etag and current_etag != last_seen_etag: + logging.warning(f"ETag changed during polling (external modification detected)") + last_seen_etag = current_etag + + if current_provisioning_state == 'Succeeded': + break + elif current_provisioning_state in {'Failed', 'Canceled'}: + raise AssertionError( + f"provisioningState reached terminal failure: {current_provisioning_state}" + ) + else: + # for/else: ran all retries without breaking + final_etag_msg = "" + if initial_etag and last_seen_etag: + final_etag_msg = f" (initial etag: {initial_etag}, final: {last_seen_etag})" + raise TimeoutError( + f"provisioningState did not reach 'Succeeded' after {max_retries} retries. " + f"Final state: {current_provisioning_state}{final_etag_msg}" + ) + + # Provisioning checks already verified via polling, skip re-checking stale result + + # Run all non-provisioning checks against the original result + for check in other_checks: + check(result) + + return result + @classmethod def generate_ssh_keys(cls): """ diff --git a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_provisioning_retry.py b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_provisioning_retry.py new file mode 100644 index 00000000000..ecc39166adc --- /dev/null +++ b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_aks_provisioning_retry.py @@ -0,0 +1,123 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import json +import os +import unittest +from unittest.mock import MagicMock, patch + +from azure.cli.testsdk.checkers import JMESPathCheck + + +class MockExecutionResult: + def __init__(self, output_json): + self._json = output_json + self.output = json.dumps(output_json) + self.json_value = None + + def get_output_in_json(self): + return self._json + + +class TestShouldRetryForProvisioningState(unittest.TestCase): + + def _make_instance(self): + from azure.cli.command_modules.acs.tests.latest.test_aks_commands import ( + AzureKubernetesServiceScenarioTest, + ) + return object.__new__(AzureKubernetesServiceScenarioTest) + + def test_non_terminal_state_returns_true(self): + inst = self._make_instance() + result = MockExecutionResult({ + 'id': '/subscriptions/xxx/resourceGroups/rg/providers/Microsoft.ContainerService/managedClusters/mc', + 'provisioningState': 'Updating', + }) + should_retry, resource_id = inst._should_retry_for_provisioning_state(result) + self.assertTrue(should_retry) + self.assertIn('managedClusters/mc', resource_id) + + def test_succeeded_returns_false(self): + inst = self._make_instance() + result = MockExecutionResult({'id': '/subscriptions/xxx/rg/mc', 'provisioningState': 'Succeeded'}) + should_retry, _ = inst._should_retry_for_provisioning_state(result) + self.assertFalse(should_retry) + + def test_no_id_returns_false(self): + inst = self._make_instance() + result = MockExecutionResult({'provisioningState': 'Updating'}) + should_retry, _ = inst._should_retry_for_provisioning_state(result) + self.assertFalse(should_retry) + + def test_list_response_returns_false(self): + inst = self._make_instance() + result = MockExecutionResult([{'id': '/some/id', 'provisioningState': 'Updating'}]) + should_retry, _ = inst._should_retry_for_provisioning_state(result) + self.assertFalse(should_retry) + + +class TestCmdWithRetry(unittest.TestCase): + + def _make_instance(self): + from azure.cli.command_modules.acs.tests.latest.test_aks_commands import ( + AzureKubernetesServiceScenarioTest, + ) + instance = object.__new__(AzureKubernetesServiceScenarioTest) + instance.kwargs = {} + instance._apply_kwargs = lambda cmd: cmd + instance.cli_ctx = MagicMock() + return instance + + def _result(self, data): + return MockExecutionResult(data) + + @patch.dict(os.environ, {'AZURE_CLI_TEST_PROVISIONING_MAX_RETRIES': '3', 'AZURE_CLI_TEST_PROVISIONING_BASE_DELAY': '0.01'}) + @patch('azure.cli.testsdk.base.execute') + def test_no_retry_when_already_succeeded(self, mock_execute): + mock_execute.return_value = self._result({'id': '/rg/mc', 'provisioningState': 'Succeeded'}) + self._make_instance()._cmd_with_retry('aks show', [JMESPathCheck('provisioningState', 'Succeeded')], False) + mock_execute.assert_called_once() + + @patch.dict(os.environ, {'AZURE_CLI_TEST_PROVISIONING_MAX_RETRIES': '3', 'AZURE_CLI_TEST_PROVISIONING_BASE_DELAY': '0.01'}) + @patch('azure.cli.testsdk.base.execute') + def test_retries_until_succeeded(self, mock_execute): + resource_id = '/subscriptions/xxx/resourceGroups/rg/providers/Microsoft.ContainerService/managedClusters/mc' + mock_execute.side_effect = [ + self._result({'id': resource_id, 'provisioningState': 'Updating'}), + self._result({'provisioningState': 'Updating'}), + self._result({'provisioningState': 'Succeeded'}), + ] + self._make_instance()._cmd_with_retry('aks show', [JMESPathCheck('provisioningState', 'Succeeded')], False) + self.assertEqual(mock_execute.call_count, 3) + + @patch.dict(os.environ, {'AZURE_CLI_TEST_PROVISIONING_MAX_RETRIES': '2', 'AZURE_CLI_TEST_PROVISIONING_BASE_DELAY': '0.01'}) + @patch('azure.cli.testsdk.base.execute') + def test_raises_on_failed_state(self, mock_execute): + mock_execute.side_effect = [ + self._result({'id': '/rg/mc', 'provisioningState': 'Updating'}), + self._result({'provisioningState': 'Failed'}), + ] + with self.assertRaises(AssertionError): + self._make_instance()._cmd_with_retry('aks show', [JMESPathCheck('provisioningState', 'Succeeded')], False) + + @patch.dict(os.environ, {'AZURE_CLI_TEST_PROVISIONING_MAX_RETRIES': '2', 'AZURE_CLI_TEST_PROVISIONING_BASE_DELAY': '0.01'}) + @patch('azure.cli.testsdk.base.execute') + def test_raises_timeout_after_max_retries(self, mock_execute): + poll = self._result({'provisioningState': 'Updating'}) + mock_execute.side_effect = [self._result({'id': '/rg/mc', 'provisioningState': 'Updating'}), poll, poll] + with self.assertRaises(TimeoutError): + self._make_instance()._cmd_with_retry('aks show', [JMESPathCheck('provisioningState', 'Succeeded')], False) + + @patch.dict(os.environ, {'AZURE_CLI_TEST_PROVISIONING_MAX_RETRIES': '3', 'AZURE_CLI_TEST_PROVISIONING_BASE_DELAY': '0.01'}) + @patch('azure.cli.testsdk.base.execute') + def test_non_provisioning_checks_still_run(self, mock_execute): + mock_execute.return_value = self._result({'id': '/rg/mc', 'name': 'mc', 'provisioningState': 'Succeeded'}) + name_check = MagicMock() + self._make_instance()._cmd_with_retry('aks show', [JMESPathCheck('provisioningState', 'Succeeded'), name_check], False) + name_check.assert_called_once() + + +if __name__ == '__main__': + unittest.main()