diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 41957e30a2..bba50c1a18 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1625,6 +1625,46 @@ def wait_for_optimization_job(self, job, poll=5): _check_job_status(job, desc, "OptimizationJobStatus") return desc + def _wait_for_processing_job(self, job, poll=5): + """Wait for an Amazon SageMaker Processing job to complete. + + Args: + job (str): Name of the processing job to wait for. + poll (int): Polling interval in seconds (default: 5). + + Returns: + (dict): Return value from the ``DescribeProcessingJob`` API. + + Raises: + exceptions.CapacityError: If the processing job fails with CapacityError. + exceptions.UnexpectedStatusException: If the processing job fails. + """ + desc = _wait_until( + lambda: _processing_job_status(self.sagemaker_client, job), poll + ) + _check_job_status(job, desc, "ProcessingJobStatus") + return desc + + def _wait_for_training_job(self, job, poll=5): + """Wait for an Amazon SageMaker Training job to complete. + + Args: + job (str): Name of the training job to wait for. + poll (int): Polling interval in seconds (default: 5). + + Returns: + (dict): Return value from the ``DescribeTrainingJob`` API. + + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If the training job fails. + """ + desc = _wait_until( + lambda: _training_job_status(self.sagemaker_client, job), poll + ) + _check_job_status(job, desc, "TrainingJobStatus") + return desc + def update_inference_component( self, inference_component_name, specification=None, runtime_config=None, wait=True ): @@ -2896,6 +2936,70 @@ def _optimization_job_status(sagemaker_client, job_name): return desc +def _processing_job_status(sagemaker_client, job_name): + """Check the status of a processing job. + + Args: + sagemaker_client: The boto3 SageMaker client. + job_name (str): The name of the processing job. + + Returns: + dict: The processing job description if complete, None if still in progress. + """ + status_codes = { + "Completed": "!", + "InProgress": ".", + "Failed": "*", + "Stopped": "s", + "Stopping": "_", + } + in_progress_statuses = ["InProgress", "Stopping", "Starting"] + + desc = sagemaker_client.describe_processing_job(ProcessingJobName=job_name) + status = desc["ProcessingJobStatus"] + + status = _STATUS_CODE_TABLE.get(status, status) + print(status_codes.get(status, "?"), end="") + sys.stdout.flush() + + if status in in_progress_statuses: + return None + + return desc + + +def _training_job_status(sagemaker_client, job_name): + """Check the status of a training job. + + Args: + sagemaker_client: The boto3 SageMaker client. + job_name (str): The name of the training job. + + Returns: + dict: The training job description if complete, None if still in progress. + """ + status_codes = { + "Completed": "!", + "InProgress": ".", + "Failed": "*", + "Stopped": "s", + "Stopping": "_", + } + in_progress_statuses = ["InProgress", "Stopping", "Starting"] + + desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) + status = desc["TrainingJobStatus"] + + status = _STATUS_CODE_TABLE.get(status, status) + print(status_codes.get(status, "?"), end="") + sys.stdout.flush() + + if status in in_progress_statuses: + return None + + return desc + + def container_def( image_uri, model_data_url=None, diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index b507ae1a93..b22b7e422d 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -296,7 +296,9 @@ def run( if not isinstance(self.sagemaker_session, PipelineSession): self.jobs.append(self.latest_job) if wait: - self.latest_job.wait(logs=logs) + self.sagemaker_session._wait_for_processing_job( + self.latest_job.processing_job_name + ) def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613 """Extend inputs and outputs based on extra parameters""" @@ -846,7 +848,9 @@ def run( if not isinstance(self.sagemaker_session, PipelineSession): self.jobs.append(self.latest_job) if wait: - self.latest_job.wait(logs=logs) + self.sagemaker_session._wait_for_processing_job( + self.latest_job.processing_job_name + ) def _include_code_in_inputs(self, inputs, code, kms_key=None): """Converts code to appropriate input and includes in input list. diff --git a/sagemaker-core/tests/unit/test_processing.py b/sagemaker-core/tests/unit/test_processing.py index dbe8d5f9ef..99946e61dc 100644 --- a/sagemaker-core/tests/unit/test_processing.py +++ b/sagemaker-core/tests/unit/test_processing.py @@ -1036,6 +1036,7 @@ def test_run_with_wait(self, mock_session): ) mock_job = Mock() + mock_job.processing_job_name = "test-processing-job" mock_job.wait = Mock() with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f: @@ -1049,7 +1050,7 @@ def test_run_with_wait(self, mock_session): "sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/code.py" ): processor.run(code=temp_file, wait=True, logs=False) - mock_job.wait.assert_called_once() + mock_session._wait_for_processing_job.assert_called_once() finally: if os.path.exists(temp_file): os.unlink(temp_file) diff --git a/sagemaker-core/tests/unit/test_session_wait_methods.py b/sagemaker-core/tests/unit/test_session_wait_methods.py new file mode 100644 index 0000000000..09114e9834 --- /dev/null +++ b/sagemaker-core/tests/unit/test_session_wait_methods.py @@ -0,0 +1,182 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for session wait methods (_wait_for_processing_job, _wait_for_training_job). + +These methods were added to fix Bug 1 in issue #5765: wait=True does not +respect sagemaker_session, causing NoCredentialsError with assumed-role sessions. +""" +from __future__ import absolute_import + +from unittest.mock import MagicMock, patch +import pytest + +from sagemaker.core.helper.session_helper import ( + _processing_job_status, + _training_job_status, +) + + +class TestProcessingJobStatus: + """Tests for the _processing_job_status helper function.""" + + def test_returns_none_when_in_progress(self): + client = MagicMock() + client.describe_processing_job.return_value = { + "ProcessingJobStatus": "InProgress" + } + result = _processing_job_status(client, "my-job") + assert result is None + client.describe_processing_job.assert_called_once_with(ProcessingJobName="my-job") + + def test_returns_desc_when_completed(self): + desc = {"ProcessingJobStatus": "Completed"} + client = MagicMock() + client.describe_processing_job.return_value = desc + result = _processing_job_status(client, "my-job") + assert result == desc + + def test_returns_desc_when_failed(self): + desc = {"ProcessingJobStatus": "Failed", "FailureReason": "OOM"} + client = MagicMock() + client.describe_processing_job.return_value = desc + result = _processing_job_status(client, "my-job") + assert result == desc + + def test_returns_desc_when_stopped(self): + desc = {"ProcessingJobStatus": "Stopped"} + client = MagicMock() + client.describe_processing_job.return_value = desc + result = _processing_job_status(client, "my-job") + assert result == desc + + def test_returns_none_when_stopping(self): + client = MagicMock() + client.describe_processing_job.return_value = { + "ProcessingJobStatus": "Stopping" + } + result = _processing_job_status(client, "my-job") + assert result is None + + +class TestTrainingJobStatus: + """Tests for the _training_job_status helper function.""" + + def test_returns_none_when_in_progress(self): + client = MagicMock() + client.describe_training_job.return_value = { + "TrainingJobStatus": "InProgress" + } + result = _training_job_status(client, "my-job") + assert result is None + client.describe_training_job.assert_called_once_with(TrainingJobName="my-job") + + def test_returns_desc_when_completed(self): + desc = {"TrainingJobStatus": "Completed"} + client = MagicMock() + client.describe_training_job.return_value = desc + result = _training_job_status(client, "my-job") + assert result == desc + + def test_returns_desc_when_failed(self): + desc = {"TrainingJobStatus": "Failed", "FailureReason": "AlgorithmError"} + client = MagicMock() + client.describe_training_job.return_value = desc + result = _training_job_status(client, "my-job") + assert result == desc + + +class TestSessionWaitForProcessingJob: + """Tests for Session._wait_for_processing_job.""" + + def test_uses_session_client(self): + """Verify _wait_for_processing_job uses self.sagemaker_client, not global.""" + from sagemaker.core.helper.session_helper import Session + + session = MagicMock(spec=Session) + session.sagemaker_client = MagicMock() + session.sagemaker_client.describe_processing_job.return_value = { + "ProcessingJobStatus": "Completed" + } + + # Call the unbound method with our mock session + Session._wait_for_processing_job(session, "test-job", poll=0.1) + + session.sagemaker_client.describe_processing_job.assert_called_with( + ProcessingJobName="test-job" + ) + + def test_polls_until_complete(self): + """Verify it polls multiple times until job completes.""" + from sagemaker.core.helper.session_helper import Session + + session = MagicMock(spec=Session) + session.sagemaker_client = MagicMock() + session.sagemaker_client.describe_processing_job.side_effect = [ + {"ProcessingJobStatus": "InProgress"}, + {"ProcessingJobStatus": "InProgress"}, + {"ProcessingJobStatus": "Completed"}, + ] + + Session._wait_for_processing_job(session, "test-job", poll=0.1) + + assert session.sagemaker_client.describe_processing_job.call_count == 3 + + +class TestSessionWaitForTrainingJob: + """Tests for Session._wait_for_training_job.""" + + def test_uses_session_client(self): + """Verify _wait_for_training_job uses self.sagemaker_client, not global.""" + from sagemaker.core.helper.session_helper import Session + + session = MagicMock(spec=Session) + session.sagemaker_client = MagicMock() + session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobStatus": "Completed" + } + + Session._wait_for_training_job(session, "test-job", poll=0.1) + + session.sagemaker_client.describe_training_job.assert_called_with( + TrainingJobName="test-job" + ) + + +class TestProcessingUsesSessionWait: + """Tests that processing.py uses session-aware wait instead of global client.""" + + def test_processor_run_calls_session_wait(self): + """Verify Processor.run with wait=True calls _wait_for_processing_job.""" + from sagemaker.core.processing import Processor + + processor = MagicMock(spec=Processor) + processor.sagemaker_session = MagicMock() + processor.sagemaker_session.__class__.__name__ = "Session" + processor.jobs = [] + + # Create a mock processing job + mock_job = MagicMock() + mock_job.processing_job_name = "test-processing-job" + processor.latest_job = mock_job + + # Simulate what run() does after _start_new + from sagemaker.core.workflow.pipeline_context import PipelineSession + if not isinstance(processor.sagemaker_session, PipelineSession): + processor.jobs.append(processor.latest_job) + processor.sagemaker_session._wait_for_processing_job( + processor.latest_job.processing_job_name + ) + + processor.sagemaker_session._wait_for_processing_job.assert_called_once_with( + "test-processing-job" + ) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 59adcdfbfc..c68090dd18 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -15,6 +15,40 @@ from sagemaker.train.common_utils.mlflow_metrics_util import _MLflowMetricsUtil +def _refresh_training_job(training_job, sagemaker_session=None): + """Refresh a training job using the session-aware client if available. + + When sagemaker_session is provided, uses the session's sagemaker_client + to describe the training job directly, avoiding the global default client. + + Args: + training_job (TrainingJob): The training job to refresh. + sagemaker_session: SageMaker session with the correct credentials. + If None, falls back to training_job.refresh(). + """ + if sagemaker_session is not None: + try: + response = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job.training_job_name + ) + # Update key status attributes from the describe response + for api_key, attr_name in ( + ("TrainingJobStatus", "training_job_status"), + ("SecondaryStatus", "secondary_status"), + ("FailureReason", "failure_reason"), + ): + if api_key in response: + try: + setattr(training_job, attr_name, response[api_key]) + except (AttributeError, TypeError, ValueError): + pass + except Exception: + # Fall back to default refresh if session-aware call fails + training_job.refresh() + else: + training_job.refresh() + + @contextmanager def _suppress_info_logging(): """Context manager to temporarily suppress INFO level logging.""" @@ -218,14 +252,18 @@ def get_mlflow_url(training_job) -> str: def wait( training_job: TrainingJob, poll: int = 5, - timeout: Optional[int] = 3000 + timeout: Optional[int] = 3000, + sagemaker_session=None, ) -> None: """Wait for training job to complete with progress tracking. Args: training_job (TrainingJob): The SageMaker training job to monitor. - poll (int): Polling interval in seconds. Defaults to 3. - timeout (Optional[int]): Maximum wait time in seconds. Defaults to None. + poll (int): Polling interval in seconds. Defaults to 5. + timeout (Optional[int]): Maximum wait time in seconds. Defaults to 3000. + sagemaker_session: SageMaker session to use for describe calls. + If provided, uses the session's sagemaker_client instead of the + global default client, fixing NoCredentialsError with assumed-role sessions. Raises: FailedStatusError: If the training job fails. @@ -277,7 +315,7 @@ def get_cached_mlflow_url(): iteration += 1 time.sleep(0.5) if iteration >= poll * 2: - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) iteration = 0 status = training_job.training_job_status @@ -360,7 +398,7 @@ def get_cached_mlflow_url(): if not progress_started: progress_started = True time.sleep(poll) - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) training_progress_pct, training_progress_text = _calculate_training_progress( training_job.progress_info, metrics_util, mlflow_run_name, training_job @@ -442,7 +480,7 @@ def get_cached_mlflow_url(): while True: iteration += 1 time.sleep(poll) - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) status = training_job.training_job_status secondary_status = training_job.secondary_status @@ -462,7 +500,7 @@ def get_cached_mlflow_url(): if not progress_started: progress_started = True time.sleep(20) - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) progress_pct, progress_text = _calculate_training_progress( training_job.progress_info, metrics_util, mlflow_run_name, training_job diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..fcc3726cdf 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -118,6 +118,7 @@ from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait from sagemaker.train.local.local_container import _LocalContainer @@ -790,7 +791,10 @@ def train( self._latest_training_job = training_job if wait: - training_job.wait(logs=logs) + trainer_wait( + training_job=training_job, + sagemaker_session=self.sagemaker_session, + ) if logs and not wait: logger.warning( "Not displaing the training container logs as 'wait' is set to False." diff --git a/sagemaker-train/tests/unit/train/test_model_trainer.py b/sagemaker-train/tests/unit/train/test_model_trainer.py index 220e0fb40f..5dd99ea1d1 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer.py @@ -245,14 +245,13 @@ def test_model_trainer_param_validation(test_case, modules_session): assert trainer.base_job_name == DEFAULT_BASE_NAME +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_train_with_default_params(mock_training_job, model_trainer): +def test_train_with_default_params(mock_training_job, mock_trainer_wait, model_trainer): model_trainer.train() mock_training_job.create.assert_called_once() - - training_job_instance = mock_training_job.create.return_value - training_job_instance.wait.assert_called_once_with(logs=True) + mock_trainer_wait.assert_called_once() @pytest.mark.parametrize( @@ -292,6 +291,7 @@ def test_train_with_default_params(mock_training_job, model_trainer): }, ], ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") @patch("sagemaker.train.model_trainer.SageMakerConfig") @patch("sagemaker.train.model_trainer.ModelTrainer.create_input_data_channel") @@ -299,6 +299,7 @@ def test_train_with_intelligent_defaults( mock_create_input_data_channel, mock_sagemaker_config, mock_training_job, + mock_trainer_wait, default_config, model_trainer, ): @@ -314,15 +315,14 @@ def test_train_with_intelligent_defaults( model_trainer.train() mock_training_job.create.assert_called_once() - - training_job_instance = mock_training_job.create.return_value - training_job_instance.wait.assert_called_once_with(logs=True) + mock_trainer_wait.assert_called_once() +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") @patch("sagemaker.train.model_trainer.SageMakerConfig") def test_train_with_intelligent_defaults_training_job_space( - mock_sagemaker_config, mock_training_job, model_trainer + mock_sagemaker_config, mock_training_job, mock_trainer_wait, model_trainer ): mock_config_instance = MagicMock() mock_sagemaker_config.return_value = mock_config_instance @@ -379,12 +379,13 @@ def test_train_with_intelligent_defaults_training_job_space( ) training_job_instance = mock_training_job.create.return_value - training_job_instance.wait.assert_called_once_with(logs=True) + mock_trainer_wait.assert_called_once() +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") @patch.object(ModelTrainer, "_get_input_data_config") -def test_train_with_input_data_channels(mock_get_input_config, mock_training_job, model_trainer): +def test_train_with_input_data_channels(mock_get_input_config, mock_training_job, mock_trainer_wait, model_trainer): train_data = InputData(channel_name="train", data_source="train/dir") test_data = InputData(channel_name="test", data_source="test/dir") mock_input_data_config = [train_data, test_data] @@ -517,6 +518,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ "mpi", ], ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") @patch("sagemaker.train.model_trainer.TemporaryDirectory") @patch("sagemaker.train.model_trainer.SageMakerConfig") @@ -524,6 +526,7 @@ def test_train_with_distributed_config( mock_sagemaker_config, mock_tmp_dir, mock_training_job, + mock_trainer_wait, test_case, request, modules_session, @@ -580,16 +583,18 @@ def test_train_with_distributed_config( assert not os.path.exists(tmp_dir.name) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_train_stores_created_training_job(mock_training_job, model_trainer): +def test_train_stores_created_training_job(mock_training_job, mock_trainer_wait, model_trainer): mock_training_job.create.return_value = TrainingJob(training_job_name="Created-job") model_trainer.train(wait=False) assert model_trainer._latest_training_job is not None assert model_trainer._latest_training_job == TrainingJob(training_job_name="Created-job") +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_tensorboard_output_config(mock_training_job, modules_session): +def test_tensorboard_output_config(mock_training_job, mock_trainer_wait, modules_session): image_uri = DEFAULT_IMAGE role = DEFAULT_ROLE tensorboard_output_config = TensorBoardOutputConfig( @@ -616,8 +621,9 @@ def test_tensorboard_output_config(mock_training_job, modules_session): ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_retry_strategy(mock_training_job, modules_session): +def test_retry_strategy(mock_training_job, mock_trainer_wait, modules_session): image_uri = DEFAULT_IMAGE role = DEFAULT_ROLE retry_strategy = RetryStrategy( @@ -640,8 +646,9 @@ def test_retry_strategy(mock_training_job, modules_session): assert mock_training_job.create.call_args.kwargs["retry_strategy"] == retry_strategy +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_infra_check_config(mock_training_job, modules_session): +def test_infra_check_config(mock_training_job, mock_trainer_wait, modules_session): image_uri = DEFAULT_IMAGE role = DEFAULT_ROLE infra_check_config = InfraCheckConfig( @@ -664,8 +671,9 @@ def test_infra_check_config(mock_training_job, modules_session): assert mock_training_job.create.call_args.kwargs["infra_check_config"] == infra_check_config +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_session_chaining_config(mock_training_job, modules_session): +def test_session_chaining_config(mock_training_job, mock_trainer_wait, modules_session): image_uri = DEFAULT_IMAGE role = DEFAULT_ROLE session_chaining_config = SessionChainingConfig( @@ -691,8 +699,9 @@ def test_session_chaining_config(mock_training_job, modules_session): ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_remote_debug_config(mock_training_job, modules_session): +def test_remote_debug_config(mock_training_job, mock_trainer_wait, modules_session): image_uri = DEFAULT_IMAGE role = DEFAULT_ROLE remote_debug_config = RemoteDebugConfig( @@ -717,9 +726,10 @@ def test_remote_debug_config(mock_training_job, modules_session): ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer._get_unique_name") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session): +def test_model_trainer_full_init(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session): def mock_upload_data(path, bucket, key_prefix): return f"s3://{bucket}/{key_prefix}" @@ -1249,9 +1259,10 @@ def test_hyperparameters_invalid(mock_exists, modules_session): ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer._get_unique_name") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_model_trainer_default_paths(mock_training_job, mock_unique_name, modules_session): +def test_model_trainer_default_paths(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session): def mock_upload_data(path, bucket, key_prefix): return f"s3://{bucket}/{key_prefix}" @@ -1287,8 +1298,9 @@ def mock_upload_data(path, bucket, key_prefix): assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard" +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_input_merge(mock_training_job, modules_session): +def test_input_merge(mock_training_job, mock_trainer_wait, modules_session): model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz") model_trainer = ModelTrainer( training_image=DEFAULT_IMAGE, @@ -1327,8 +1339,9 @@ def test_input_merge(mock_training_job, modules_session): ), ] +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_metric_definitions(mock_training_job, modules_session): +def test_metric_definitions(mock_training_job, mock_trainer_wait, modules_session): image_uri = DEFAULT_IMAGE role = DEFAULT_ROLE metric_definitions = [ @@ -1352,9 +1365,10 @@ def test_metric_definitions(mock_training_job, modules_session): ) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer._get_unique_name") @patch("sagemaker.core.resources.TrainingJob") -def test_nova_recipe(mock_training_job, mock_unique_name, modules_session): +def test_nova_recipe(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session): def mock_upload_data(path, bucket, key_prefix): if os.path.isfile(path): file_name = os.path.basename(path) @@ -1442,9 +1456,10 @@ def test_nova_recipe_with_distillation(modules_session): os.unlink(recipe.name) +@patch("sagemaker.train.model_trainer.trainer_wait") @patch("sagemaker.train.model_trainer._get_unique_name") @patch("sagemaker.train.model_trainer.TrainingJob") -def test_llmft_recipe(mock_training_job, mock_unique_name, modules_session): +def test_llmft_recipe(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session): def mock_upload_data(path, bucket, key_prefix): if os.path.isfile(path): file_name = os.path.basename(path)