From f7e8a6e482877f4271f9cc0d8436f3ba1c52eb4c Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 22 Apr 2026 11:52:44 -0700 Subject: [PATCH 1/2] fix: store session on resource instances so instance methods use correct credentials --- .../src/sagemaker/core/config_schema.py | 1729 ++++++++++++----- .../src/sagemaker/core/resources.py | 1532 +++++++++------ .../src/sagemaker/core/shapes/shapes.py | 172 +- .../src/sagemaker/core/tools/codegen.py | 18 +- .../sagemaker/core/tools/resources_codegen.py | 38 +- .../sagemaker/core/tools/shapes_codegen.py | 51 +- .../src/sagemaker/core/tools/templates.py | 23 +- .../tests/integ/test_session_wait_e2e.py | 708 +++++++ .../generated/test_session_propagation.py | 335 ++++ 9 files changed, 3522 insertions(+), 1084 deletions(-) create mode 100644 sagemaker-core/tests/integ/test_session_wait_e2e.py create mode 100644 sagemaker-core/tests/unit/generated/test_session_propagation.py diff --git a/sagemaker-core/src/sagemaker/core/config_schema.py b/sagemaker-core/src/sagemaker/core/config_schema.py index c87ba3d02b..72dc24930f 100644 --- a/sagemaker-core/src/sagemaker/core/config_schema.py +++ b/sagemaker-core/src/sagemaker/core/config_schema.py @@ -4,8 +4,10 @@ "properties": { "SchemaVersion": { "type": "string", - "enum": ["1.0"], - "description": "The schema version of the document.", + "enum": [ + "1.0" + ], + "description": "The schema version of the document." }, "SageMaker": { "type": "object", @@ -21,79 +23,119 @@ "properties": { "training_specification": { "additional_s3_data_source": { - "s3_data_type": {"type": "string"}, - "s3_uri": {"type": "string"}, - "manifest_s3_uri": {"type": "string"}, + "s3_data_type": { + "type": "string" + }, + "s3_uri": { + "type": "string" + }, + "manifest_s3_uri": { + "type": "string" + } } }, "validation_specification": { - "validation_role": {"type": "string"} - }, - }, + "validation_role": { + "type": "string" + } + } + } }, "AutoMLJob": { "type": "object", "properties": { "output_data_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, "auto_ml_job_config": { "security_config": { - "volume_kms_key_id": {"type": "string"}, + "volume_kms_key_id": { + "type": "string" + }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, + "items": { + "type": "string" + } + } + } }, "candidate_generation_config": { - "feature_specification_s3_uri": {"type": "string"} - }, - }, - }, + "feature_specification_s3_uri": { + "type": "string" + } + } + } + } }, "AutoMLJobV2": { "type": "object", "properties": { "output_data_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, "auto_ml_problem_type_config": { "time_series_forecasting_job_config": { - "feature_specification_s3_uri": {"type": "string"} + "feature_specification_s3_uri": { + "type": "string" + } }, "tabular_job_config": { - "feature_specification_s3_uri": {"type": "string"} - }, + "feature_specification_s3_uri": { + "type": "string" + } + } }, "security_config": { - "volume_kms_key_id": {"type": "string"}, + "volume_kms_key_id": { + "type": "string" + }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, + "items": { + "type": "string" + } + } + } }, "auto_ml_compute_config": { "emr_serverless_compute_config": { - "execution_role_arn": {"type": "string"} + "execution_role_arn": { + "type": "string" + } } - }, - }, + } + } }, "Cluster": { "type": "object", @@ -101,158 +143,254 @@ "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } }, - "cluster_role": {"type": "string"}, - }, + "cluster_role": { + "type": "string" + } + } }, "CompilationJob": { "type": "object", "properties": { "model_artifacts": { - "s3_model_artifacts": {"type": "string"} + "s3_model_artifacts": { + "type": "string" + } + }, + "role_arn": { + "type": "string" + }, + "input_config": { + "s3_uri": { + "type": "string" + } }, - "role_arn": {"type": "string"}, - "input_config": {"s3_uri": {"type": "string"}}, "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_location": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } }, "resource_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, - }, + "items": { + "type": "string" + } + } + } + } }, "CustomMonitoringJobDefinition": { "type": "object", "properties": { "custom_monitoring_job_input": { "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, - "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + "ground_truth_s3_input": { + "s3_uri": { + "type": "string" + } + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "custom_monitoring_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, + } + } }, "DataQualityJobDefinition": { "type": "object", "properties": { "data_quality_job_input": { "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "data_quality_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "data_quality_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}}, - "statistics_resource": {"s3_uri": {"type": "string"}}, + "constraints_resource": { + "s3_uri": { + "type": "string" + } + }, + "statistics_resource": { + "s3_uri": { + "type": "string" + } + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, + } + } }, "DeviceFleet": { "type": "object", "properties": { "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_location": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, - "iot_role_alias": {"type": "string"}, - }, + "iot_role_alias": { + "type": "string" + } + } }, "Domain": { "type": "object", "properties": { - "security_group_id_for_domain_boundary": {"type": "string"}, + "security_group_id_for_domain_boundary": { + "type": "string" + }, "default_user_settings": { - "execution_role": {"type": "string"}, + "execution_role": { + "type": "string" + }, "environment_settings": { - "default_s3_artifact_path": {"type": "string"}, - "default_s3_kms_key_id": {"type": "string"}, + "default_s3_artifact_path": { + "type": "string" + }, + "default_s3_kms_key_id": { + "type": "string" + } }, "security_groups": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "sharing_settings": { - "s3_output_path": {"type": "string"}, - "s3_kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "s3_kms_key_id": { + "type": "string" + } }, "canvas_app_settings": { "time_series_forecasting_settings": { - "amazon_forecast_role_arn": {"type": "string"} + "amazon_forecast_role_arn": { + "type": "string" + } }, "model_register_settings": { "cross_account_model_register_role_arn": { @@ -260,278 +398,450 @@ } }, "workspace_settings": { - "s3_artifact_path": {"type": "string"}, - "s3_kms_key_id": {"type": "string"}, + "s3_artifact_path": { + "type": "string" + }, + "s3_kms_key_id": { + "type": "string" + } }, "generative_ai_settings": { - "amazon_bedrock_role_arn": {"type": "string"} + "amazon_bedrock_role_arn": { + "type": "string" + } }, "emr_serverless_settings": { - "execution_role_arn": {"type": "string"} - }, + "execution_role_arn": { + "type": "string" + } + } }, "jupyter_lab_app_settings": { "emr_settings": { "assumable_role_arns": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "execution_role_arns": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } }, "emr_settings": { "assumable_role_arns": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "execution_role_arns": { "type": "array", - "items": {"type": "string"}, - }, - }, + "items": { + "type": "string" + } + } + } }, "domain_settings": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "r_studio_server_pro_domain_settings": { - "domain_execution_role_arn": {"type": "string"} + "domain_execution_role_arn": { + "type": "string" + } }, - "execution_role_identity_config": {"type": "string"}, - "unified_studio_settings": { - "project_s3_path": {"type": "string"} + "execution_role_identity_config": { + "type": "string" }, + "unified_studio_settings": { + "project_s3_path": { + "type": "string" + } + } + }, + "home_efs_file_system_kms_key_id": { + "type": "string" }, - "home_efs_file_system_kms_key_id": {"type": "string"}, "subnet_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } + }, + "kms_key_id": { + "type": "string" + }, + "app_security_group_management": { + "type": "string" }, - "kms_key_id": {"type": "string"}, - "app_security_group_management": {"type": "string"}, "default_space_settings": { - "execution_role": {"type": "string"}, + "execution_role": { + "type": "string" + }, "security_groups": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "jupyter_lab_app_settings": { "emr_settings": { "assumable_role_arns": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "execution_role_arns": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, - }, + } + } + } }, "EdgePackagingJob": { "type": "object", "properties": { - "role_arn": {"type": "string"}, - "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "role_arn": { + "type": "string" }, - }, + "output_config": { + "s3_output_location": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + } + } }, "Endpoint": { "type": "object", "properties": { "data_capture_config": { - "destination_s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "destination_s3_uri": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } }, "async_inference_config": { "output_config": { - "kms_key_id": {"type": "string"}, - "s3_output_path": {"type": "string"}, - "s3_failure_path": {"type": "string"}, + "kms_key_id": { + "type": "string" + }, + "s3_output_path": { + "type": "string" + }, + "s3_failure_path": { + "type": "string" + } } - }, - }, + } + } }, "EndpointConfig": { "type": "object", "properties": { "data_capture_config": { - "destination_s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "destination_s3_uri": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "kms_key_id": { + "type": "string" }, - "kms_key_id": {"type": "string"}, "async_inference_config": { "output_config": { - "kms_key_id": {"type": "string"}, - "s3_output_path": {"type": "string"}, - "s3_failure_path": {"type": "string"}, + "kms_key_id": { + "type": "string" + }, + "s3_output_path": { + "type": "string" + }, + "s3_failure_path": { + "type": "string" + } } }, - "execution_role_arn": {"type": "string"}, + "execution_role_arn": { + "type": "string" + }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, - }, + "items": { + "type": "string" + } + } + } + } }, "EvaluationJob": { "type": "object", "properties": { "output_data_config": { - "s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_uri": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, "upstream_platform_config": { "upstream_platform_customer_output_data_config": { - "s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, - "s3_kms_encryption_context": {"type": "string"}, + "s3_uri": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + }, + "s3_kms_encryption_context": { + "type": "string" + } }, "upstream_platform_customer_execution_role": { "type": "string" - }, - }, - }, + } + } + } }, "FeatureGroup": { "type": "object", "properties": { "online_store_config": { - "security_config": {"kms_key_id": {"type": "string"}} + "security_config": { + "kms_key_id": { + "type": "string" + } + } }, "offline_store_config": { "s3_storage_config": { - "s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, - "resolved_output_s3_uri": {"type": "string"}, + "s3_uri": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + }, + "resolved_output_s3_uri": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, - }, + "role_arn": { + "type": "string" + } + } }, "FlowDefinition": { "type": "object", "properties": { "output_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } }, - "role_arn": {"type": "string"}, - "task_rendering_role_arn": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, + "role_arn": { + "type": "string" + }, + "task_rendering_role_arn": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + } }, "GroundTruthJob": { "type": "object", "properties": { "input_config": { "data_source": { - "s3_data_source": {"s3_uri": {"type": "string"}} + "s3_data_source": { + "s3_uri": { + "type": "string" + } + } } }, - "output_config": {"s3_output_path": {"type": "string"}}, - }, + "output_config": { + "s3_output_path": { + "type": "string" + } + } + } }, "GroundTruthWorkflow": { "type": "object", - "properties": {"execution_role_arn": {"type": "string"}}, + "properties": { + "execution_role_arn": { + "type": "string" + } + } }, "Hub": { "type": "object", "properties": { - "s3_storage_config": {"s3_output_path": {"type": "string"}} - }, + "s3_storage_config": { + "s3_output_path": { + "type": "string" + } + } + } }, "HumanTaskUi": { "type": "object", - "properties": {"kms_key_id": {"type": "string"}}, + "properties": { + "kms_key_id": { + "type": "string" + } + } }, "HyperParameterTuningJob": { "type": "object", "properties": { "training_job_definition": { - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "output_data_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + }, "remove_job_name_from_s3_output_path": { "type": "boolean" - }, + } }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } }, "resource_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } }, "hyper_parameter_tuning_resource_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } }, - "checkpoint_config": {"s3_uri": {"type": "string"}}, + "checkpoint_config": { + "s3_uri": { + "type": "string" + } + } } - }, + } }, "Image": { "type": "object", - "properties": {"role_arn": {"type": "string"}}, + "properties": { + "role_arn": { + "type": "string" + } + } }, "InferenceExperiment": { "type": "object", "properties": { - "role_arn": {"type": "string"}, - "data_storage_config": {"kms_key": {"type": "string"}}, - "kms_key": {"type": "string"}, - }, + "role_arn": { + "type": "string" + }, + "data_storage_config": { + "kms_key": { + "type": "string" + } + }, + "kms_key": { + "type": "string" + } + } }, "InferenceRecommendationsJob": { "type": "object", "properties": { - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "input_config": { - "volume_kms_key_id": {"type": "string"}, + "volume_kms_key_id": { + "type": "string" + }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, + "items": { + "type": "string" + } + } + } }, "output_config": { - "kms_key_id": {"type": "string"}, + "kms_key_id": { + "type": "string" + }, "compiled_output_config": { - "s3_output_uri": {"type": "string"} + "s3_output_uri": { + "type": "string" + } }, "benchmark_results_output_config": { - "s3_output_uri": {"type": "string"} - }, - }, - }, + "s3_output_uri": { + "type": "string" + } + } + } + } }, "LabelingJob": { "type": "object", @@ -539,47 +849,79 @@ "input_config": { "data_source": { "s3_data_source": { - "manifest_s3_uri": {"type": "string"} + "manifest_s3_uri": { + "type": "string" + } } } }, "output_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, "human_task_config": { - "ui_config": {"ui_template_s3_uri": {"type": "string"}} + "ui_config": { + "ui_template_s3_uri": { + "type": "string" + } + } + }, + "task_rendering_role_arn": { + "type": "string" + }, + "label_category_config_s3_uri": { + "type": "string" }, - "task_rendering_role_arn": {"type": "string"}, - "label_category_config_s3_uri": {"type": "string"}, "labeling_job_algorithms_config": { "labeling_job_resource_config": { - "volume_kms_key_id": {"type": "string"}, + "volume_kms_key_id": { + "type": "string" + }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, + "items": { + "type": "string" + } + } + } } }, "labeling_job_output": { - "output_dataset_s3_uri": {"type": "string"} - }, - }, + "output_dataset_s3_uri": { + "type": "string" + } + } + } }, "MlflowApp": { "type": "object", - "properties": {"role_arn": {"type": "string"}}, + "properties": { + "role_arn": { + "type": "string" + } + } }, "MlflowTrackingServer": { "type": "object", - "properties": {"role_arn": {"type": "string"}}, + "properties": { + "role_arn": { + "type": "string" + } + } }, "Model": { "type": "object", @@ -587,234 +929,400 @@ "primary_container": { "model_data_source": { "s3_data_source": { - "s3_uri": {"type": "string"}, - "s3_data_type": {"type": "string"}, - "manifest_s3_uri": {"type": "string"}, + "s3_uri": { + "type": "string" + }, + "s3_data_type": { + "type": "string" + }, + "manifest_s3_uri": { + "type": "string" + } } } }, - "execution_role_arn": {"type": "string"}, + "execution_role_arn": { + "type": "string" + }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, - }, + "items": { + "type": "string" + } + } + } + } }, "ModelBiasJobDefinition": { "type": "object", "properties": { "model_bias_job_input": { - "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + "ground_truth_s3_input": { + "s3_uri": { + "type": "string" + } + }, "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "model_bias_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "model_bias_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}} + "constraints_resource": { + "s3_uri": { + "type": "string" + } + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, + } + } }, "ModelCard": { "type": "object", "properties": { - "security_config": {"kms_key_id": {"type": "string"}} - }, + "security_config": { + "kms_key_id": { + "type": "string" + } + } + } }, "ModelCardExportJob": { "type": "object", "properties": { - "output_config": {"s3_output_path": {"type": "string"}}, - "export_artifacts": { - "s3_export_artifacts": {"type": "string"} + "output_config": { + "s3_output_path": { + "type": "string" + } }, - }, + "export_artifacts": { + "s3_export_artifacts": { + "type": "string" + } + } + } }, "ModelExplainabilityJobDefinition": { "type": "object", "properties": { "model_explainability_job_input": { "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "model_explainability_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "model_explainability_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}} + "constraints_resource": { + "s3_uri": { + "type": "string" + } + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, + } + } }, "ModelPackage": { "type": "object", "properties": { "validation_specification": { - "validation_role": {"type": "string"} + "validation_role": { + "type": "string" + } }, "model_metrics": { "model_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, + "statistics": { + "s3_uri": { + "type": "string" + } + }, + "constraints": { + "s3_uri": { + "type": "string" + } + } }, "model_data_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, + "statistics": { + "s3_uri": { + "type": "string" + } + }, + "constraints": { + "s3_uri": { + "type": "string" + } + } }, "bias": { - "report": {"s3_uri": {"type": "string"}}, + "report": { + "s3_uri": { + "type": "string" + } + }, "pre_training_report": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } }, "post_training_report": { - "s3_uri": {"type": "string"} - }, + "s3_uri": { + "type": "string" + } + } }, "explainability": { - "report": {"s3_uri": {"type": "string"}} - }, + "report": { + "s3_uri": { + "type": "string" + } + } + } }, "deployment_specification": { "test_input": { "data_source": { "s3_data_source": { - "s3_data_type": {"type": "string"}, - "s3_uri": {"type": "string"}, - "s3_data_distribution_type": { + "s3_data_type": { + "type": "string" + }, + "s3_uri": { "type": "string" }, + "s3_data_distribution_type": { + "type": "string" + } } } } }, "drift_check_baselines": { "bias": { - "config_file": {"s3_uri": {"type": "string"}}, + "config_file": { + "s3_uri": { + "type": "string" + } + }, "pre_training_constraints": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } }, "post_training_constraints": { - "s3_uri": {"type": "string"} - }, + "s3_uri": { + "type": "string" + } + } }, "explainability": { - "constraints": {"s3_uri": {"type": "string"}}, - "config_file": {"s3_uri": {"type": "string"}}, + "constraints": { + "s3_uri": { + "type": "string" + } + }, + "config_file": { + "s3_uri": { + "type": "string" + } + } }, "model_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, + "statistics": { + "s3_uri": { + "type": "string" + } + }, + "constraints": { + "s3_uri": { + "type": "string" + } + } }, "model_data_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, - }, + "statistics": { + "s3_uri": { + "type": "string" + } + }, + "constraints": { + "s3_uri": { + "type": "string" + } + } + } }, - "security_config": {"kms_key_id": {"type": "string"}}, - }, + "security_config": { + "kms_key_id": { + "type": "string" + } + } + } }, "ModelQualityJobDefinition": { "type": "object", "properties": { "model_quality_job_input": { - "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + "ground_truth_s3_input": { + "s3_uri": { + "type": "string" + } + }, "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "model_quality_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "model_quality_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}} + "constraints_resource": { + "s3_uri": { + "type": "string" + } + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, + } + } }, "MonitoringSchedule": { "type": "object", @@ -822,362 +1330,555 @@ "monitoring_schedule_config": { "monitoring_job_definition": { "monitoring_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "monitoring_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "baseline_config": { "constraints_resource": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } }, "statistics_resource": { - "s3_uri": {"type": "string"} - }, + "s3_uri": { + "type": "string" + } + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, + } } }, "custom_monitoring_job_definition": { "custom_monitoring_job_input": { "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "ground_truth_s3_input": { - "s3_uri": {"type": "string"} - }, + "s3_uri": { + "type": "string" + } + } }, "custom_monitoring_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, + } }, "data_quality_job_definition": { "data_quality_job_input": { "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "data_quality_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "data_quality_baseline_config": { "constraints_resource": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } }, "statistics_resource": { - "s3_uri": {"type": "string"} - }, + "s3_uri": { + "type": "string" + } + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, + } }, "model_quality_job_definition": { "model_quality_job_input": { "ground_truth_s3_input": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } }, "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "model_quality_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "model_quality_baseline_config": { "constraints_resource": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, + } }, "model_bias_job_definition": { "model_bias_job_input": { "ground_truth_s3_input": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } }, "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "model_bias_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "model_bias_baseline_config": { "constraints_resource": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, + } }, "model_explainability_job_definition": { "model_explainability_job_input": { "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } }, "batch_transform_input": { "data_captured_destination_s3_uri": { "type": "string" }, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, + "s3_input_mode": { + "type": "string" + }, + "s3_data_distribution_type": { + "type": "string" + } + } }, "model_explainability_job_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "job_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, - "role_arn": {"type": "string"}, + "role_arn": { + "type": "string" + }, "model_explainability_baseline_config": { "constraints_resource": { - "s3_uri": {"type": "string"} + "s3_uri": { + "type": "string" + } } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } - }, - }, - }, + } + } + } }, "NotebookInstance": { "type": "object", "properties": { - "subnet_id": {"type": "string"}, + "subnet_id": { + "type": "string" + }, "security_groups": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, + "kms_key_id": { + "type": "string" + } + } }, "OptimizationJob": { "type": "object", "properties": { - "model_source": {"s3": {"s3_uri": {"type": "string"}}}, + "model_source": { + "s3": { + "s3_uri": { + "type": "string" + } + } + }, "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_location": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, - }, - }, + "items": { + "type": "string" + } + } + } + } }, "PartnerApp": { "type": "object", "properties": { - "execution_role_arn": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, + "execution_role_arn": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + } }, "Pipeline": { "type": "object", - "properties": {"role_arn": {"type": "string"}}, + "properties": { + "role_arn": { + "type": "string" + } + } }, "ProcessingJob": { "type": "object", "properties": { "processing_resources": { "cluster_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } } }, "processing_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "network_config": { "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } }, - "role_arn": {"type": "string"}, - }, + "role_arn": { + "type": "string" + } + } }, "QuotaAllocation": { "type": "object", "properties": { "quota_allocation_target": { - "roles": {"type": "array", "items": {"type": "string"}} + "roles": { + "type": "array", + "items": { + "type": "string" + } + } } - }, + } }, "TrainingJob": { "type": "object", "properties": { "model_artifacts": { - "s3_model_artifacts": {"type": "string"} + "s3_model_artifacts": { + "type": "string" + } }, "training_job_output": { - "s3_training_job_output": {"type": "string"} + "s3_training_job_output": { + "type": "string" + } + }, + "role_arn": { + "type": "string" }, - "role_arn": {"type": "string"}, "output_data_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + }, "remove_job_name_from_s3_output_path": { "type": "boolean" - }, + } }, "resource_config": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } + }, + "checkpoint_config": { + "s3_uri": { + "type": "string" + } + }, + "debug_hook_config": { + "s3_output_path": { + "type": "string" + } }, - "checkpoint_config": {"s3_uri": {"type": "string"}}, - "debug_hook_config": {"s3_output_path": {"type": "string"}}, "tensor_board_output_config": { - "s3_output_path": {"type": "string"} + "s3_output_path": { + "type": "string" + } }, "upstream_platform_config": { "credential_proxy_config": { @@ -1186,37 +1887,57 @@ }, "platform_credential_provider_kms_key_id": { "type": "string" - }, + } }, "vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } }, "output_data_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } + }, + "checkpoint_config": { + "s3_uri": { + "type": "string" + } }, - "checkpoint_config": {"s3_uri": {"type": "string"}}, "enable_s3_context_keys_on_input_data": { "type": "boolean" }, - "execution_role": {"type": "string"}, + "execution_role": { + "type": "string" + } + }, + "profiler_config": { + "s3_output_path": { + "type": "string" + } }, - "profiler_config": {"s3_output_path": {"type": "string"}}, "processing_job_config": { "processing_output_config": { - "kms_key_id": {"type": "string"} + "kms_key_id": { + "type": "string" + } }, "upstream_processing_output_config": { - "kms_key_id": {"type": "string"} - }, - }, - }, + "kms_key_id": { + "type": "string" + } + } + } + } }, "TransformJob": { "type": "object", @@ -1224,44 +1945,72 @@ "transform_input": { "data_source": { "s3_data_source": { - "s3_data_type": {"type": "string"}, - "s3_uri": {"type": "string"}, + "s3_data_type": { + "type": "string" + }, + "s3_uri": { + "type": "string" + } } } }, "transform_resources": { - "volume_kms_key_id": {"type": "string"} + "volume_kms_key_id": { + "type": "string" + } }, "transform_output": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } }, "data_capture_config": { - "destination_s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, - }, + "destination_s3_uri": { + "type": "string" + }, + "kms_key_id": { + "type": "string" + } + } + } }, "UserProfile": { "type": "object", "properties": { "user_settings": { - "execution_role": {"type": "string"}, + "execution_role": { + "type": "string" + }, "environment_settings": { - "default_s3_artifact_path": {"type": "string"}, - "default_s3_kms_key_id": {"type": "string"}, + "default_s3_artifact_path": { + "type": "string" + }, + "default_s3_kms_key_id": { + "type": "string" + } }, "security_groups": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "sharing_settings": { - "s3_output_path": {"type": "string"}, - "s3_kms_key_id": {"type": "string"}, + "s3_output_path": { + "type": "string" + }, + "s3_kms_key_id": { + "type": "string" + } }, "canvas_app_settings": { "time_series_forecasting_settings": { - "amazon_forecast_role_arn": {"type": "string"} + "amazon_forecast_role_arn": { + "type": "string" + } }, "model_register_settings": { "cross_account_model_register_role_arn": { @@ -1269,40 +2018,56 @@ } }, "workspace_settings": { - "s3_artifact_path": {"type": "string"}, - "s3_kms_key_id": {"type": "string"}, + "s3_artifact_path": { + "type": "string" + }, + "s3_kms_key_id": { + "type": "string" + } }, "generative_ai_settings": { - "amazon_bedrock_role_arn": {"type": "string"} + "amazon_bedrock_role_arn": { + "type": "string" + } }, "emr_serverless_settings": { - "execution_role_arn": {"type": "string"} - }, + "execution_role_arn": { + "type": "string" + } + } }, "jupyter_lab_app_settings": { "emr_settings": { "assumable_role_arns": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "execution_role_arns": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } }, "emr_settings": { "assumable_role_arns": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "execution_role_arns": { "type": "array", - "items": {"type": "string"}, - }, - }, + "items": { + "type": "string" + } + } + } } - }, + } }, "Workforce": { "type": "object", @@ -1311,24 +2076,34 @@ "workforce_vpc_config": { "security_group_ids": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string" + } }, "subnets": { "type": "array", - "items": {"type": "string"}, - }, + "items": { + "type": "string" + } + } } } - }, - }, - }, + } + } + } } }, - "required": ["Resources"], + "required": [ + "Resources" + ] } }, - "required": ["PythonSDK"], - }, + "required": [ + "PythonSDK" + ] + } }, - "required": ["SageMaker"], -} + "required": [ + "SageMaker" + ] +} \ No newline at end of file diff --git a/sagemaker-core/src/sagemaker/core/resources.py b/sagemaker-core/src/sagemaker/core/resources.py index 61e0f9c677..ab68a34686 100644 --- a/sagemaker-core/src/sagemaker/core/resources.py +++ b/sagemaker-core/src/sagemaker/core/resources.py @@ -16,7 +16,7 @@ import functools from pydantic import validate_call from typing import Dict, List, Literal, Optional, Union, Any -from boto3.session import Session +from boto3.session import Session as Boto3Session from rich.console import Group from rich.live import Live from rich.panel import Panel @@ -46,7 +46,6 @@ from sagemaker.core.serializers.base import BaseSerializer from sagemaker.core.deserializers.base import BaseDeserializer - logger = get_textual_rich_logger(__name__) @@ -208,7 +207,7 @@ def create( properties: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Action"]: """ @@ -281,7 +280,7 @@ def create( def get( cls, action_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Action"]: """ @@ -325,6 +324,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeActionResponse") action = cls(**transformed_response) + action._session = session return action @Base.add_validate_call @@ -357,7 +357,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_action(**operation_input_args) # deserialize response and update self @@ -396,7 +396,7 @@ def update( """ logger.info("Updating action resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ActionName": self.action_name, @@ -437,7 +437,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ActionName": self.action_name, @@ -460,7 +460,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Action"]: """ @@ -582,7 +582,7 @@ def create( properties: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["ActionInternal"]: """ @@ -704,6 +704,7 @@ def wrapper(*args, **kwargs): "additional_s3_data_source": { "s3_data_type": {"type": "string"}, "s3_uri": {"type": "string"}, + "manifest_s3_uri": {"type": "string"}, } }, "validation_specification": {"validation_role": {"type": "string"}}, @@ -731,7 +732,7 @@ def create( require_image_scan: Optional[bool] = Unassigned(), workflow_disabled: Optional[bool] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Algorithm"]: """ @@ -805,7 +806,7 @@ def create( def get( cls, algorithm_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Algorithm"]: """ @@ -848,6 +849,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeAlgorithmOutput") algorithm = cls(**transformed_response) + algorithm._session = session return algorithm @Base.add_validate_call @@ -879,7 +881,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_algorithm(**operation_input_args) # deserialize response and update self @@ -906,7 +908,7 @@ def delete( ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "AlgorithmName": self.algorithm_name, @@ -1045,7 +1047,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Algorithm"]: """ @@ -1176,7 +1178,7 @@ def create( persistent_volume_names: Optional[List[StrPipeVar]] = Unassigned(), app_launch_configuration: Optional[AppLaunchConfiguration] = Unassigned(), recovery_mode: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["App"]: """ @@ -1264,7 +1266,7 @@ def get( app_name: StrPipeVar, user_profile_name: Optional[StrPipeVar] = Unassigned(), space_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["App"]: """ @@ -1316,6 +1318,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeAppResponse") app = cls(**transformed_response) + app._session = session return app @Base.add_validate_call @@ -1352,7 +1355,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_app(**operation_input_args) # deserialize response and update self @@ -1387,7 +1390,7 @@ def update( """ logger.info("Updating app resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -1429,7 +1432,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -1576,7 +1579,7 @@ def get_all( domain_id_equals: Optional[StrPipeVar] = Unassigned(), user_profile_name_equals: Optional[StrPipeVar] = Unassigned(), space_name_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["App"]: """ @@ -1685,7 +1688,7 @@ def create( savitur_app_image_config: Optional[SaviturAppImageConfig] = Unassigned(), jupyter_lab_app_image_config: Optional[JupyterLabAppImageConfig] = Unassigned(), code_editor_app_image_config: Optional[CodeEditorAppImageConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AppImageConfig"]: """ @@ -1754,7 +1757,7 @@ def create( def get( cls, app_image_config_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AppImageConfig"]: """ @@ -1798,6 +1801,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeAppImageConfigResponse") app_image_config = cls(**transformed_response) + app_image_config._session = session return app_image_config @Base.add_validate_call @@ -1830,7 +1834,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_app_image_config(**operation_input_args) # deserialize response and update self @@ -1865,7 +1869,7 @@ def update( """ logger.info("Updating app_image_config resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "AppImageConfigName": self.app_image_config_name, @@ -1906,7 +1910,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "AppImageConfigName": self.app_image_config_name, @@ -1930,7 +1934,7 @@ def get_all( modified_time_after: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["AppImageConfig"]: """ @@ -2049,7 +2053,7 @@ def create( properties: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Artifact"]: """ @@ -2118,7 +2122,7 @@ def create( def get( cls, artifact_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Artifact"]: """ @@ -2162,6 +2166,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeArtifactResponse") artifact = cls(**transformed_response) + artifact._session = session return artifact @Base.add_validate_call @@ -2194,7 +2199,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_artifact(**operation_input_args) # deserialize response and update self @@ -2232,7 +2237,7 @@ def update( """ logger.info("Updating artifact resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ArtifactArn": self.artifact_arn, @@ -2272,7 +2277,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ArtifactArn": self.artifact_arn, @@ -2296,7 +2301,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Artifact"]: """ @@ -2412,7 +2417,7 @@ def create( properties: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["ArtifactInternal"]: """ @@ -2538,7 +2543,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "SourceArn": self.source_arn, @@ -2565,7 +2570,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Association"]: """ @@ -2638,7 +2643,7 @@ def add( source_arn: StrPipeVar, destination_arn: StrPipeVar, association_type: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -2683,6 +2688,13 @@ def add( logger.debug(f"Response: {response}") +class AssociationInternal(Base): + """ + Class representing resource AssociationInternal + + """ + + class AutoMLJob(Base): """ Class representing resource AutoMLJob @@ -2799,7 +2811,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), image_url_overrides: Optional[ImageUrlOverrides] = Unassigned(), model_deploy_config: Optional[ModelDeployConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AutoMLJob"]: """ @@ -2879,7 +2891,7 @@ def create( def get( cls, auto_ml_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AutoMLJob"]: """ @@ -2923,6 +2935,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeAutoMLJobResponse") auto_ml_job = cls(**transformed_response) + auto_ml_job._session = session return auto_ml_job @Base.add_validate_call @@ -2955,7 +2968,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_auto_ml_job(**operation_input_args) # deserialize response and update self @@ -2984,7 +2997,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "AutoMLJobName": self.auto_ml_job_name, @@ -3015,7 +3028,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "AutoMLJobName": self.auto_ml_job_name, @@ -3099,7 +3112,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["AutoMLJob"]: """ @@ -3169,7 +3182,7 @@ def get_all_candidates( candidate_name_equals: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[AutoMLCandidate]: """ @@ -3355,7 +3368,7 @@ def create( auto_ml_execution_mode: Optional[StrPipeVar] = Unassigned(), external_feature_transformers: Optional[AutoMLExternalFeatureTransformers] = Unassigned(), auto_ml_compute_config: Optional[AutoMLComputeConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AutoMLJobV2"]: """ @@ -3441,7 +3454,7 @@ def create( def get( cls, auto_ml_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AutoMLJobV2"]: """ @@ -3485,6 +3498,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeAutoMLJobV2Response") auto_ml_job_v2 = cls(**transformed_response) + auto_ml_job_v2._session = session return auto_ml_job_v2 @Base.add_validate_call @@ -3517,7 +3531,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_auto_ml_job_v2(**operation_input_args) # deserialize response and update self @@ -3636,7 +3650,7 @@ def create( auto_ml_job_name: StrPipeVar, auto_ml_task_context: AutoMLTaskContext, auto_ml_task_type: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AutoMLTask"]: """ @@ -3700,7 +3714,7 @@ def create( def get( cls, auto_ml_task_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["AutoMLTask"]: """ @@ -3744,6 +3758,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeAutoMLTaskResponse") auto_ml_task = cls(**transformed_response) + auto_ml_task._session = session return auto_ml_task @Base.add_validate_call @@ -3776,7 +3791,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_auto_ml_task(**operation_input_args) # deserialize response and update self @@ -3913,7 +3928,7 @@ def create( capacity_schedule_offering_id: StrPipeVar, target_services: Optional[List[StrPipeVar]] = Unassigned(), max_wait_time_in_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CapacitySchedule"]: """ @@ -3982,7 +3997,7 @@ def create( def get( cls, capacity_schedule_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CapacitySchedule"]: """ @@ -4026,6 +4041,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeCapacityScheduleResponse") capacity_schedule = cls(**transformed_response) + capacity_schedule._session = session return capacity_schedule @Base.add_validate_call @@ -4059,7 +4075,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_capacity_schedule(**operation_input_args) # deserialize response and update self @@ -4101,7 +4117,7 @@ def update( """ logger.info("Updating capacity_schedule resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "CapacityScheduleName": capacity_schedule_name, @@ -4140,7 +4156,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "CapacityScheduleName": self.capacity_schedule_name, @@ -4222,7 +4238,7 @@ def load( capacity_schedule_name: StrPipeVar, capacity_resource_arn: StrPipeVar, target_resources: List[StrPipeVar], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CapacitySchedule"]: """ @@ -4343,7 +4359,8 @@ def wrapper(*args, **kwargs): "vpc_config": { "security_group_ids": {"type": "array", "items": {"type": "string"}}, "subnets": {"type": "array", "items": {"type": "string"}}, - } + }, + "cluster_role": {"type": "string"}, } return create_func( *args, @@ -4375,7 +4392,7 @@ def create( cluster_role: Optional[StrPipeVar] = Unassigned(), auto_scaling: Optional[ClusterAutoScalingConfig] = Unassigned(), custom_metadata: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Cluster"]: """ @@ -4462,7 +4479,7 @@ def create( def get( cls, cluster_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Cluster"]: """ @@ -4506,6 +4523,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeClusterResponse") cluster = cls(**transformed_response) + cluster._session = session return cluster @Base.add_validate_call @@ -4538,7 +4556,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_cluster(**operation_input_args) # deserialize response and update self @@ -4590,7 +4608,7 @@ def update( """ logger.info("Updating cluster resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ClusterName": self.cluster_name, @@ -4641,7 +4659,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ClusterName": self.cluster_name, @@ -4790,7 +4808,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), training_plan_arn: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Cluster"]: """ @@ -4854,7 +4872,7 @@ def get_node( self, node_id: Optional[StrPipeVar] = Unassigned(), node_logical_id: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[ClusterNodeDetails]: """ @@ -4911,7 +4929,7 @@ def get_all_nodes( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), include_node_logical_ids: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[ClusterNodeDetails]: """ @@ -4977,7 +4995,7 @@ def update_software( deployment_config: Optional[DeploymentConfiguration] = Unassigned(), dry_run: Optional[bool] = Unassigned(), image_id: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -5030,7 +5048,7 @@ def batch_delete_nodes( node_ids: Optional[List[StrPipeVar]] = Unassigned(), node_logical_ids: Optional[List[StrPipeVar]] = Unassigned(), dry_run: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[BatchDeleteClusterNodesResponse]: """ @@ -5157,7 +5175,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), dry_run: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ClusterSchedulerConfig"]: """ @@ -5233,7 +5251,7 @@ def get( cls, cluster_scheduler_config_id: StrPipeVar, cluster_scheduler_config_version: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ClusterSchedulerConfig"]: """ @@ -5279,6 +5297,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeClusterSchedulerConfigResponse") cluster_scheduler_config = cls(**transformed_response) + cluster_scheduler_config._session = session return cluster_scheduler_config @Base.add_validate_call @@ -5312,7 +5331,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_cluster_scheduler_config(**operation_input_args) # deserialize response and update self @@ -5354,7 +5373,7 @@ def update( """ logger.info("Updating cluster_scheduler_config resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ClusterSchedulerConfigId": self.cluster_scheduler_config_id, @@ -5397,7 +5416,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ClusterSchedulerConfigId": self.cluster_scheduler_config_id, @@ -5564,7 +5583,7 @@ def get_all( status: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ClusterSchedulerConfig"]: """ @@ -5668,7 +5687,7 @@ def create( code_repository_name: StrPipeVar, git_config: GitConfig, tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CodeRepository"]: """ @@ -5730,7 +5749,7 @@ def create( def get( cls, code_repository_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CodeRepository"]: """ @@ -5773,6 +5792,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeCodeRepositoryOutput") code_repository = cls(**transformed_response) + code_repository._session = session return code_repository @Base.add_validate_call @@ -5804,7 +5824,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_code_repository(**operation_input_args) # deserialize response and update self @@ -5836,7 +5856,7 @@ def update( """ logger.info("Updating code_repository resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "CodeRepositoryName": self.code_repository_name, @@ -5873,7 +5893,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "CodeRepositoryName": self.code_repository_name, @@ -5897,7 +5917,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator["CodeRepository"]: """ @@ -6032,6 +6052,7 @@ def wrapper(*args, **kwargs): "s3_output_location": {"type": "string"}, "kms_key_id": {"type": "string"}, }, + "resource_config": {"volume_kms_key_id": {"type": "string"}}, "vpc_config": { "security_group_ids": {"type": "array", "items": {"type": "string"}}, "subnets": {"type": "array", "items": {"type": "string"}}, @@ -6060,7 +6081,7 @@ def create( resource_config: Optional[NeoResourceConfig] = Unassigned(), vpc_config: Optional[NeoVpcConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CompilationJob"]: """ @@ -6136,7 +6157,7 @@ def create( def get( cls, compilation_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CompilationJob"]: """ @@ -6180,6 +6201,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeCompilationJobResponse") compilation_job = cls(**transformed_response) + compilation_job._session = session return compilation_job @Base.add_validate_call @@ -6212,7 +6234,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_compilation_job(**operation_input_args) # deserialize response and update self @@ -6239,7 +6261,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "CompilationJobName": self.compilation_job_name, @@ -6270,7 +6292,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "CompilationJobName": self.compilation_job_name, @@ -6354,7 +6376,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["CompilationJob"]: """ @@ -6485,7 +6507,7 @@ def create( activation_state: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), dry_run: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ComputeQuota"]: """ @@ -6561,7 +6583,7 @@ def get( cls, compute_quota_id: StrPipeVar, compute_quota_version: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ComputeQuota"]: """ @@ -6607,6 +6629,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeComputeQuotaResponse") compute_quota = cls(**transformed_response) + compute_quota._session = session return compute_quota @Base.add_validate_call @@ -6640,7 +6663,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_compute_quota(**operation_input_args) # deserialize response and update self @@ -6684,7 +6707,7 @@ def update( """ logger.info("Updating compute_quota resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ComputeQuotaId": self.compute_quota_id, @@ -6729,7 +6752,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ComputeQuotaId": self.compute_quota_id, @@ -6892,7 +6915,7 @@ def get_all( cluster_arn: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ComputeQuota"]: """ @@ -7011,7 +7034,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), properties: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Context"]: """ @@ -7080,7 +7103,7 @@ def create( def get( cls, context_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Context"]: """ @@ -7124,6 +7147,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeContextResponse") context = cls(**transformed_response) + context._session = session return context @Base.add_validate_call @@ -7156,7 +7180,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_context(**operation_input_args) # deserialize response and update self @@ -7194,7 +7218,7 @@ def update( """ logger.info("Updating context resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ContextName": self.context_name, @@ -7234,7 +7258,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ContextName": self.context_name, @@ -7257,7 +7281,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Context"]: """ @@ -7373,7 +7397,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), properties: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["ContextInternal"]: """ @@ -7506,7 +7530,7 @@ def create( environment: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), source_arn: Optional[StrPipeVar] = Unassigned(), source_account: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["CrossAccountTrainingJob"]: """ @@ -7625,7 +7649,43 @@ def get_name(self) -> str: logger.error("Name attribute not found for object custom_monitoring_job_definition") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "custom_monitoring_job_input": { + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + }, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "custom_monitoring_job_output_config": {"kms_key_id": {"type": "string"}}, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "CustomMonitoringJobDefinition", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -7638,7 +7698,7 @@ def create( network_config: Optional[MonitoringNetworkConfig] = Unassigned(), stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CustomMonitoringJobDefinition"]: """ @@ -7714,7 +7774,7 @@ def create( def get( cls, job_definition_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["CustomMonitoringJobDefinition"]: """ @@ -7758,6 +7818,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeCustomMonitoringJobDefinitionResponse") custom_monitoring_job_definition = cls(**transformed_response) + custom_monitoring_job_definition._session = session return custom_monitoring_job_definition @Base.add_validate_call @@ -7790,7 +7851,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_custom_monitoring_job_definition(**operation_input_args) # deserialize response and update self @@ -7817,7 +7878,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobDefinitionName": self.job_definition_name, @@ -7840,7 +7901,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["CustomMonitoringJobDefinition"]: """ @@ -8004,7 +8065,7 @@ def create( network_config: Optional[MonitoringNetworkConfig] = Unassigned(), stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["DataQualityJobDefinition"]: """ @@ -8082,7 +8143,7 @@ def create( def get( cls, job_definition_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["DataQualityJobDefinition"]: """ @@ -8126,6 +8187,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeDataQualityJobDefinitionResponse") data_quality_job_definition = cls(**transformed_response) + data_quality_job_definition._session = session return data_quality_job_definition @Base.add_validate_call @@ -8158,7 +8220,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_data_quality_job_definition(**operation_input_args) # deserialize response and update self @@ -8185,7 +8247,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobDefinitionName": self.job_definition_name, @@ -8208,7 +8270,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["DataQualityJobDefinition"]: """ @@ -8326,7 +8388,7 @@ def get( device_name: StrPipeVar, device_fleet_name: StrPipeVar, next_token: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Device"]: """ @@ -8374,6 +8436,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeDeviceResponse") device = cls(**transformed_response) + device._session = session return device @Base.add_validate_call @@ -8408,7 +8471,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_device(**operation_input_args) # deserialize response and update self @@ -8422,7 +8485,7 @@ def get_all( latest_heartbeat_after: Optional[datetime.datetime] = Unassigned(), model_name: Optional[StrPipeVar] = Unassigned(), device_fleet_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Device"]: """ @@ -8548,7 +8611,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), enable_iot_role_alias: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["DeviceFleet"]: """ @@ -8618,7 +8681,7 @@ def create( def get( cls, device_fleet_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["DeviceFleet"]: """ @@ -8662,6 +8725,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeDeviceFleetResponse") device_fleet = cls(**transformed_response) + device_fleet._session = session return device_fleet @Base.add_validate_call @@ -8694,7 +8758,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_device_fleet(**operation_input_args) # deserialize response and update self @@ -8733,7 +8797,7 @@ def update( """ logger.info("Updating device_fleet resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DeviceFleetName": self.device_fleet_name, @@ -8774,7 +8838,7 @@ def delete( ResourceInUse: Resource being accessed is in use. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DeviceFleetName": self.device_fleet_name, @@ -8798,7 +8862,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["DeviceFleet"]: """ @@ -8863,7 +8927,7 @@ def get_all( def deregister_devices( self, device_names: List[StrPipeVar], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -8905,7 +8969,7 @@ def deregister_devices( @Base.add_validate_call def get_report( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[GetDeviceFleetReportResponse]: """ @@ -8953,7 +9017,7 @@ def register_devices( self, devices: List[Device], tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -8999,7 +9063,7 @@ def register_devices( def update_devices( self, devices: List[Device], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -9121,6 +9185,10 @@ def wrapper(*args, **kwargs): "security_group_id_for_domain_boundary": {"type": "string"}, "default_user_settings": { "execution_role": {"type": "string"}, + "environment_settings": { + "default_s3_artifact_path": {"type": "string"}, + "default_s3_kms_key_id": {"type": "string"}, + }, "security_groups": {"type": "array", "items": {"type": "string"}}, "sharing_settings": { "s3_output_path": {"type": "string"}, @@ -9146,6 +9214,10 @@ def wrapper(*args, **kwargs): "execution_role_arns": {"type": "array", "items": {"type": "string"}}, } }, + "emr_settings": { + "assumable_role_arns": {"type": "array", "items": {"type": "string"}}, + "execution_role_arns": {"type": "array", "items": {"type": "string"}}, + }, }, "domain_settings": { "security_group_ids": {"type": "array", "items": {"type": "string"}}, @@ -9153,6 +9225,7 @@ def wrapper(*args, **kwargs): "domain_execution_role_arn": {"type": "string"} }, "execution_role_identity_config": {"type": "string"}, + "unified_studio_settings": {"project_s3_path": {"type": "string"}}, }, "home_efs_file_system_kms_key_id": {"type": "string"}, "subnet_ids": {"type": "array", "items": {"type": "string"}}, @@ -9198,7 +9271,7 @@ def create( app_storage_type: Optional[StrPipeVar] = Unassigned(), tag_propagation: Optional[StrPipeVar] = Unassigned(), default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Domain"]: """ @@ -9286,7 +9359,7 @@ def create( def get( cls, domain_id: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Domain"]: """ @@ -9330,6 +9403,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeDomainResponse") domain = cls(**transformed_response) + domain._session = session return domain @Base.add_validate_call @@ -9362,7 +9436,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_domain(**operation_input_args) # deserialize response and update self @@ -9407,7 +9481,7 @@ def update( """ logger.info("Updating domain resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -9454,7 +9528,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -9605,7 +9679,7 @@ def wait_for_delete( @Base.add_validate_call def get_all( cls, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Domain"]: """ @@ -9688,7 +9762,7 @@ def create( device_fleet_name: Union[StrPipeVar, object], stages: Optional[List[DeploymentStage]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EdgeDeploymentPlan"]: """ @@ -9759,7 +9833,7 @@ def get( edge_deployment_plan_name: StrPipeVar, next_token: Optional[StrPipeVar] = Unassigned(), max_results: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EdgeDeploymentPlan"]: """ @@ -9807,6 +9881,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeEdgeDeploymentPlanResponse") edge_deployment_plan = cls(**transformed_response) + edge_deployment_plan._session = session return edge_deployment_plan @Base.add_validate_call @@ -9842,7 +9917,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_edge_deployment_plan(**operation_input_args) # deserialize response and update self @@ -9869,7 +9944,7 @@ def delete( ResourceInUse: Resource being accessed is in use. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EdgeDeploymentPlanName": self.edge_deployment_plan_name, @@ -9894,7 +9969,7 @@ def get_all( device_fleet_name_contains: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["EdgeDeploymentPlan"]: """ @@ -9960,7 +10035,7 @@ def get_all( @Base.add_validate_call def create_stage( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -10003,7 +10078,7 @@ def create_stage( def delete_stage( self, stage_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -10047,7 +10122,7 @@ def delete_stage( def start_stage( self, stage_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -10090,7 +10165,7 @@ def start_stage( def stop_stage( self, stage_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -10134,7 +10209,7 @@ def get_all_stage_devices( self, stage_name: StrPipeVar, exclude_devices_deployed_in_other_stage: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[DeviceDeploymentSummary]: """ @@ -10272,7 +10347,7 @@ def create( output_config: EdgeOutputConfig, resource_key: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EdgePackagingJob"]: """ @@ -10347,7 +10422,7 @@ def create( def get( cls, edge_packaging_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EdgePackagingJob"]: """ @@ -10391,6 +10466,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeEdgePackagingJobResponse") edge_packaging_job = cls(**transformed_response) + edge_packaging_job._session = session return edge_packaging_job @Base.add_validate_call @@ -10423,7 +10499,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_edge_packaging_job(**operation_input_args) # deserialize response and update self @@ -10447,7 +10523,7 @@ def stop(self) -> None: ``` """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EdgePackagingJobName": self.edge_packaging_job_name, @@ -10534,7 +10610,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["EdgePackagingJob"]: """ @@ -10697,7 +10773,7 @@ def create( deletion_condition: Optional[EndpointDeletionCondition] = Unassigned(), deployment_config: Optional[DeploymentConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Endpoint"]: """ @@ -10766,7 +10842,7 @@ def create( def get( cls, endpoint_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Endpoint"]: """ @@ -10809,6 +10885,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeEndpointOutput") endpoint = cls(**transformed_response) + endpoint._session = session return endpoint @Base.add_validate_call @@ -10840,7 +10917,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_endpoint(**operation_input_args) # deserialize response and update self @@ -10882,7 +10959,7 @@ def update( """ logger.info("Updating endpoint resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EndpointName": self.endpoint_name, @@ -10924,7 +11001,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EndpointName": self.endpoint_name, @@ -10954,7 +11031,6 @@ def wait_for_status( ], poll: int = 5, timeout: Optional[int] = None, - logs: Optional[bool] = False, ) -> None: """ Wait for a Endpoint resource to reach certain status. @@ -10963,7 +11039,6 @@ def wait_for_status( target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. - logs: Whether to print logs while waiting. Raises: TimeoutExceededError: If the resource does not reach a terminal state before the timeout. @@ -10980,21 +11055,6 @@ def wait_for_status( progress.add_task(f"Waiting for Endpoint to reach [bold]{target_status} status...") status = Status("Current status:") - if logs: - instance_count = ( - sum(variant.current_instance_count for variant in self.production_variants) - if self.production_variants and not isinstance(self.production_variants, Unassigned) - else 1 - ) - log_group_name = f"/aws/sagemaker/Endpoints/{self.get_name()}" - logger.info(f"log_group_name") - logger.info(log_group_name) - multi_stream_logger = MultiLogStreamHandler( - log_group_name=f"/aws/sagemaker/Endpoints/{self.get_name()}", - log_stream_name_prefix=self.get_name(), - expected_stream_count=instance_count, - ) - with Live( Panel( Group(progress, status), @@ -11008,11 +11068,6 @@ def wait_for_status( current_status = self.endpoint_status status.update(f"Current status: [bold]{current_status}") - if logs and multi_stream_logger.ready(): - stream_log_events = multi_stream_logger.get_latest_log_events() - for stream_id, event in stream_log_events: - logger.info(f"{stream_id}:\n{event['message']}") - if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return @@ -11099,7 +11154,7 @@ def get_all( last_modified_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), status_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Endpoint"]: """ @@ -11166,7 +11221,7 @@ def get_all( def update_weights_and_capacities( self, desired_weights_and_capacities: List[DesiredWeightAndCapacity], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -11220,7 +11275,7 @@ def invoke( enable_explanations: Optional[StrPipeVar] = Unassigned(), inference_component_name: Optional[StrPipeVar] = Unassigned(), session_id: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[InvokeEndpointOutput]: """ @@ -11319,7 +11374,7 @@ def invoke_async( inference_id: Optional[StrPipeVar] = Unassigned(), request_ttl_seconds: Optional[int] = Unassigned(), invocation_timeout_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[InvokeEndpointAsyncOutput]: """ @@ -11391,7 +11446,7 @@ def invoke_with_response_stream( inference_id: Optional[StrPipeVar] = Unassigned(), inference_component_name: Optional[StrPipeVar] = Unassigned(), session_id: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[InvokeEndpointWithResponseStreamOutput]: """ @@ -11558,7 +11613,7 @@ def create( vpc_config: Optional[VpcConfig] = Unassigned(), enable_network_isolation: Optional[bool] = Unassigned(), metrics_config: Optional[MetricsConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EndpointConfig"]: """ @@ -11639,7 +11694,7 @@ def create( def get( cls, endpoint_config_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EndpointConfig"]: """ @@ -11682,6 +11737,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeEndpointConfigOutput") endpoint_config = cls(**transformed_response) + endpoint_config._session = session return endpoint_config @Base.add_validate_call @@ -11713,7 +11769,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_endpoint_config(**operation_input_args) # deserialize response and update self @@ -11739,7 +11795,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EndpointConfigName": self.endpoint_config_name, @@ -11761,7 +11817,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["EndpointConfig"]: """ @@ -11818,7 +11874,7 @@ def get_all( list_method_kwargs=operation_input_args, ) -''' + class EndpointConfigInternal(Base): """ Class representing resource EndpointConfigInternal @@ -11859,7 +11915,7 @@ def create( endpoint_config_input: CreateEndpointConfigInput, account_id: StrPipeVar, auto_ml_job_arn: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["EndpointConfigInternal"]: """ @@ -11929,7 +11985,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EndpointConfigInput": self.endpoint_config_input, @@ -11994,7 +12050,7 @@ def create( fas_credentials: Optional[StrPipeVar] = Unassigned(), encrypted_fas_credentials: Optional[StrPipeVar] = Unassigned(), billing_mode: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["EndpointInternal"]: """ @@ -12070,7 +12126,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EndpointInput": self.endpoint_input, @@ -12141,7 +12197,35 @@ def get_name(self) -> str: logger.error("Name attribute not found for object evaluation_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "output_data_config": { + "s3_uri": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "role_arn": {"type": "string"}, + "upstream_platform_config": { + "upstream_platform_customer_output_data_config": { + "s3_uri": {"type": "string"}, + "kms_key_id": {"type": "string"}, + "s3_kms_encryption_context": {"type": "string"}, + }, + "upstream_platform_customer_execution_role": {"type": "string"}, + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "EvaluationJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -12155,7 +12239,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), model_config: Optional[EvaluationJobModelConfig] = Unassigned(), upstream_platform_config: Optional[EvaluationJobUpstreamPlatformConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EvaluationJob"]: """ @@ -12233,7 +12317,7 @@ def create( def get( cls, evaluation_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["EvaluationJob"]: """ @@ -12277,6 +12361,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeEvaluationJobResponse") evaluation_job = cls(**transformed_response) + evaluation_job._session = session return evaluation_job @Base.add_validate_call @@ -12309,7 +12394,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_evaluation_job(**operation_input_args) # deserialize response and update self @@ -12337,7 +12422,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EvaluationJobName": self.evaluation_job_name, @@ -12368,7 +12453,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "EvaluationJobName": self.evaluation_job_name, @@ -12450,7 +12535,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), status_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["EvaluationJob"]: """ @@ -12508,7 +12593,7 @@ def get_all( resource_cls=EvaluationJob, list_method_kwargs=operation_input_args, ) -''' + class Experiment(Base): """ @@ -12561,7 +12646,7 @@ def create( display_name: Optional[StrPipeVar] = Unassigned(), description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Experiment"]: """ @@ -12626,7 +12711,7 @@ def create( def get( cls, experiment_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Experiment"]: """ @@ -12670,6 +12755,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeExperimentResponse") experiment = cls(**transformed_response) + experiment._session = session return experiment @Base.add_validate_call @@ -12702,7 +12788,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_experiment(**operation_input_args) # deserialize response and update self @@ -12736,7 +12822,7 @@ def update( """ logger.info("Updating experiment resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ExperimentName": self.experiment_name, @@ -12775,7 +12861,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ExperimentName": self.experiment_name, @@ -12796,7 +12882,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Experiment"]: """ @@ -12904,7 +12990,7 @@ def create( source: Optional[InputExperimentSource] = Unassigned(), creation_time: Optional[datetime.datetime] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["ExperimentInternal"]: """ @@ -13075,7 +13161,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), use_pre_prod_offline_store_replicator_lambda: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["FeatureGroup"]: """ @@ -13156,7 +13242,7 @@ def get( cls, feature_group_name: StrPipeVar, next_token: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["FeatureGroup"]: """ @@ -13202,6 +13288,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeFeatureGroupResponse") feature_group = cls(**transformed_response) + feature_group._session = session return feature_group @Base.add_validate_call @@ -13235,7 +13322,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_feature_group(**operation_input_args) # deserialize response and update self @@ -13277,7 +13364,7 @@ def update( """ logger.info("Updating feature_group resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "FeatureGroupName": self.feature_group_name, @@ -13319,7 +13406,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "FeatureGroupName": self.feature_group_name, @@ -13464,7 +13551,7 @@ def get_all( creation_time_before: Optional[datetime.datetime] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["FeatureGroup"]: """ @@ -13531,7 +13618,7 @@ def get_record( record_identifier_value_as_string: StrPipeVar, feature_names: Optional[List[StrPipeVar]] = Unassigned(), expiration_time_response: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[GetRecordResponse]: """ @@ -13591,7 +13678,7 @@ def put_record( record: List[FeatureValue], target_stores: Optional[List[StrPipeVar]] = Unassigned(), ttl_duration: Optional[TtlDuration] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -13645,7 +13732,7 @@ def delete_record( event_time: StrPipeVar, target_stores: Optional[List[StrPipeVar]] = Unassigned(), deletion_mode: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -13699,7 +13786,7 @@ def batch_get_record( self, identifiers: List[BatchGetRecordIdentifier], expiration_time_response: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[BatchGetRecordResponse]: """ @@ -13835,7 +13922,7 @@ def create( storage_account_stage_test_override: Optional[StrPipeVar] = Unassigned(), online_store_metadata: Optional[OnlineStoreMetadata] = Unassigned(), online_store_replica_metadata: Optional[OnlineStoreReplicaMetadata] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["FeatureGroupInternal"]: """ @@ -13968,7 +14055,7 @@ def get( cls, feature_group_name: StrPipeVar, feature_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["FeatureMetadata"]: """ @@ -14014,6 +14101,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeFeatureMetadataResponse") feature_metadata = cls(**transformed_response) + feature_metadata._session = session return feature_metadata @Base.add_validate_call @@ -14047,7 +14135,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_feature_metadata(**operation_input_args) # deserialize response and update self @@ -14085,7 +14173,7 @@ def update( """ logger.info("Updating feature_metadata resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "FeatureGroupName": self.feature_group_name, @@ -14167,6 +14255,8 @@ def wrapper(*args, **kwargs): "kms_key_id": {"type": "string"}, }, "role_arn": {"type": "string"}, + "task_rendering_role_arn": {"type": "string"}, + "kms_key_id": {"type": "string"}, } return create_func( *args, @@ -14192,7 +14282,7 @@ def create( task_rendering_role_arn: Optional[StrPipeVar] = Unassigned(), kms_key_id: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["FlowDefinition"]: """ @@ -14270,7 +14360,7 @@ def create( def get( cls, flow_definition_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["FlowDefinition"]: """ @@ -14314,6 +14404,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeFlowDefinitionResponse") flow_definition = cls(**transformed_response) + flow_definition._session = session return flow_definition @Base.add_validate_call @@ -14346,7 +14437,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_flow_definition(**operation_input_args) # deserialize response and update self @@ -14374,7 +14465,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "FlowDefinitionName": self.flow_definition_name, @@ -14515,7 +14606,7 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["FlowDefinition"]: """ @@ -14614,7 +14705,24 @@ def get_name(self) -> str: logger.error("Name attribute not found for object ground_truth_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "input_config": {"data_source": {"s3_data_source": {"s3_uri": {"type": "string"}}}}, + "output_config": {"s3_output_path": {"type": "string"}}, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "GroundTruthJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -14624,7 +14732,7 @@ def create( input_config: GroundTruthJobInputConfig, output_config: GroundTruthJobOutputConfig, ground_truth_job_description: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["GroundTruthJob"]: """ @@ -14702,7 +14810,7 @@ def get( ground_truth_project_name: StrPipeVar, ground_truth_workflow_name: StrPipeVar, ground_truth_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["GroundTruthJob"]: """ @@ -14750,6 +14858,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeGroundTruthJobResponse") ground_truth_job = cls(**transformed_response) + ground_truth_job._session = session return ground_truth_job @Base.add_validate_call @@ -14786,7 +14895,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_ground_truth_job(**operation_input_args) # deserialize response and update self @@ -14897,7 +15006,7 @@ def create( ground_truth_project_name: StrPipeVar, ground_truth_project_description: Optional[StrPipeVar] = Unassigned(), point_of_contact: Optional[GroundTruthProjectPointOfContact] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["GroundTruthProject"]: """ @@ -14962,7 +15071,7 @@ def create( def get( cls, ground_truth_project_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["GroundTruthProject"]: """ @@ -15006,6 +15115,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeGroundTruthProjectResponse") ground_truth_project = cls(**transformed_response) + ground_truth_project._session = session return ground_truth_project @Base.add_validate_call @@ -15038,7 +15148,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_ground_truth_project(**operation_input_args) # deserialize response and update self @@ -15104,7 +15214,7 @@ def wait_for_status( @Base.add_validate_call def get_all( cls, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["GroundTruthProject"]: """ @@ -15168,7 +15278,21 @@ def get_name(self) -> str: logger.error("Name attribute not found for object ground_truth_workflow") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = {"execution_role_arn": {"type": "string"}} + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "GroundTruthWorkflow", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -15176,7 +15300,7 @@ def create( ground_truth_workflow_name: StrPipeVar, ground_truth_workflow_definition_spec: StrPipeVar, execution_role_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["GroundTruthWorkflow"]: """ @@ -15248,7 +15372,7 @@ def get( cls, ground_truth_project_name: StrPipeVar, ground_truth_workflow_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["GroundTruthWorkflow"]: """ @@ -15294,6 +15418,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeGroundTruthWorkflowResponse") ground_truth_workflow = cls(**transformed_response) + ground_truth_workflow._session = session return ground_truth_workflow @Base.add_validate_call @@ -15328,7 +15453,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_ground_truth_workflow(**operation_input_args) # deserialize response and update self @@ -15407,7 +15532,7 @@ def create( hub_search_keywords: Optional[List[StrPipeVar]] = Unassigned(), s3_storage_config: Optional[HubS3StorageConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Hub"]: """ @@ -15477,7 +15602,7 @@ def create( def get( cls, hub_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Hub"]: """ @@ -15521,6 +15646,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeHubResponse") hub = cls(**transformed_response) + hub._session = session return hub @Base.add_validate_call @@ -15553,7 +15679,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_hub(**operation_input_args) # deserialize response and update self @@ -15588,7 +15714,7 @@ def update( """ logger.info("Updating hub resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HubName": self.hub_name, @@ -15629,7 +15755,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HubName": self.hub_name, @@ -15778,7 +15904,7 @@ def get_all( last_modified_time_after: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Hub"]: """ @@ -15914,7 +16040,7 @@ def get( hub_content_type: StrPipeVar, hub_content_name: StrPipeVar, hub_content_version: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["HubContent"]: """ @@ -15964,6 +16090,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeHubContentResponse") hub_content = cls(**transformed_response) + hub_content._session = session return hub_content @Base.add_validate_call @@ -15999,7 +16126,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_hub_content(**operation_input_args) # deserialize response and update self @@ -16038,7 +16165,7 @@ def update( """ logger.info("Updating hub_content resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HubName": self.hub_name, @@ -16084,7 +16211,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HubName": self.hub_name, @@ -16167,7 +16294,7 @@ def load( support_status: Optional[StrPipeVar] = Unassigned(), hub_content_search_keywords: Optional[List[StrPipeVar]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["HubContent"]: """ @@ -16253,7 +16380,7 @@ def get_all_versions( creation_time_after: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator["HubContent"]: """ @@ -16368,7 +16495,7 @@ def create( access_config: Optional[PresignedUrlAccessConfig] = Unassigned(), max_results: Optional[int] = Unassigned(), next_token: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["HubContentPresignedUrls"]: """ @@ -16476,7 +16603,7 @@ def create( hub_content_name: Optional[Union[StrPipeVar, object]] = Unassigned(), min_version: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["HubContentReference"]: """ @@ -16564,7 +16691,7 @@ def update( """ logger.info("Updating hub_content_reference resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HubName": self.hub_name, @@ -16605,7 +16732,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HubName": self.hub_name, @@ -16658,7 +16785,21 @@ def get_name(self) -> str: logger.error("Name attribute not found for object human_task_ui") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = {"kms_key_id": {"type": "string"}} + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "HumanTaskUi", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -16666,7 +16807,7 @@ def create( ui_template: UiTemplate, kms_key_id: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["HumanTaskUi"]: """ @@ -16732,7 +16873,7 @@ def create( def get( cls, human_task_ui_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["HumanTaskUi"]: """ @@ -16776,6 +16917,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeHumanTaskUiResponse") human_task_ui = cls(**transformed_response) + human_task_ui._session = session return human_task_ui @Base.add_validate_call @@ -16808,13 +16950,14 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_human_task_ui(**operation_input_args) # deserialize response and update self transform(response, "DescribeHumanTaskUiResponse", self) return self + @populate_inputs_decorator @Base.add_validate_call def update( self, @@ -16841,7 +16984,7 @@ def update( """ logger.info("Updating human_task_ui resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HumanTaskUiName": self.human_task_ui_name, @@ -16879,7 +17022,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HumanTaskUiName": self.human_task_ui_name, @@ -17013,7 +17156,7 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["HumanTaskUi"]: """ @@ -17139,6 +17282,7 @@ def wrapper(*args, **kwargs): "output_data_config": { "s3_output_path": {"type": "string"}, "kms_key_id": {"type": "string"}, + "remove_job_name_from_s3_output_path": {"type": "boolean"}, }, "vpc_config": { "security_group_ids": {"type": "array", "items": {"type": "string"}}, @@ -17174,7 +17318,7 @@ def create( warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), autotune: Optional[Autotune] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["HyperParameterTuningJob"]: """ @@ -17250,7 +17394,7 @@ def create( def get( cls, hyper_parameter_tuning_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["HyperParameterTuningJob"]: """ @@ -17294,6 +17438,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeHyperParameterTuningJobResponse") hyper_parameter_tuning_job = cls(**transformed_response) + hyper_parameter_tuning_job._session = session return hyper_parameter_tuning_job @Base.add_validate_call @@ -17326,7 +17471,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_hyper_parameter_tuning_job(**operation_input_args) # deserialize response and update self @@ -17352,7 +17497,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, @@ -17383,7 +17528,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, @@ -17532,7 +17677,7 @@ def get_all( last_modified_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), status_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["HyperParameterTuningJob"]: """ @@ -17601,7 +17746,7 @@ def get_all_training_jobs( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[HyperParameterTrainingJobSummary]: """ @@ -17728,7 +17873,7 @@ def create( billing_mode: Optional[StrPipeVar] = Unassigned(), source_identity: Optional[StrPipeVar] = Unassigned(), identity_center_user_token: Optional[IdentityCenterUserToken] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["HyperParameterTuningJobInternal"]: """ @@ -17819,7 +17964,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, @@ -17900,7 +18045,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), display_name: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Image"]: """ @@ -17968,7 +18113,7 @@ def create( def get( cls, image_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Image"]: """ @@ -18012,6 +18157,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeImageResponse") image = cls(**transformed_response) + image._session = session return image @Base.add_validate_call @@ -18044,7 +18190,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_image(**operation_input_args) # deserialize response and update self @@ -18084,7 +18230,7 @@ def update( """ logger.info("Updating image resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DeleteProperties": delete_properties, @@ -18126,7 +18272,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ImageName": self.image_name, @@ -18283,7 +18429,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Image"]: """ @@ -18349,7 +18495,7 @@ def get_all_aliases( self, alias: Optional[StrPipeVar] = Unassigned(), version: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[str]: """ @@ -18480,7 +18626,7 @@ def create( horovod: Optional[bool] = Unassigned(), override_alias_image_version: Optional[bool] = Unassigned(), release_notes: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ImageVersion"]: """ @@ -18565,7 +18711,7 @@ def get( image_name: StrPipeVar, version: Optional[int] = Unassigned(), alias: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ImageVersion"]: """ @@ -18613,6 +18759,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeImageVersionResponse") image_version = cls(**transformed_response) + image_version._session = session return image_version @Base.add_validate_call @@ -18648,7 +18795,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_image_version(**operation_input_args) # deserialize response and update self @@ -18696,7 +18843,7 @@ def update( """ logger.info("Updating image_version resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ImageName": self.image_name, @@ -18746,7 +18893,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ImageName": self.image_name, @@ -18950,7 +19097,7 @@ def create( variant_name: Optional[StrPipeVar] = Unassigned(), runtime_config: Optional[InferenceComponentRuntimeConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["InferenceComponent"]: """ @@ -19021,7 +19168,7 @@ def create( def get( cls, inference_component_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["InferenceComponent"]: """ @@ -19064,6 +19211,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeInferenceComponentOutput") inference_component = cls(**transformed_response) + inference_component._session = session return inference_component @Base.add_validate_call @@ -19095,7 +19243,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_inference_component(**operation_input_args) # deserialize response and update self @@ -19132,7 +19280,7 @@ def update( """ logger.info("Updating inference_component resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "InferenceComponentName": self.inference_component_name, @@ -19171,7 +19319,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "InferenceComponentName": self.inference_component_name, @@ -19323,7 +19471,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), endpoint_name_equals: Optional[StrPipeVar] = Unassigned(), variant_name_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["InferenceComponent"]: """ @@ -19394,7 +19542,7 @@ def get_all( def update_runtime_configs( self, desired_runtime_config: InferenceComponentRuntimeConfig, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -19525,7 +19673,7 @@ def create( data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned(), kms_key: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["InferenceExperiment"]: """ @@ -19605,7 +19753,7 @@ def create( def get( cls, name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["InferenceExperiment"]: """ @@ -19649,6 +19797,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeInferenceExperimentResponse") inference_experiment = cls(**transformed_response) + inference_experiment._session = session return inference_experiment @Base.add_validate_call @@ -19681,7 +19830,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_inference_experiment(**operation_input_args) # deserialize response and update self @@ -19719,7 +19868,7 @@ def update( """ logger.info("Updating inference_experiment resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Name": self.name, @@ -19762,7 +19911,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Name": self.name, @@ -19778,7 +19927,7 @@ def delete( @Base.add_validate_call def start( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -19836,7 +19985,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Name": self.name, @@ -19930,7 +20079,7 @@ def get_all( last_modified_time_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["InferenceExperiment"]: """ @@ -20067,6 +20216,11 @@ def wrapper(*args, **kwargs): "subnets": {"type": "array", "items": {"type": "string"}}, }, }, + "output_config": { + "kms_key_id": {"type": "string"}, + "compiled_output_config": {"s3_output_uri": {"type": "string"}}, + "benchmark_results_output_config": {"s3_output_uri": {"type": "string"}}, + }, } return create_func( *args, @@ -20093,7 +20247,7 @@ def create( ] = Unassigned(), output_config: Optional[RecommendationJobOutputConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["InferenceRecommendationsJob"]: """ @@ -20169,7 +20323,7 @@ def create( def get( cls, job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["InferenceRecommendationsJob"]: """ @@ -20213,6 +20367,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeInferenceRecommendationsJobResponse") inference_recommendations_job = cls(**transformed_response) + inference_recommendations_job._session = session return inference_recommendations_job @Base.add_validate_call @@ -20245,7 +20400,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_inference_recommendations_job(**operation_input_args) # deserialize response and update self @@ -20272,7 +20427,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobName": self.job_name, @@ -20303,7 +20458,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobName": self.job_name, @@ -20458,7 +20613,7 @@ def get_all( sort_order: Optional[StrPipeVar] = Unassigned(), model_name_equals: Optional[StrPipeVar] = Unassigned(), model_package_version_arn_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["InferenceRecommendationsJob"]: """ @@ -20529,7 +20684,7 @@ def get_all( def get_all_steps( self, step_type: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[InferenceRecommendationsJobStep]: """ @@ -20657,6 +20812,7 @@ def wrapper(*args, **kwargs): }, "role_arn": {"type": "string"}, "human_task_config": {"ui_config": {"ui_template_s3_uri": {"type": "string"}}}, + "task_rendering_role_arn": {"type": "string"}, "label_category_config_s3_uri": {"type": "string"}, "labeling_job_algorithms_config": { "labeling_job_resource_config": { @@ -20694,7 +20850,7 @@ def create( stopping_conditions: Optional[LabelingJobStoppingConditions] = Unassigned(), labeling_job_algorithms_config: Optional[LabelingJobAlgorithmsConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["LabelingJob"]: """ @@ -20774,7 +20930,7 @@ def create( def get( cls, labeling_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["LabelingJob"]: """ @@ -20818,6 +20974,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeLabelingJobResponse") labeling_job = cls(**transformed_response) + labeling_job._session = session return labeling_job @Base.add_validate_call @@ -20850,7 +21007,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_labeling_job(**operation_input_args) # deserialize response and update self @@ -20878,7 +21035,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "LabelingJobName": self.labeling_job_name, @@ -20910,7 +21067,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "LabelingJobName": self.labeling_job_name, @@ -20994,7 +21151,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), status_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["LabelingJob"]: """ @@ -21107,7 +21264,7 @@ def create( display_name: Optional[StrPipeVar] = Unassigned(), description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["LineageGroup"]: """ @@ -21172,7 +21329,7 @@ def create( def get( cls, lineage_group_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["LineageGroup"]: """ @@ -21216,6 +21373,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeLineageGroupResponse") lineage_group = cls(**transformed_response) + lineage_group._session = session return lineage_group @Base.add_validate_call @@ -21248,7 +21406,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_lineage_group(**operation_input_args) # deserialize response and update self @@ -21275,7 +21433,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "LineageGroupName": self.lineage_group_name, @@ -21296,7 +21454,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["LineageGroup"]: """ @@ -21354,7 +21512,7 @@ def get_all( @Base.add_validate_call def get_policy( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[GetLineageGroupPolicyResponse]: """ @@ -21448,7 +21606,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), creation_time: Optional[datetime.datetime] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["LineageGroupInternal"]: """ @@ -21564,7 +21722,21 @@ def get_name(self) -> str: logger.error("Name attribute not found for object mlflow_app") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = {"role_arn": {"type": "string"}} + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "MlflowApp", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -21576,7 +21748,7 @@ def create( account_default_status: Optional[StrPipeVar] = Unassigned(), default_domain_id_list: Optional[List[StrPipeVar]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MlflowApp"]: """ @@ -21649,7 +21821,7 @@ def create( def get( cls, arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MlflowApp"]: """ @@ -21693,6 +21865,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeMlflowAppResponse") mlflow_app = cls(**transformed_response) + mlflow_app._session = session return mlflow_app @Base.add_validate_call @@ -21725,13 +21898,14 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_mlflow_app(**operation_input_args) # deserialize response and update self transform(response, "DescribeMlflowAppResponse", self) return self + @populate_inputs_decorator @Base.add_validate_call def update( self, @@ -21763,7 +21937,7 @@ def update( """ logger.info("Updating mlflow_app resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Arn": self.arn, @@ -21806,7 +21980,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Arn": self.arn, @@ -21962,7 +22136,7 @@ def get_all( account_default_status: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["MlflowApp"]: """ @@ -22111,7 +22285,7 @@ def create( automatic_model_registration: Optional[bool] = Unassigned(), weekly_maintenance_window_start: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MlflowTrackingServer"]: """ @@ -22184,7 +22358,7 @@ def create( def get( cls, tracking_server_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MlflowTrackingServer"]: """ @@ -22228,6 +22402,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeMlflowTrackingServerResponse") mlflow_tracking_server = cls(**transformed_response) + mlflow_tracking_server._session = session return mlflow_tracking_server @Base.add_validate_call @@ -22260,7 +22435,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_mlflow_tracking_server(**operation_input_args) # deserialize response and update self @@ -22298,7 +22473,7 @@ def update( """ logger.info("Updating mlflow_tracking_server resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrackingServerName": self.tracking_server_name, @@ -22339,7 +22514,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrackingServerName": self.tracking_server_name, @@ -22355,7 +22530,7 @@ def delete( @Base.add_validate_call def start( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -22413,7 +22588,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrackingServerName": self.tracking_server_name, @@ -22585,7 +22760,7 @@ def get_all( mlflow_version: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["MlflowTrackingServer"]: """ @@ -22731,7 +22906,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), vpc_config: Optional[VpcConfig] = Unassigned(), enable_network_isolation: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Model"]: """ @@ -22804,7 +22979,7 @@ def create( def get( cls, model_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Model"]: """ @@ -22847,6 +23022,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelOutput") model = cls(**transformed_response) + model._session = session return model @Base.add_validate_call @@ -22878,7 +23054,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model(**operation_input_args) # deserialize response and update self @@ -22904,7 +23080,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelName": self.model_name, @@ -22926,7 +23102,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Model"]: """ @@ -22987,7 +23163,7 @@ def get_all( def get_all_metadata( self, search_expression: Optional[ModelMetadataSearchExpression] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[ModelMetadataSummary]: """ @@ -23136,7 +23312,7 @@ def create( network_config: Optional[MonitoringNetworkConfig] = Unassigned(), stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelBiasJobDefinition"]: """ @@ -23214,7 +23390,7 @@ def create( def get( cls, job_definition_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelBiasJobDefinition"]: """ @@ -23258,6 +23434,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelBiasJobDefinitionResponse") model_bias_job_definition = cls(**transformed_response) + model_bias_job_definition._session = session return model_bias_job_definition @Base.add_validate_call @@ -23290,7 +23467,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_bias_job_definition(**operation_input_args) # deserialize response and update self @@ -23317,7 +23494,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobDefinitionName": self.job_definition_name, @@ -23340,7 +23517,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelBiasJobDefinition"]: """ @@ -23474,7 +23651,7 @@ def create( model_card_status: StrPipeVar, security_config: Optional[ModelCardSecurityConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelCard"]: """ @@ -23543,7 +23720,7 @@ def get( cls, model_card_name: StrPipeVar, model_card_version: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelCard"]: """ @@ -23589,6 +23766,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelCardResponse") model_card = cls(**transformed_response) + model_card._session = session return model_card @Base.add_validate_call @@ -23622,7 +23800,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_card(**operation_input_args) # deserialize response and update self @@ -23658,7 +23836,7 @@ def update( """ logger.info("Updating model_card resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelCardName": self.model_card_name, @@ -23698,7 +23876,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelCardName": self.model_card_name, @@ -23772,7 +23950,7 @@ def get_all( model_card_status: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelCard"]: """ @@ -23838,7 +24016,7 @@ def get_all_versions( creation_time_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[ModelCardVersionSummary]: """ @@ -23966,7 +24144,7 @@ def create( model_card_export_job_name: StrPipeVar, output_config: ModelCardExportOutputConfig, model_card_version: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelCardExportJob"]: """ @@ -24037,7 +24215,7 @@ def create( def get( cls, model_card_export_job_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelCardExportJob"]: """ @@ -24081,6 +24259,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelCardExportJobResponse") model_card_export_job = cls(**transformed_response) + model_card_export_job._session = session return model_card_export_job @Base.add_validate_call @@ -24113,7 +24292,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_card_export_job(**operation_input_args) # deserialize response and update self @@ -24193,7 +24372,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelCardExportJob"]: """ @@ -24360,7 +24539,7 @@ def create( network_config: Optional[MonitoringNetworkConfig] = Unassigned(), stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelExplainabilityJobDefinition"]: """ @@ -24439,7 +24618,7 @@ def create( def get( cls, job_definition_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelExplainabilityJobDefinition"]: """ @@ -24485,6 +24664,7 @@ def get( response, "DescribeModelExplainabilityJobDefinitionResponse" ) model_explainability_job_definition = cls(**transformed_response) + model_explainability_job_definition._session = session return model_explainability_job_definition @Base.add_validate_call @@ -24517,7 +24697,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_explainability_job_definition(**operation_input_args) # deserialize response and update self @@ -24544,7 +24724,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobDefinitionName": self.job_definition_name, @@ -24567,7 +24747,7 @@ def get_all( name_contains: Optional[StrPipeVar] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelExplainabilityJobDefinition"]: """ @@ -24630,7 +24810,7 @@ def get_all( list_method_kwargs=operation_input_args, ) -''' + class ModelInternal(Base): """ Class representing resource ModelInternal @@ -24671,7 +24851,7 @@ def create( model_input: CreateModelInput, account_id: Optional[StrPipeVar] = Unassigned(), auto_ml_job_arn: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["ModelInternal"]: """ @@ -24741,7 +24921,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelInput": self.model_input, @@ -24756,7 +24936,7 @@ def delete( logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") -''' + class ModelPackage(Base): """ Class representing resource ModelPackage @@ -24798,7 +24978,7 @@ class ModelPackage(Base): """ - model_package_name: Optional[str] = Unassigned() + model_package_name: StrPipeVar model_package_group_name: Optional[StrPipeVar] = Unassigned() model_package_version: Optional[int] = Unassigned() model_package_registration_type: Optional[StrPipeVar] = Unassigned() @@ -24871,6 +25051,17 @@ def wrapper(*args, **kwargs): }, "explainability": {"report": {"s3_uri": {"type": "string"}}}, }, + "deployment_specification": { + "test_input": { + "data_source": { + "s3_data_source": { + "s3_data_type": {"type": "string"}, + "s3_uri": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + } + } + } + }, "drift_check_baselines": { "bias": { "config_file": {"s3_uri": {"type": "string"}}, @@ -24936,7 +25127,7 @@ def create( security_config: Optional[ModelPackageSecurityConfig] = Unassigned(), model_card: Optional[ModelPackageModelCard] = Unassigned(), model_life_cycle: Optional[ModelLifeCycle] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelPackage"]: """ @@ -25052,7 +25243,7 @@ def create( def get( cls, model_package_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelPackage"]: """ @@ -25095,6 +25286,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelPackageOutput") model_package = cls(**transformed_response) + model_package._session = session return model_package @Base.add_validate_call @@ -25126,7 +25318,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_package(**operation_input_args) # deserialize response and update self @@ -25176,7 +25368,7 @@ def update( """ logger.info("Updating model_package resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelPackageArn": self.model_package_arn, @@ -25224,7 +25416,7 @@ def delete( ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelPackageName": self.model_package_name, @@ -25368,7 +25560,7 @@ def get_all( model_package_type: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelPackage"]: """ @@ -25435,7 +25627,7 @@ def get_all( def batch_get( self, model_package_arn_list: List[StrPipeVar], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[BatchDescribeModelPackageOutput]: """ @@ -25524,7 +25716,7 @@ def create( model_package_group_name: StrPipeVar, model_package_group_description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelPackageGroup"]: """ @@ -25589,7 +25781,7 @@ def create( def get( cls, model_package_group_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelPackageGroup"]: """ @@ -25632,6 +25824,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelPackageGroupOutput") model_package_group = cls(**transformed_response) + model_package_group._session = session return model_package_group @Base.add_validate_call @@ -25663,7 +25856,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_package_group(**operation_input_args) # deserialize response and update self @@ -25690,7 +25883,7 @@ def delete( ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ModelPackageGroupName": self.model_package_group_name, @@ -25836,7 +26029,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), cross_account_filter_option: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelPackageGroup"]: """ @@ -25898,7 +26091,7 @@ def get_all( @Base.add_validate_call def get_policy( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[str]: """ @@ -25944,7 +26137,7 @@ def get_policy( @Base.add_validate_call def delete_policy( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -25986,7 +26179,7 @@ def delete_policy( def put_policy( self, resource_policy: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -26128,7 +26321,7 @@ def create( network_config: Optional[MonitoringNetworkConfig] = Unassigned(), stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelQualityJobDefinition"]: """ @@ -26206,7 +26399,7 @@ def create( def get( cls, job_definition_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ModelQualityJobDefinition"]: """ @@ -26250,6 +26443,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeModelQualityJobDefinitionResponse") model_quality_job_definition = cls(**transformed_response) + model_quality_job_definition._session = session return model_quality_job_definition @Base.add_validate_call @@ -26282,7 +26476,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_model_quality_job_definition(**operation_input_args) # deserialize response and update self @@ -26309,7 +26503,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "JobDefinitionName": self.job_definition_name, @@ -26333,7 +26527,7 @@ def get_all( creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), variant_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ModelQualityJobDefinition"]: """ @@ -26469,7 +26663,7 @@ def update( """ logger.info("Updating monitoring_alert resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "MonitoringScheduleName": monitoring_schedule_name, @@ -26494,7 +26688,7 @@ def update( def get_all( cls, monitoring_schedule_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["MonitoringAlert"]: """ @@ -26555,7 +26749,7 @@ def list_history( creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), status_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[MonitoringAlertHistorySummary]: """ @@ -26668,7 +26862,7 @@ def get_name(self) -> str: def get( cls, monitoring_execution_id: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MonitoringExecution"]: """ @@ -26712,6 +26906,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeMonitoringExecutionResponse") monitoring_execution = cls(**transformed_response) + monitoring_execution._session = session return monitoring_execution @Base.add_validate_call @@ -26744,7 +26939,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_monitoring_execution(**operation_input_args) # deserialize response and update self @@ -26839,7 +27034,7 @@ def get_all( monitoring_job_definition_name: Optional[StrPipeVar] = Unassigned(), monitoring_type_equals: Optional[StrPipeVar] = Unassigned(), variant_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["MonitoringExecution"]: """ @@ -26997,7 +27192,133 @@ def wrapper(*args, **kwargs): } }, } - } + }, + "custom_monitoring_job_definition": { + "custom_monitoring_job_input": { + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + }, + "custom_monitoring_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + }, + "data_quality_job_definition": { + "data_quality_job_input": { + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + }, + "data_quality_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "data_quality_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}}, + "statistics_resource": {"s3_uri": {"type": "string"}}, + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + }, + "model_quality_job_definition": { + "model_quality_job_input": { + "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + }, + "model_quality_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "model_quality_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}} + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + }, + "model_bias_job_definition": { + "model_bias_job_input": { + "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + }, + "model_bias_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "model_bias_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}} + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + }, + "model_explainability_job_definition": { + "model_explainability_job_input": { + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + }, + "model_explainability_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "model_explainability_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}} + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + }, } return create_func( *args, @@ -27016,7 +27337,7 @@ def create( monitoring_schedule_name: StrPipeVar, monitoring_schedule_config: MonitoringScheduleConfig, tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MonitoringSchedule"]: """ @@ -27082,7 +27403,7 @@ def create( def get( cls, monitoring_schedule_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["MonitoringSchedule"]: """ @@ -27126,6 +27447,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeMonitoringScheduleResponse") monitoring_schedule = cls(**transformed_response) + monitoring_schedule._session = session return monitoring_schedule @Base.add_validate_call @@ -27158,7 +27480,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_monitoring_schedule(**operation_input_args) # deserialize response and update self @@ -27192,7 +27514,7 @@ def update( """ logger.info("Updating monitoring_schedule resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "MonitoringScheduleName": self.monitoring_schedule_name, @@ -27230,7 +27552,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "MonitoringScheduleName": self.monitoring_schedule_name, @@ -27246,7 +27568,7 @@ def delete( @Base.add_validate_call def start( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -27302,7 +27624,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "MonitoringScheduleName": self.monitoring_schedule_name, @@ -27393,7 +27715,7 @@ def get_all( monitoring_job_definition_name: Optional[StrPipeVar] = Unassigned(), monitoring_type_equals: Optional[StrPipeVar] = Unassigned(), variant_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["MonitoringSchedule"]: """ @@ -27580,7 +27902,7 @@ def create( instance_metadata_service_configuration: Optional[ InstanceMetadataServiceConfiguration ] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["NotebookInstance"]: """ @@ -27673,7 +27995,7 @@ def create( def get( cls, notebook_instance_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["NotebookInstance"]: """ @@ -27716,6 +28038,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeNotebookInstanceOutput") notebook_instance = cls(**transformed_response) + notebook_instance._session = session return notebook_instance @Base.add_validate_call @@ -27747,7 +28070,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_notebook_instance(**operation_input_args) # deserialize response and update self @@ -27803,7 +28126,7 @@ def update( """ logger.info("Updating notebook_instance resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "NotebookInstanceName": self.notebook_instance_name, @@ -27854,7 +28177,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "NotebookInstanceName": self.notebook_instance_name, @@ -27870,7 +28193,7 @@ def delete( @Base.add_validate_call def start( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -27925,7 +28248,7 @@ def stop(self) -> None: ``` """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "NotebookInstanceName": self.notebook_instance_name, @@ -28078,7 +28401,7 @@ def get_all( notebook_instance_lifecycle_config_name_contains: Optional[StrPipeVar] = Unassigned(), default_code_repository_contains: Optional[StrPipeVar] = Unassigned(), additional_code_repository_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["NotebookInstance"]: """ @@ -28193,7 +28516,7 @@ def create( on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["NotebookInstanceLifecycleConfig"]: """ @@ -28263,7 +28586,7 @@ def create( def get( cls, notebook_instance_lifecycle_config_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["NotebookInstanceLifecycleConfig"]: """ @@ -28306,6 +28629,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeNotebookInstanceLifecycleConfigOutput") notebook_instance_lifecycle_config = cls(**transformed_response) + notebook_instance_lifecycle_config._session = session return notebook_instance_lifecycle_config @Base.add_validate_call @@ -28337,7 +28661,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_notebook_instance_lifecycle_config(**operation_input_args) # deserialize response and update self @@ -28370,7 +28694,7 @@ def update( """ logger.info("Updating notebook_instance_lifecycle_config resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, @@ -28408,7 +28732,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, @@ -28432,7 +28756,7 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["NotebookInstanceLifecycleConfig"]: """ @@ -28595,7 +28919,7 @@ def create( optimization_environment: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), vpc_config: Optional[OptimizationVpcConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["OptimizationJob"]: """ @@ -28675,7 +28999,7 @@ def create( def get( cls, optimization_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["OptimizationJob"]: """ @@ -28719,6 +29043,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeOptimizationJobResponse") optimization_job = cls(**transformed_response) + optimization_job._session = session return optimization_job @Base.add_validate_call @@ -28751,7 +29076,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_optimization_job(**operation_input_args) # deserialize response and update self @@ -28778,7 +29103,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "OptimizationJobName": self.optimization_job_name, @@ -28809,7 +29134,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "OptimizationJobName": self.optimization_job_name, @@ -28896,7 +29221,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["OptimizationJob"]: """ @@ -29030,7 +29355,10 @@ def get_name(self) -> str: def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): - config_schema_for_resource = {"execution_role_arn": {"type": "string"}} + config_schema_for_resource = { + "execution_role_arn": {"type": "string"}, + "kms_key_id": {"type": "string"}, + } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( @@ -29058,7 +29386,7 @@ def create( enable_auto_minor_version_upgrade: Optional[bool] = Unassigned(), client_token: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["PartnerApp"]: """ @@ -29143,7 +29471,7 @@ def get( cls, arn: StrPipeVar, include_available_upgrade: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["PartnerApp"]: """ @@ -29189,6 +29517,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribePartnerAppResponse") partner_app = cls(**transformed_response) + partner_app._session = session return partner_app @Base.add_validate_call @@ -29223,7 +29552,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_partner_app(**operation_input_args) # deserialize response and update self @@ -29269,7 +29598,7 @@ def update( """ logger.info("Updating partner_app resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Arn": self.arn, @@ -29316,7 +29645,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "Arn": self.arn, @@ -29334,7 +29663,7 @@ def delete( def start( self, partner_app_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -29390,7 +29719,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "PartnerAppArn": self.partner_app_arn, @@ -29530,7 +29859,7 @@ def wait_for_delete( @Base.add_validate_call def get_all( cls, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["PartnerApp"]: """ @@ -29597,7 +29926,7 @@ def create( arn: StrPipeVar, expires_in_seconds: Optional[int] = Unassigned(), session_expiration_duration_in_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["PartnerAppPresignedUrl"]: """ @@ -29703,7 +30032,7 @@ def create( persistent_volume_configuration: PersistentVolumeConfiguration, tags: Optional[List[Tag]] = Unassigned(), owning_entity_arn: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["PersistentVolume"]: """ @@ -29777,7 +30106,7 @@ def get( cls, persistent_volume_name: StrPipeVar, domain_id: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["PersistentVolume"]: """ @@ -29823,6 +30152,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribePersistentVolumeResponse") persistent_volume = cls(**transformed_response) + persistent_volume._session = session return persistent_volume @Base.add_validate_call @@ -29856,7 +30186,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_persistent_volume(**operation_input_args) # deserialize response and update self @@ -29884,7 +30214,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "PersistentVolumeName": self.persistent_volume_name, @@ -30104,7 +30434,7 @@ def create( pipeline_description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), parallelism_configuration: Optional[ParallelismConfiguration] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Pipeline"]: """ @@ -30182,7 +30512,7 @@ def get( cls, pipeline_name: StrPipeVar, pipeline_version_id: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Pipeline"]: """ @@ -30228,6 +30558,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribePipelineResponse") pipeline = cls(**transformed_response) + pipeline._session = session return pipeline @Base.add_validate_call @@ -30262,7 +30593,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_pipeline(**operation_input_args) # deserialize response and update self @@ -30304,7 +30635,7 @@ def update( """ logger.info("Updating pipeline resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "PipelineName": self.pipeline_name, @@ -30349,7 +30680,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "PipelineName": self.pipeline_name, @@ -30484,7 +30815,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Pipeline"]: """ @@ -30602,7 +30933,7 @@ def get_name(self) -> str: def get( cls, pipeline_execution_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["PipelineExecution"]: """ @@ -30646,6 +30977,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribePipelineExecutionResponse") pipeline_execution = cls(**transformed_response) + pipeline_execution._session = session return pipeline_execution @Base.add_validate_call @@ -30678,7 +31010,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_pipeline_execution(**operation_input_args) # deserialize response and update self @@ -30713,7 +31045,7 @@ def update( """ logger.info("Updating pipeline_execution resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "PipelineExecutionArn": self.pipeline_execution_arn, @@ -30740,7 +31072,7 @@ def start( client_request_token: StrPipeVar, pipeline_parameters: Optional[List[Parameter]] = Unassigned(), mlflow_experiment_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -30807,7 +31139,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "PipelineExecutionArn": self.pipeline_execution_arn, @@ -30890,7 +31222,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["PipelineExecution"]: """ @@ -30951,7 +31283,7 @@ def get_all( @Base.add_validate_call def get_pipeline_definition( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[DescribePipelineDefinitionForExecutionResponse]: """ @@ -30999,7 +31331,7 @@ def get_pipeline_definition( def get_all_steps( self, sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[PipelineExecutionStep]: """ @@ -31052,7 +31384,7 @@ def get_all_steps( @Base.add_validate_call def get_all_parameters( self, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[Parameter]: """ @@ -31104,7 +31436,7 @@ def get_all_parameters( def retry( self, client_request_token: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -31152,7 +31484,7 @@ def send_execution_step_failure( self, callback_token: StrPipeVar, client_request_token: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -31202,7 +31534,7 @@ def send_execution_step_success( callback_token: StrPipeVar, output_parameters: Optional[List[OutputParameter]] = Unassigned(), client_request_token: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -31306,7 +31638,7 @@ def create( space_name: Optional[Union[StrPipeVar, object]] = Unassigned(), landing_uri: Optional[StrPipeVar] = Unassigned(), is_dual_stack_endpoint: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["PresignedDomainUrl"]: """ @@ -31417,7 +31749,7 @@ def create( expires_in_seconds: Optional[int] = Unassigned(), landing_uri: Optional[StrPipeVar] = Unassigned(), is_dual_stack_endpoint: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["PresignedDomainUrlWithPrincipalTag"]: """ @@ -31516,7 +31848,7 @@ def create( arn: StrPipeVar, expires_in_seconds: Optional[int] = Unassigned(), session_expiration_duration_in_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["PresignedMlflowAppUrl"]: """ @@ -31609,7 +31941,7 @@ def create( tracking_server_name: StrPipeVar, expires_in_seconds: Optional[int] = Unassigned(), session_expiration_duration_in_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["PresignedMlflowTrackingServerUrl"]: """ @@ -31699,7 +32031,7 @@ def create( cls, notebook_instance_name: Union[StrPipeVar, object], session_expiration_duration_in_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["PresignedNotebookInstanceUrl"]: """ @@ -31862,7 +32194,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), workflow_type: Optional[StrPipeVar] = Unassigned(), experiment_config: Optional[ExperimentConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ProcessingJob"]: """ @@ -31945,7 +32277,7 @@ def create( def get( cls, processing_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["ProcessingJob"]: """ @@ -31989,6 +32321,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeProcessingJobResponse") processing_job = cls(**transformed_response) + processing_job._session = session return processing_job @Base.add_validate_call @@ -32021,7 +32354,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_processing_job(**operation_input_args) # deserialize response and update self @@ -32049,7 +32382,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ProcessingJobName": self.processing_job_name, @@ -32080,7 +32413,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ProcessingJobName": self.processing_job_name, @@ -32179,7 +32512,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ProcessingJob"]: """ @@ -32242,7 +32575,7 @@ def get_all( list_method_kwargs=operation_input_args, ) -''' + class ProcessingJobInternal(Base): """ Class representing resource ProcessingJobInternal @@ -32364,7 +32697,7 @@ def create( fas_source_account: Optional[StrPipeVar] = Unassigned(), experiment_config: Optional[ExperimentConfig] = Unassigned(), identity_center_user_token: Optional[IdentityCenterUserToken] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["ProcessingJobInternal"]: """ @@ -32494,7 +32827,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ProcessingJobName": self.processing_job_name, @@ -32528,7 +32861,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ProcessingJobName": self.processing_job_name, @@ -32543,7 +32876,7 @@ def stop(self) -> None: logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") -''' + class Project(Base): """ Class representing resource Project @@ -32607,7 +32940,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), template_providers: Optional[List[CreateTemplateProvider]] = Unassigned(), workflow_disabled: Optional[bool] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Project"]: """ @@ -32676,7 +33009,7 @@ def create( def get( cls, project_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Project"]: """ @@ -32719,6 +33052,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeProjectOutput") project = cls(**transformed_response) + project._session = session return project @Base.add_validate_call @@ -32750,7 +33084,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_project(**operation_input_args) # deserialize response and update self @@ -32794,7 +33128,7 @@ def update( """ logger.info("Updating project resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ProjectName": self.project_name, @@ -32836,7 +33170,7 @@ def delete( ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "ProjectName": self.project_name, @@ -32926,7 +33260,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), project_status: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Project"]: """ @@ -33045,7 +33379,23 @@ def get_name(self) -> str: logger.error("Name attribute not found for object quota_allocation") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "quota_allocation_target": {"roles": {"type": "array", "items": {"type": "string"}}} + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "QuotaAllocation", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, @@ -33058,7 +33408,7 @@ def create( activation_state: Optional[ActivationStateV1] = Unassigned(), quota_allocation_description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["QuotaAllocation"]: """ @@ -33137,7 +33487,7 @@ def get( cls, quota_allocation_arn: StrPipeVar, quota_allocation_version: Optional[int] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["QuotaAllocation"]: """ @@ -33183,6 +33533,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeQuotaAllocationResponse") quota_allocation = cls(**transformed_response) + quota_allocation._session = session return quota_allocation @Base.add_validate_call @@ -33216,13 +33567,14 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_quota_allocation(**operation_input_args) # deserialize response and update self transform(response, "DescribeQuotaAllocationResponse", self) return self + @populate_inputs_decorator @Base.add_validate_call def update( self, @@ -33256,7 +33608,7 @@ def update( """ logger.info("Updating quota_allocation resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "QuotaAllocationArn": self.quota_allocation_arn, @@ -33300,7 +33652,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "QuotaAllocationArn": self.quota_allocation_arn, @@ -33464,7 +33816,7 @@ def get_all( cluster_arn: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["QuotaAllocation"]: """ @@ -33568,7 +33920,7 @@ def get_all( creation_time_before: Optional[datetime.datetime] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["ResourceCatalog"]: """ @@ -33635,7 +33987,7 @@ class SagemakerServicecatalogPortfolio(Base): @staticmethod @Base.add_validate_call def disable( - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -33668,7 +34020,7 @@ def disable( @staticmethod @Base.add_validate_call def enable( - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -33701,7 +34053,7 @@ def enable( @staticmethod @Base.add_validate_call def get_status( - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[str]: """ @@ -33737,6 +34089,13 @@ def get_status( return list(response.values())[0] +class Session(Base): + """ + Class representing resource Session + + """ + + class SharedModel(Base): """ Class representing resource SharedModel @@ -33787,7 +34146,7 @@ def create( comment: Optional[StrPipeVar] = Unassigned(), model_name: Optional[Union[StrPipeVar, object]] = Unassigned(), origin: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["SharedModel"]: """ @@ -33859,7 +34218,7 @@ def get( cls, shared_model_id: StrPipeVar, shared_model_version: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["SharedModel"]: """ @@ -33904,6 +34263,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeSharedModelResponse") shared_model = cls(**transformed_response) + shared_model._session = session return shared_model @Base.add_validate_call @@ -33936,7 +34296,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_shared_model(**operation_input_args) # deserialize response and update self @@ -33973,7 +34333,7 @@ def update( """ logger.info("Updating shared_model resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "SharedModelId": self.shared_model_id, @@ -34013,7 +34373,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "SharedModelId": self.shared_model_id, @@ -34035,7 +34395,7 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["SharedModel"]: """ @@ -34160,7 +34520,7 @@ def create( ownership_settings: Optional[OwnershipSettings] = Unassigned(), space_sharing_settings: Optional[SpaceSharingSettings] = Unassigned(), space_display_name: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Space"]: """ @@ -34233,7 +34593,7 @@ def get( cls, domain_id: StrPipeVar, space_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Space"]: """ @@ -34279,6 +34639,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeSpaceResponse") space = cls(**transformed_response) + space._session = session return space @Base.add_validate_call @@ -34312,7 +34673,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_space(**operation_input_args) # deserialize response and update self @@ -34347,7 +34708,7 @@ def update( """ logger.info("Updating space resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -34388,7 +34749,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -34543,7 +34904,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), domain_id_equals: Optional[StrPipeVar] = Unassigned(), space_name_contains: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Space"]: """ @@ -34644,7 +35005,7 @@ def create( studio_lifecycle_config_content: StrPipeVar, studio_lifecycle_config_app_type: StrPipeVar, tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["StudioLifecycleConfig"]: """ @@ -34713,7 +35074,7 @@ def create( def get( cls, studio_lifecycle_config_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["StudioLifecycleConfig"]: """ @@ -34757,6 +35118,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeStudioLifecycleConfigResponse") studio_lifecycle_config = cls(**transformed_response) + studio_lifecycle_config._session = session return studio_lifecycle_config @Base.add_validate_call @@ -34789,7 +35151,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_studio_lifecycle_config(**operation_input_args) # deserialize response and update self @@ -34817,7 +35179,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "StudioLifecycleConfigName": self.studio_lifecycle_config_name, @@ -34842,7 +35204,7 @@ def get_all( modified_time_after: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["StudioLifecycleConfig"]: """ @@ -34940,7 +35302,7 @@ def get_name(self) -> str: def get( cls, workteam_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["SubscribedWorkteam"]: """ @@ -34983,6 +35345,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeSubscribedWorkteamResponse") subscribed_workteam = cls(**transformed_response) + subscribed_workteam._session = session return subscribed_workteam @Base.add_validate_call @@ -35014,7 +35377,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_subscribed_workteam(**operation_input_args) # deserialize response and update self @@ -35026,7 +35389,7 @@ def refresh( def get_all( cls, name_contains: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["SubscribedWorkteam"]: """ @@ -35110,7 +35473,7 @@ def get_name(self) -> str: def get_all( cls, resource_arn: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Tag"]: """ @@ -35165,7 +35528,7 @@ def add_tags( cls, resource_arn: StrPipeVar, tags: List[Tag], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -35211,7 +35574,7 @@ def delete_tags( cls, resource_arn: StrPipeVar, tag_keys: List[StrPipeVar], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -35400,12 +35763,14 @@ def populate_inputs_decorator(create_func): def wrapper(*args, **kwargs): config_schema_for_resource = { "model_artifacts": {"s3_model_artifacts": {"type": "string"}}, - "resource_config": {"volume_kms_key_id": {"type": "string"}}, + "training_job_output": {"s3_training_job_output": {"type": "string"}}, "role_arn": {"type": "string"}, "output_data_config": { "s3_output_path": {"type": "string"}, "kms_key_id": {"type": "string"}, + "remove_job_name_from_s3_output_path": {"type": "boolean"}, }, + "resource_config": {"volume_kms_key_id": {"type": "string"}}, "vpc_config": { "security_group_ids": {"type": "array", "items": {"type": "string"}}, "subnets": {"type": "array", "items": {"type": "string"}}, @@ -35413,7 +35778,25 @@ def wrapper(*args, **kwargs): "checkpoint_config": {"s3_uri": {"type": "string"}}, "debug_hook_config": {"s3_output_path": {"type": "string"}}, "tensor_board_output_config": {"s3_output_path": {"type": "string"}}, + "upstream_platform_config": { + "credential_proxy_config": { + "customer_credential_provider_kms_key_id": {"type": "string"}, + "platform_credential_provider_kms_key_id": {"type": "string"}, + }, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, + "output_data_config": {"kms_key_id": {"type": "string"}}, + "checkpoint_config": {"s3_uri": {"type": "string"}}, + "enable_s3_context_keys_on_input_data": {"type": "boolean"}, + "execution_role": {"type": "string"}, + }, "profiler_config": {"s3_output_path": {"type": "string"}}, + "processing_job_config": { + "processing_output_config": {"kms_key_id": {"type": "string"}}, + "upstream_processing_output_config": {"kms_key_id": {"type": "string"}}, + }, } return create_func( *args, @@ -35467,7 +35850,7 @@ def create( mlflow_config: Optional[MlflowConfig] = Unassigned(), with_warm_pool_validation_error: Optional[bool] = Unassigned(), model_package_config: Optional[ModelPackageConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrainingJob"]: """ @@ -35602,7 +35985,7 @@ def create( def get( cls, training_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrainingJob"]: """ @@ -35646,6 +36029,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeTrainingJobResponse") training_job = cls(**transformed_response) + training_job._session = session return training_job @Base.add_validate_call @@ -35678,7 +36062,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_training_job(**operation_input_args) # deserialize response and update self @@ -35715,7 +36099,7 @@ def update( """ logger.info("Updating training_job resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingJobName": self.training_job_name, @@ -35757,7 +36141,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingJobName": self.training_job_name, @@ -35788,7 +36172,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().sagemaker_client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingJobName": self.training_job_name, @@ -35835,16 +36219,17 @@ def wait( instance_count = 1 # Default if not isinstance(self.resource_config, Unassigned): - if (hasattr(self.resource_config, 'instance_groups') and - self.resource_config.instance_groups and - not isinstance(self.resource_config.instance_groups, Unassigned)): + if ( + hasattr(self.resource_config, "instance_groups") + and self.resource_config.instance_groups + and not isinstance(self.resource_config.instance_groups, Unassigned) + ): instance_count = sum( instance_group.instance_count for instance_group in self.resource_config.instance_groups ) - elif hasattr(self.resource_config, 'instance_count'): + elif hasattr(self.resource_config, "instance_count"): instance_count = self.resource_config.instance_count - if logs: multi_stream_logger = MultiLogStreamHandler( log_group_name=f"/aws/sagemaker/TrainingJobs", @@ -35963,7 +36348,7 @@ def get_all( sort_order: Optional[StrPipeVar] = Unassigned(), warm_pool_status_equals: Optional[StrPipeVar] = Unassigned(), training_plan_arn_equals: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["TrainingJob"]: """ @@ -36030,7 +36415,7 @@ def get_all( list_method_kwargs=operation_input_args, ) -''' + class TrainingJobInternal(Base): """ Class representing resource TrainingJobInternal @@ -36170,7 +36555,7 @@ def create( fas_source_account: Optional[StrPipeVar] = Unassigned(), sts_context_map: Optional[Dict[StrPipeVar, StrPipeVar]] = Unassigned(), identity_center_user_token: Optional[IdentityCenterUserToken] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["TrainingJobInternal"]: """ @@ -36312,7 +36697,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingJobName": self.training_job_name, @@ -36346,7 +36731,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingJobName": self.training_job_name, @@ -36360,7 +36745,7 @@ def stop(self) -> None: logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") -''' + class TrainingPlan(Base): """ Class representing resource TrainingPlan @@ -36432,7 +36817,7 @@ def create( training_plan_offering_id: StrPipeVar, spare_instance_count_per_ultra_server: Optional[int] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrainingPlan"]: """ @@ -36499,7 +36884,7 @@ def create( def get( cls, training_plan_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrainingPlan"]: """ @@ -36543,6 +36928,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeTrainingPlanResponse") training_plan = cls(**transformed_response) + training_plan._session = session return training_plan @Base.add_validate_call @@ -36575,7 +36961,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_training_plan(**operation_input_args) # deserialize response and update self @@ -36617,7 +37003,7 @@ def update( """ logger.info("Updating training_plan resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingPlanName": self.training_plan_name, @@ -36656,7 +37042,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrainingPlanName": self.training_plan_name, @@ -36734,7 +37120,7 @@ def load( training_plan_arn: StrPipeVar, capacity_resource_arn: StrPipeVar, target_resources: List[StrPipeVar], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrainingPlan"]: """ @@ -36799,7 +37185,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), filters: Optional[List[TrainingPlanFilter]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["TrainingPlan"]: """ @@ -36986,7 +37372,7 @@ def create( credential_provider_function: Optional[StrPipeVar] = Unassigned(), credential_provider_encryption_key: Optional[StrPipeVar] = Unassigned(), experiment_config: Optional[ExperimentConfig] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TransformJob"]: """ @@ -37085,7 +37471,7 @@ def create( def get( cls, transform_job_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TransformJob"]: """ @@ -37129,6 +37515,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeTransformJobResponse") transform_job = cls(**transformed_response) + transform_job._session = session return transform_job @Base.add_validate_call @@ -37161,7 +37548,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_transform_job(**operation_input_args) # deserialize response and update self @@ -37189,7 +37576,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TransformJobName": self.transform_job_name, @@ -37220,7 +37607,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TransformJobName": self.transform_job_name, @@ -37319,7 +37706,7 @@ def get_all( status_equals: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["TransformJob"]: """ @@ -37382,7 +37769,7 @@ def get_all( list_method_kwargs=operation_input_args, ) -''' + class TransformJobInternal(Base): """ Class representing resource TransformJobInternal @@ -37498,7 +37885,7 @@ def create( billing_mode: Optional[StrPipeVar] = Unassigned(), fas_source_arn: Optional[StrPipeVar] = Unassigned(), fas_source_account: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["TransformJobInternal"]: """ @@ -37619,7 +38006,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TransformJobName": self.transform_job_name, @@ -37633,7 +38020,7 @@ def stop(self) -> None: logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") -''' + class Trial(Base): """ Class representing resource Trial @@ -37688,7 +38075,7 @@ def create( display_name: Optional[StrPipeVar] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Trial"]: """ @@ -37756,7 +38143,7 @@ def create( def get( cls, trial_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Trial"]: """ @@ -37800,6 +38187,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeTrialResponse") trial = cls(**transformed_response) + trial._session = session return trial @Base.add_validate_call @@ -37832,7 +38220,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_trial(**operation_input_args) # deserialize response and update self @@ -37865,7 +38253,7 @@ def update( """ logger.info("Updating trial resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrialName": self.trial_name, @@ -37903,7 +38291,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrialName": self.trial_name, @@ -37926,7 +38314,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Trial"]: """ @@ -38062,7 +38450,7 @@ def create( output_artifacts: Optional[Dict[StrPipeVar, TrialComponentArtifact]] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrialComponent"]: """ @@ -38139,7 +38527,7 @@ def create( def get( cls, trial_component_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["TrialComponent"]: """ @@ -38183,6 +38571,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeTrialComponentResponse") trial_component = cls(**transformed_response) + trial_component._session = session return trial_component @Base.add_validate_call @@ -38215,7 +38604,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_trial_component(**operation_input_args) # deserialize response and update self @@ -38262,7 +38651,7 @@ def update( """ logger.info("Updating trial_component resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrialComponentName": self.trial_component_name, @@ -38309,7 +38698,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrialComponentName": self.trial_component_name, @@ -38454,7 +38843,7 @@ def get_all( created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["TrialComponent"]: """ @@ -38520,7 +38909,7 @@ def get_all( def associate_trail( self, trial_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -38565,7 +38954,7 @@ def associate_trail( def disassociate_trail( self, trial_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -38610,7 +38999,7 @@ def batch_put_metrics( self, resource_arn: StrPipeVar, metric_data: List[RawMetricData], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> None: """ @@ -38655,7 +39044,7 @@ def batch_put_metrics( def batch_get_metrics( cls, metric_queries: List[MetricQuery], - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional[BatchGetMetricsResponse]: """ @@ -38770,7 +39159,7 @@ def create( output_artifacts: Optional[Dict[StrPipeVar, TrialComponentArtifact]] = Unassigned(), metadata_properties: Optional[MetadataProperties] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["TrialComponentInternal"]: """ @@ -38883,7 +39272,7 @@ def update( """ logger.info("Updating trial_component_internal resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "TrialComponentName": self.trial_component_name, @@ -38967,7 +39356,7 @@ def create( metadata_properties: Optional[MetadataProperties] = Unassigned(), source: Optional[InputTrialSource] = Unassigned(), customer_details: Optional[CustomerDetails] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> Optional["TrialInternal"]: """ @@ -39086,6 +39475,10 @@ def wrapper(*args, **kwargs): config_schema_for_resource = { "user_settings": { "execution_role": {"type": "string"}, + "environment_settings": { + "default_s3_artifact_path": {"type": "string"}, + "default_s3_kms_key_id": {"type": "string"}, + }, "security_groups": {"type": "array", "items": {"type": "string"}}, "sharing_settings": { "s3_output_path": {"type": "string"}, @@ -39111,6 +39504,10 @@ def wrapper(*args, **kwargs): "execution_role_arns": {"type": "array", "items": {"type": "string"}}, } }, + "emr_settings": { + "assumable_role_arns": {"type": "array", "items": {"type": "string"}}, + "execution_role_arns": {"type": "array", "items": {"type": "string"}}, + }, } } return create_func( @@ -39134,7 +39531,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), user_policy: Optional[StrPipeVar] = Unassigned(), user_settings: Optional[UserSettings] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["UserProfile"]: """ @@ -39210,7 +39607,7 @@ def get( cls, domain_id: StrPipeVar, user_profile_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["UserProfile"]: """ @@ -39258,6 +39655,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeUserProfileResponse") user_profile = cls(**transformed_response) + user_profile._session = session return user_profile @Base.add_validate_call @@ -39293,7 +39691,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_user_profile(**operation_input_args) # deserialize response and update self @@ -39329,7 +39727,7 @@ def update( """ logger.info("Updating user_profile resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -39370,7 +39768,7 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "DomainId": self.domain_id, @@ -39529,7 +39927,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), domain_id_equals: Optional[StrPipeVar] = Unassigned(), user_profile_name_contains: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["UserProfile"]: """ @@ -39645,7 +40043,7 @@ def create( tags: Optional[List[Tag]] = Unassigned(), workforce_vpc_config: Optional[WorkforceVpcConfigRequest] = Unassigned(), ip_address_type: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Workforce"]: """ @@ -39715,7 +40113,7 @@ def create( def get( cls, workforce_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Workforce"]: """ @@ -39758,6 +40156,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeWorkforceResponse") workforce = cls(**transformed_response) + workforce._session = session return workforce @Base.add_validate_call @@ -39789,7 +40188,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_workforce(**operation_input_args) # deserialize response and update self @@ -39831,7 +40230,7 @@ def update( """ logger.info("Updating workforce resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "WorkforceName": self.workforce_name, @@ -39871,7 +40270,7 @@ def delete( ``` """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "WorkforceName": self.workforce_name, @@ -40008,7 +40407,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), name_contains: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Workforce"]: """ @@ -40103,7 +40502,7 @@ def create( notification_configuration: Optional[NotificationConfiguration] = Unassigned(), worker_access_configuration: Optional[WorkerAccessConfiguration] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Workteam"]: """ @@ -40179,7 +40578,7 @@ def create( def get( cls, workteam_name: StrPipeVar, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["Workteam"]: """ @@ -40222,6 +40621,7 @@ def get( # deserialize the response transformed_response = transform(response, "DescribeWorkteamResponse") workteam = cls(**transformed_response) + workteam._session = session return workteam @Base.add_validate_call @@ -40253,7 +40653,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) response = client.describe_workteam(**operation_input_args) # deserialize response and update self @@ -40298,7 +40698,7 @@ def update( """ logger.info("Updating workteam resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "WorkteamName": self.workteam_name, @@ -40341,7 +40741,7 @@ def delete( ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, "_session", None)) operation_input_args = { "WorkteamName": self.workteam_name, @@ -40361,7 +40761,7 @@ def get_all( sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), name_contains: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["Workteam"]: """ @@ -40423,7 +40823,7 @@ def get_all_labeling_jobs( job_reference_code_contains: Optional[StrPipeVar] = Unassigned(), sort_by: Optional[StrPipeVar] = Unassigned(), sort_order: Optional[StrPipeVar] = Unassigned(), - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[str] = None, ) -> ResourceIterator[LabelingJob]: """ diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index adbcf6ec67..08fb48fb84 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -13,10 +13,10 @@ import datetime import warnings -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict from typing import List, Dict, Optional, Any, Union from sagemaker.core.utils.utils import Unassigned -from sagemaker.core.helper.pipeline_variable import StrPipeVar, IntPipeVar, BoolPipeVar +from sagemaker.core.helper.pipeline_variable import StrPipeVar # Suppress Pydantic warnings about field names shadowing parent attributes warnings.filterwarnings("ignore", message=".*shadows an attribute.*") @@ -823,7 +823,7 @@ class ModelPackageContainerDefinition(Base): """ container_hostname: Optional[StrPipeVar] = Unassigned() - image: Optional[StrPipeVar] = Unassigned() # Revert back to autogen version + image: Optional[StrPipeVar] = Unassigned() image_digest: Optional[StrPipeVar] = Unassigned() model_data_url: Optional[StrPipeVar] = Unassigned() model_data_source: Optional[ModelDataSource] = Unassigned() @@ -1324,10 +1324,10 @@ class ResourceConfig(Base): """ instance_type: Optional[StrPipeVar] = Unassigned() - instance_count: Optional[IntPipeVar] = Unassigned() - volume_size_in_gb: Optional[IntPipeVar] = Unassigned() + instance_count: Optional[int] = Unassigned() + volume_size_in_gb: Optional[int] = Unassigned() volume_kms_key_id: Optional[StrPipeVar] = Unassigned() - keep_alive_period_in_seconds: Optional[IntPipeVar] = Unassigned() + keep_alive_period_in_seconds: Optional[int] = Unassigned() capacity_reservation_ids: Optional[List[StrPipeVar]] = Unassigned() instance_groups: Optional[List[InstanceGroup]] = Unassigned() capacity_schedules_config: Optional[CapacitySchedulesConfig] = Unassigned() @@ -7542,6 +7542,52 @@ class MetricsConfig(Base): metric_publish_frequency_in_seconds: Optional[int] = Unassigned() +class CreateEndpointConfigInput(Base): + """ + CreateEndpointConfigInput + + Attributes + ---------------------- + endpoint_config_name: The name of the endpoint configuration. You specify this name in a CreateEndpoint request. + production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint. + data_capture_config + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + kms_key_id: The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. The KmsKeyId can be any of the following formats: Key ID: 1234abcd-12ab-34cd-56ef-1234567890ab Key ARN: arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab Alias name: alias/ExampleAlias Alias name ARN: arn:aws:kms:us-west-2:111122223333:alias/ExampleAlias The KMS key policy must grant permission to the IAM role that you specify in your CreateEndpoint, UpdateEndpoint requests. For more information, refer to the Amazon Web Services Key Management Service section Using Key Policies in Amazon Web Services KMS Certain Nitro-based instances include local storage, dependent on the instance type. Local storage volumes are encrypted using a hardware module on the instance. You can't request a KmsKeyId when using an instance type with local storage. If any of the models that you specify in the ProductionVariants parameter use nitro-based instances with local storage, do not specify a value for the KmsKeyId parameter. If you specify a value for KmsKeyId when using any nitro-based instances with local storage, the call to CreateEndpointConfig fails. For a list of instance types that support local instance storage, see Instance Store Volumes. For more information about local instance storage encryption, see SSD Instance Store Volumes. + async_inference_config: Specifies configuration for how an endpoint performs asynchronous inference. This is a required field in order for your Endpoint to be invoked using InvokeEndpointAsync. + explainer_config: A member of CreateEndpointConfig that enables explainers. + shadow_production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. If you use this field, you can only specify one variant for ProductionVariants and one variant for ShadowProductionVariants. + execution_role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker AI can assume to perform actions on your behalf. For more information, see SageMaker AI Roles. To be able to pass this role to Amazon SageMaker AI, the caller of this action must have the iam:PassRole permission. + vpc_config + enable_network_isolation: Sets whether all model containers deployed to the endpoint are isolated. If they are, no inbound or outbound network calls can be made to or from the model containers. + metrics_config: The Configuration parameters for Utilization metrics. + """ + + endpoint_config_name: Union[StrPipeVar, object] + production_variants: List[ProductionVariant] + data_capture_config: Optional[DataCaptureConfig] = Unassigned() + tags: Optional[List[Tag]] = Unassigned() + kms_key_id: Optional[StrPipeVar] = Unassigned() + async_inference_config: Optional[AsyncInferenceConfig] = Unassigned() + explainer_config: Optional[ExplainerConfig] = Unassigned() + shadow_production_variants: Optional[List[ProductionVariant]] = Unassigned() + execution_role_arn: Optional[StrPipeVar] = Unassigned() + vpc_config: Optional[VpcConfig] = Unassigned() + enable_network_isolation: Optional[bool] = Unassigned() + metrics_config: Optional[MetricsConfig] = Unassigned() + + +class CreateEndpointConfigOutput(Base): + """ + CreateEndpointConfigOutput + + Attributes + ---------------------- + endpoint_config_arn: The Amazon Resource Name (ARN) of the endpoint configuration. + """ + + endpoint_config_arn: StrPipeVar + + class EndpointDeletionCondition(Base): """ EndpointDeletionCondition @@ -7592,6 +7638,40 @@ class DeploymentConfig(Base): auto_rollback_configuration: Optional[AutoRollbackConfig] = Unassigned() +class CreateEndpointInput(Base): + """ + CreateEndpointInput + + Attributes + ---------------------- + endpoint_name: The name of the endpoint.The name must be unique within an Amazon Web Services Region in your Amazon Web Services account. The name is case-insensitive in CreateEndpoint, but the case is preserved and must be matched in InvokeEndpoint. + endpoint_config_name: The name of an endpoint configuration. For more information, see CreateEndpointConfig. + graph_config_name + deletion_condition + deployment_config + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + """ + + endpoint_name: Union[StrPipeVar, object] + endpoint_config_name: Union[StrPipeVar, object] + graph_config_name: Optional[StrPipeVar] = Unassigned() + deletion_condition: Optional[EndpointDeletionCondition] = Unassigned() + deployment_config: Optional[DeploymentConfig] = Unassigned() + tags: Optional[List[Tag]] = Unassigned() + + +class CreateEndpointOutput(Base): + """ + CreateEndpointOutput + + Attributes + ---------------------- + endpoint_arn: The Amazon Resource Name (ARN) of the endpoint. + """ + + endpoint_arn: StrPipeVar + + class EvaluationJobModel(Base): """ EvaluationJobModel @@ -8577,9 +8657,9 @@ class InferenceComponentComputeResourceRequirements(Base): max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component. """ - min_memory_required_in_mb: Optional[int] = Unassigned() number_of_cpu_cores_required: Optional[float] = Unassigned() number_of_accelerator_devices_required: Optional[float] = Unassigned() + min_memory_required_in_mb: Optional[int] = Unassigned() max_memory_required_in_mb: Optional[int] = Unassigned() @@ -9492,6 +9572,44 @@ class InferenceExecutionConfig(Base): mode: StrPipeVar +class CreateModelInput(Base): + """ + CreateModelInput + + Attributes + ---------------------- + model_name: The name of the new model. + primary_container: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions. + containers: Specifies the containers in the inference pipeline. + inference_execution_config: Specifies details of how containers in a multi-container endpoint are called. + execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs. Deploying on ML compute instances is part of model hosting. For more information, see SageMaker Roles. To be able to pass this role to SageMaker, the caller of this API must have the iam:PassRole permission. + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + vpc_config: A VpcConfig object that specifies the VPC that you want your model to connect to. Control access to and from your model container by configuring the VPC. VpcConfig is used in hosting services and in batch transform. For more information, see Protect Endpoints by Using an Amazon Virtual Private Cloud and Protect Data in Batch Transform Jobs by Using an Amazon Virtual Private Cloud. + enable_network_isolation: Isolates the model container. No inbound or outbound network calls can be made to or from the model container. + """ + + model_name: Union[StrPipeVar, object] + primary_container: Optional[ContainerDefinition] = Unassigned() + containers: Optional[List[ContainerDefinition]] = Unassigned() + inference_execution_config: Optional[InferenceExecutionConfig] = Unassigned() + execution_role_arn: Optional[StrPipeVar] = Unassigned() + tags: Optional[List[Tag]] = Unassigned() + vpc_config: Optional[VpcConfig] = Unassigned() + enable_network_isolation: Optional[bool] = Unassigned() + + +class CreateModelOutput(Base): + """ + CreateModelOutput + + Attributes + ---------------------- + model_arn: The ARN of the model created in SageMaker. + """ + + model_arn: StrPipeVar + + class ModelPackageValidationProfile(Base): """ ModelPackageValidationProfile @@ -9770,7 +9888,7 @@ class ModelPackageSecurityConfig(Base): kms_key_id: The KMS Key ID (KMSKeyId) used for encryption of model package information. """ - kms_key_id: Optional[str] = Unassigned() + kms_key_id: StrPipeVar class ModelPackageModelCard(Base): @@ -10537,6 +10655,18 @@ class ExperimentConfig(Base): run_name: Optional[StrPipeVar] = Unassigned() +class CreateProcessingJobResponse(Base): + """ + CreateProcessingJobResponse + + Attributes + ---------------------- + processing_job_arn: The Amazon Resource Name (ARN) of the processing job. + """ + + processing_job_arn: StrPipeVar + + class ProvisioningParameter(Base): """ ProvisioningParameter @@ -11032,6 +11162,18 @@ class UpstreamPlatformConfig(Base): execution_role: Optional[StrPipeVar] = Unassigned() +class CreateTrainingJobResponse(Base): + """ + CreateTrainingJobResponse + + Attributes + ---------------------- + training_job_arn: The Amazon Resource Name (ARN) of the training job. + """ + + training_job_arn: StrPipeVar + + class DebugHookConfig(Base): """ DebugHookConfig @@ -11189,7 +11331,7 @@ class ServerlessJobConfig(Base): evaluator_arn job_spec """ - + base_model_arn: StrPipeVar job_type: StrPipeVar accept_eula: Optional[bool] = Unassigned() @@ -11264,6 +11406,18 @@ class DataProcessing(Base): join_source: Optional[StrPipeVar] = Unassigned() +class CreateTransformJobResponse(Base): + """ + CreateTransformJobResponse + + Attributes + ---------------------- + transform_job_arn: The Amazon Resource Name (ARN) of the transform job. + """ + + transform_job_arn: StrPipeVar + + class InputTrialComponentSource(Base): """ InputTrialComponentSource diff --git a/sagemaker-core/src/sagemaker/core/tools/codegen.py b/sagemaker-core/src/sagemaker/core/tools/codegen.py index dc30101da6..80e092dff3 100644 --- a/sagemaker-core/src/sagemaker/core/tools/codegen.py +++ b/sagemaker-core/src/sagemaker/core/tools/codegen.py @@ -11,7 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Generates the code for the service model.""" -from sagemaker.core.utils.utils import reformat_file_with_black from sagemaker.core.tools.shapes_codegen import ShapesCodeGen from sagemaker.core.tools.resources_codegen import ResourcesCodeGen from typing import Optional @@ -19,6 +18,14 @@ from sagemaker.core.tools.data_extractor import ServiceJsonData, load_service_jsons +# Generated files that should be reformatted after codegen +_GENERATED_FILES = [ + "src/sagemaker/core/resources.py", + "src/sagemaker/core/shapes/shapes.py", + "src/sagemaker/core/config_schema.py", +] + + def generate_code( shapes_code_gen: Optional[ShapesCodeGen] = None, resources_code_gen: Optional[ShapesCodeGen] = None, @@ -37,6 +44,10 @@ def generate_code( Returns: None """ + # Import lazily to avoid circular import through sagemaker.core.__init__ + # which imports processing -> resources (the file we are generating) + from sagemaker.core.utils.utils import reformat_file_with_black + service_json_data: ServiceJsonData = load_service_jsons() shapes_code_gen = shapes_code_gen or ShapesCodeGen() @@ -45,7 +56,10 @@ def generate_code( ) shapes_code_gen.generate_shapes() - reformat_file_with_black(".") + + # Only reformat the generated files, not the entire directory + for generated_file in _GENERATED_FILES: + reformat_file_with_black(generated_file) """ diff --git a/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py b/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py index 2780ca14e7..0140e72eb0 100644 --- a/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py +++ b/sagemaker-core/src/sagemaker/core/tools/resources_codegen.py @@ -181,7 +181,7 @@ def generate_imports(self) -> str: "import functools", "from pydantic import validate_call", "from typing import Dict, List, Literal, Optional, Union, Any\n" - "from boto3.session import Session", + "from boto3.session import Session as Boto3Session", "from rich.console import Group", "from rich.live import Live", "from rich.panel import Panel", @@ -940,7 +940,7 @@ def generate_create_method(self, resource_name: str, **kwargs) -> str: add_indent("cls,\n", 4) + create_args + "\n" - + add_indent("session: Optional[Session] = None,\n", 4) + + add_indent("session: Optional[Boto3Session] = None,\n", 4) + add_indent("region: Optional[str] = None,", 4) ) formatted_method = GENERIC_METHOD_TEMPLATE.format( @@ -1347,7 +1347,7 @@ def generate_start_method(self, resource_name: str, **kwargs) -> str: operation_metadata, False, resource_attributes ) exclude_resource_attrs = resource_attributes - method_args += add_indent("session: Optional[Session] = None,\n", 4) + method_args += add_indent("session: Optional[Boto3Session] = None,\n", 4) method_args += add_indent("region: Optional[str] = None,", 4) serialize_operation_input = SERIALIZE_INPUT_TEMPLATE.format( @@ -1448,7 +1448,7 @@ def generate_method(self, method: Method, resource_attributes: list): operation_metadata, False, resource_attributes ) exclude_resource_attrs = resource_attributes - method_args += add_indent("session: Optional[Session] = None,\n", 4) + method_args += add_indent("session: Optional[Boto3Session] = None,\n", 4) method_args += add_indent("region: Optional[str] = None,", 4) initialize_client = INITIALIZE_CLIENT_TEMPLATE.format(service_name=method.service_name) @@ -1573,7 +1573,7 @@ def generate_additional_get_all_method(self, method: Method, resource_attributes operation_metadata, False, resource_attributes, exclude_list ) exclude_resource_attrs = resource_attributes - method_args += add_indent("session: Optional[Session] = None,\n", 4) + method_args += add_indent("session: Optional[Boto3Session] = None,\n", 4) method_args += add_indent("region: Optional[str] = None,", 4) if method.return_type == method.resource_name: @@ -1664,19 +1664,21 @@ def _get_instance_count_ref(self, resource_name: str) -> str: """ if resource_name == "TrainingJob": - return """1 # Default - if not isinstance(self.resource_config, Unassigned): - if ( - hasattr(self.resource_config, "instance_groups") - and self.resource_config.instance_groups - and not isinstance(self.resource_config.instance_groups, Unassigned) - ): - instance_count = sum( - instance_group.instance_count - for instance_group in self.resource_config.instance_groups - ) - elif hasattr(self.resource_config, "instance_count"): - instance_count = self.resource_config.instance_count""" + return ( + "1 # Default\n" + "if not isinstance(self.resource_config, Unassigned):\n" + " if (\n" + ' hasattr(self.resource_config, "instance_groups")\n' + " and self.resource_config.instance_groups\n" + " and not isinstance(self.resource_config.instance_groups, Unassigned)\n" + " ):\n" + " instance_count = sum(\n" + " instance_group.instance_count\n" + " for instance_group in self.resource_config.instance_groups\n" + " )\n" + ' elif hasattr(self.resource_config, "instance_count"):\n' + " instance_count = self.resource_config.instance_count" + ) elif resource_name == "TransformJob": return "self.transform_resources.instance_count" elif resource_name == "ProcessingJob": diff --git a/sagemaker-core/src/sagemaker/core/tools/shapes_codegen.py b/sagemaker-core/src/sagemaker/core/tools/shapes_codegen.py index 2242804888..5df3d697f0 100644 --- a/sagemaker-core/src/sagemaker/core/tools/shapes_codegen.py +++ b/sagemaker-core/src/sagemaker/core/tools/shapes_codegen.py @@ -16,6 +16,7 @@ export PYTHONPATH=:$PYTHONPATH """ import os +from functools import lru_cache from sagemaker.core.utils.code_injection.codec import pascal_to_snake from sagemaker.core.tools.constants import ( @@ -247,12 +248,60 @@ def _filter_input_output_shapes(self, shape): required_output_shapes.add(method.return_type) if shape in operation_input_output_shapes and shape not in required_output_shapes: + # Before filtering out, check if this shape is transitively referenced + # by any operation input/output shape that is itself used as a resource + # class attribute. For example, CreateEndpointConfigInput is an operation + # input shape but is also a member type inside + # CreateEndpointConfigInputInternal, which feeds the + # EndpointConfigInternal resource class. + if shape in self._get_shapes_required_by_resources(): + return True return False return True + @lru_cache(maxsize=1) + def _get_shapes_required_by_resources(self): + """Collect all shapes transitively referenced by resource class attributes. + + Resource classes derive their attributes from operation input/output shapes. + Some of those attributes reference shapes that are also operation input/output + shapes (and would normally be filtered out). This method finds those shapes + so they can be kept. + """ + required = set() + resource_plan = self.resources_extractor.get_resource_plan() + + for _, row in resource_plan.iterrows(): + resource_name = row["resource_name"] + class_methods = row["class_methods"] + + # Determine which shapes feed this resource's class attributes + attr_shapes = [] + if "get" in class_methods: + op = self.combined_operations.get("Describe" + resource_name) + if op: + attr_shapes.append(op["output"]["shape"]) + elif "create" in class_methods: + op = self.combined_operations.get("Create" + resource_name) + if op: + attr_shapes.append(op["input"]["shape"]) + attr_shapes.append(op["output"]["shape"]) + + # Walk one level of members to find referenced structure shapes + for attr_shape in attr_shapes: + shape_def = self.combined_shapes.get(attr_shape, {}) + for member_attrs in shape_def.get("members", {}).values(): + member_shape_name = member_attrs.get("shape") + if member_shape_name and member_shape_name in self.combined_shapes: + member_shape_def = self.combined_shapes[member_shape_name] + if member_shape_def.get("type") == "structure": + required.add(member_shape_name) + + return required + def generate_shapes( self, - output_folder=GENERATED_CLASSES_LOCATION, + output_folder=GENERATED_CLASSES_LOCATION + "/shapes", file_name=SHAPES_CODEGEN_FILE_NAME, ) -> None: """ diff --git a/sagemaker-core/src/sagemaker/core/tools/templates.py b/sagemaker-core/src/sagemaker/core/tools/templates.py index 1c84c91668..a2113a10d4 100644 --- a/sagemaker-core/src/sagemaker/core/tools/templates.py +++ b/sagemaker-core/src/sagemaker/core/tools/templates.py @@ -39,7 +39,7 @@ class {class_name}: def create( cls, {create_args} - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["{resource_name}"]: {docstring} @@ -70,7 +70,7 @@ def create( def create( cls, {create_args} - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["{resource_name}"]: {docstring} @@ -101,7 +101,7 @@ def create( def load( cls, {import_args} - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["{resource_name}"]: {docstring} @@ -152,7 +152,7 @@ def update( ) -> Optional["{resource_name}"]: {docstring} logger.info("Updating {resource_lower} resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, '_session', None)) operation_input_args = {{ {operation_input_args} @@ -179,7 +179,7 @@ def update( ) -> Optional["{resource_name}"]: {docstring} logger.info("Updating {resource_lower} resource.") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, '_session', None)) operation_input_args = {{ {operation_input_args} @@ -213,7 +213,7 @@ def wrapper(*args, **kwargs): def get( cls, {describe_args} - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["{resource_name}"]: {docstring} @@ -232,6 +232,7 @@ def get( # deserialize the response transformed_response = transform(response, '{describe_operation_output_shape}') {resource_lower} = cls(**transformed_response) + {resource_lower}._session = session return {resource_lower} """ @@ -249,7 +250,7 @@ def refresh( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, '_session', None)) response = client.{operation}(**operation_input_args) # deserialize response and update self @@ -465,7 +466,7 @@ def delete( {delete_args} ) -> None: {docstring} - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client(session=getattr(self, '_session', None)) operation_input_args = {{ {operation_input_args} @@ -483,7 +484,7 @@ def delete( @Base.add_validate_call def stop(self) -> None: {docstring} - client = SageMakerClient().sagemaker_client + client = Base.get_sagemaker_client(session=getattr(self, '_session', None)) operation_input_args = {{ {operation_input_args} @@ -503,7 +504,7 @@ def stop(self) -> None: def get_all( cls, {get_all_args} - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["{resource}"]: {docstring} @@ -527,7 +528,7 @@ def get_all( @Base.add_validate_call def get_all( cls, - session: Optional[Session] = None, + session: Optional[Boto3Session] = None, region: Optional[StrPipeVar] = None, ) -> ResourceIterator["{resource}"]: """ diff --git a/sagemaker-core/tests/integ/test_session_wait_e2e.py b/sagemaker-core/tests/integ/test_session_wait_e2e.py new file mode 100644 index 0000000000..53ac0013c0 --- /dev/null +++ b/sagemaker-core/tests/integ/test_session_wait_e2e.py @@ -0,0 +1,708 @@ +"""End-to-end integration tests for session propagation in wait flows. + +Tests that wait=True works correctly for various resource types when using +sagemaker_session. These tests create real AWS resources and wait for them +to complete, verifying the session fix from GitHub issue #5765. + +Usage: + python -m pytest tests/integ/test_session_wait_e2e.py -v -s + # Or run individual tests: + python -m pytest tests/integ/test_session_wait_e2e.py::test_processing_job_wait -v -s + python -m pytest tests/integ/test_session_wait_e2e.py::test_training_job_wait -v -s + python -m pytest tests/integ/test_session_wait_e2e.py::test_training_job_wait_via_resource -v -s + python -m pytest tests/integ/test_session_wait_e2e.py::test_transform_job_wait -v -s + +Prerequisites: + - Valid AWS credentials configured + - IAM role with SageMaker permissions + - pip install sagemaker-core (from this repo) + +Note: These tests create real SageMaker jobs and incur AWS costs. + Each test takes 3-10 minutes to complete. +""" + +import os +import time +import tempfile +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core import image_uris + + +# ── Shared fixtures ────────────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +def sm_session(): + """Create a SageMaker session for all tests.""" + return Session() + + +@pytest.fixture(scope="module") +def role(sm_session): + """Get the execution role.""" + return get_execution_role() + + +@pytest.fixture(scope="module") +def region(sm_session): + """Get the region.""" + return sm_session.boto_region_name + + +@pytest.fixture(scope="module") +def training_image(region): + """Get a PyTorch training image URI.""" + return image_uris.retrieve( + framework="pytorch", + region=region, + version="2.2.0", + py_version="py310", + instance_type="ml.m5.large", + image_scope="training", + ) + + +@pytest.fixture(scope="module") +def processing_image(): + """Get a processing image URI.""" + return "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.2.0-cpu-py310" + + +# ── Test 1: Processing job wait via ScriptProcessor ───────────────────────── + +def test_processing_job_wait(sm_session, role, processing_image): + """Test that ScriptProcessor.run(wait=True) works with sagemaker_session. + + This is the primary flow reported in GitHub issue #5765. + The ScriptProcessor passes sagemaker_session, and wait=True triggers + ProcessingJob.refresh() which must use the session's credentials. + """ + from sagemaker.core.processing import ScriptProcessor + + script_dir = tempfile.mkdtemp() + script_path = os.path.join(script_dir, "hello.py") + with open(script_path, "w") as f: + f.write('print("Hello from processing job!")\n') + + processor = ScriptProcessor( + image_uri=processing_image, + command=["python3"], + role=role, + instance_count=1, + instance_type="ml.m5.large", + sagemaker_session=sm_session, + base_job_name="integ-session-proc", + ) + + start = time.time() + processor.run(code=script_path, wait=True, logs=False) + elapsed = time.time() - start + + assert processor.latest_job is not None + assert processor.latest_job.processing_job_status in ("Completed", "Failed", "Stopped") + print(f"\nProcessing job completed in {elapsed:.0f}s") + print(f"Job: {processor.latest_job.processing_job_name}") + print(f"Status: {processor.latest_job.processing_job_status}") + + # Verify the job actually completed (not just that wait returned) + assert processor.latest_job.processing_job_status == "Completed" + + +# ── Test 2: Training job wait via ModelTrainer ────────────────────────────── + +def test_training_job_wait(sm_session, role, region, training_image): + """Test that ModelTrainer.train(wait=True) works with sagemaker_session. + + This is the other primary flow reported in GitHub issue #5765. + ModelTrainer passes sagemaker_session, creates a TrainingJob, and + wait=True triggers TrainingJob.refresh() which must use the session. + """ + from sagemaker.train.model_trainer import ModelTrainer + from sagemaker.core.training.configs import Compute, SourceCode + + script_dir = tempfile.mkdtemp() + script_path = os.path.join(script_dir, "train.py") + with open(script_path, "w") as f: + f.write( + 'import os\n' + 'print("Hello from training job!")\n' + 'model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")\n' + 'os.makedirs(model_dir, exist_ok=True)\n' + 'with open(os.path.join(model_dir, "dummy.txt"), "w") as f:\n' + ' f.write("done")\n' + 'print("Training complete!")\n' + ) + + trainer = ModelTrainer( + training_image=training_image, + source_code=SourceCode(source_dir=script_dir, entry_script="train.py"), + compute=Compute(instance_type="ml.m5.large", instance_count=1), + role=role, + sagemaker_session=sm_session, + base_job_name="integ-session-train", + ) + + start = time.time() + trainer.train(wait=True, logs=False) + elapsed = time.time() - start + + job = trainer._latest_training_job + assert job is not None + print(f"\nTraining job completed in {elapsed:.0f}s") + print(f"Job: {job.training_job_name}") + print(f"Status: {job.training_job_status}") + + assert job.training_job_status == "Completed" + + +# ── Test 3: Training job wait via resource class directly ─────────────────── + +def test_training_job_wait_via_resource(sm_session, role, region, training_image): + """Test TrainingJob.create() + wait() using the resource class directly. + + This bypasses ModelTrainer and tests the resource class session + propagation directly: create() stores _session, wait() calls refresh() + which uses the stored _session. + + Uses a simple container command instead of a training script to avoid + script packaging complexity. + """ + from sagemaker.core.resources import TrainingJob + from sagemaker.core.shapes import ( + AlgorithmSpecification, + OutputDataConfig, + ResourceConfig, + StoppingCondition, + ) + + bucket = sm_session.default_bucket() + prefix = f"integ-session-direct-{int(time.time())}" + job_name = f"integ-direct-{int(time.time())}" + + # Use container_entrypoint + container_arguments to run inline code + # This avoids the script packaging issue entirely + training_job = TrainingJob.create( + training_job_name=job_name, + role_arn=role, + algorithm_specification=AlgorithmSpecification( + training_image=training_image, + training_input_mode="File", + container_entrypoint=["python3", "-c"], + container_arguments=[ + "import os; " + "print('Direct resource class training!'); " + "model_dir = os.environ.get('SM_MODEL_DIR', '/opt/ml/model'); " + "os.makedirs(model_dir, exist_ok=True); " + "open(os.path.join(model_dir, 'dummy.txt'), 'w').write('done'); " + "print('Training complete!')" + ], + ), + output_data_config=OutputDataConfig( + s3_output_path=f"s3://{bucket}/{prefix}/output", + ), + resource_config=ResourceConfig( + instance_type="ml.m5.large", + instance_count=1, + volume_size_in_gb=10, + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + session=sm_session.boto_session, + ) + + assert training_job is not None + assert hasattr(training_job, "_session"), "TrainingJob should have _session attribute" + assert training_job._session is sm_session.boto_session, "_session should be the boto session" + + print(f"\nCreated training job: {job_name}") + print(f"Waiting for completion...") + + start = time.time() + training_job.wait() + elapsed = time.time() - start + + print(f"Training job completed in {elapsed:.0f}s") + print(f"Status: {training_job.training_job_status}") + + assert training_job.training_job_status == "Completed" + + +# ── Test 4: Processing job wait via resource class directly ───────────────── + +def test_processing_job_wait_via_resource(sm_session, role, region): + """Test ProcessingJob.create() + wait() using the resource class directly.""" + from sagemaker.core.resources import ProcessingJob + from sagemaker.core.shapes import ( + AppSpecification, + ProcessingResources, + ProcessingClusterConfig, + ) + + job_name = f"integ-session-proc-direct-{int(time.time())}" + + processing_job = ProcessingJob.create( + processing_job_name=job_name, + role_arn=role, + app_specification=AppSpecification( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.2.0-cpu-py310", + container_entrypoint=["python3", "-c", 'print("Hello from direct processing!")'], + ), + processing_resources=ProcessingResources( + cluster_config=ProcessingClusterConfig( + instance_count=1, + instance_type="ml.m5.large", + volume_size_in_gb=10, + ), + ), + session=sm_session.boto_session, + ) + + assert processing_job is not None + assert hasattr(processing_job, "_session"), "ProcessingJob should have _session attribute" + + print(f"\nCreated processing job: {job_name}") + print(f"Waiting for completion...") + + start = time.time() + processing_job.wait() + elapsed = time.time() - start + + print(f"Processing job completed in {elapsed:.0f}s") + print(f"Status: {processing_job.processing_job_status}") + + assert processing_job.processing_job_status == "Completed" + + +# ── Test 5: Verify _session survives refresh cycle ────────────────────────── + +def test_session_survives_refresh(sm_session, role): + """Test that _session persists through multiple refresh() calls. + + Creates a processing job, then manually calls refresh() multiple times + to verify the session attribute isn't lost during deserialization. + """ + from sagemaker.core.resources import ProcessingJob + from sagemaker.core.shapes import ( + AppSpecification, + ProcessingResources, + ProcessingClusterConfig, + ) + + job_name = f"integ-session-refresh-{int(time.time())}" + + job = ProcessingJob.create( + processing_job_name=job_name, + role_arn=role, + app_specification=AppSpecification( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.2.0-cpu-py310", + container_entrypoint=["python3", "-c", 'print("Refresh test!")'], + ), + processing_resources=ProcessingResources( + cluster_config=ProcessingClusterConfig( + instance_count=1, + instance_type="ml.m5.large", + volume_size_in_gb=10, + ), + ), + session=sm_session.boto_session, + ) + + original_session = job._session + + # Refresh multiple times and verify _session persists + for i in range(3): + job.refresh() + assert job._session is original_session, ( + f"_session lost after refresh #{i+1}" + ) + print(f"Refresh #{i+1}: status={job.processing_job_status}, _session intact ✓") + time.sleep(2) + + # Now wait for completion + job.wait() + assert job._session is original_session, "_session lost after wait()" + assert job.processing_job_status == "Completed" + print(f"Final status: {job.processing_job_status}, _session intact ✓") + + +# ── Test 6: TrainingJob.get() with session, then wait ─────────────────────── + +def test_get_then_wait(sm_session, role, training_image): + """Test TrainingJob.get() with session, then call wait(). + + This tests the get() → _session storage → wait() → refresh() flow + without going through create(). + """ + from sagemaker.core.resources import TrainingJob + from sagemaker.core.shapes import ( + AlgorithmSpecification, + OutputDataConfig, + ResourceConfig, + StoppingCondition, + ) + + bucket = sm_session.default_bucket() + job_name = f"integ-get-wait-{int(time.time())}" + + # Create the job using container_entrypoint to avoid script packaging + TrainingJob.create( + training_job_name=job_name, + role_arn=role, + algorithm_specification=AlgorithmSpecification( + training_image=training_image, + training_input_mode="File", + container_entrypoint=["python3", "-c"], + container_arguments=[ + "import os; " + "print('Get-then-wait test!'); " + "model_dir = os.environ.get('SM_MODEL_DIR', '/opt/ml/model'); " + "os.makedirs(model_dir, exist_ok=True); " + "open(os.path.join(model_dir, 'dummy.txt'), 'w').write('done')" + ], + ), + output_data_config=OutputDataConfig( + s3_output_path=f"s3://{bucket}/integ-get-wait/output", + ), + resource_config=ResourceConfig( + instance_type="ml.m5.large", + instance_count=1, + volume_size_in_gb=10, + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + session=sm_session.boto_session, + ) + + # Now get() the job with a session and wait + job = TrainingJob.get( + training_job_name=job_name, + session=sm_session.boto_session, + ) + + assert job._session is sm_session.boto_session, "get() should store _session" + print(f"\nGot training job: {job_name}") + print(f"_session stored: ✓") + + start = time.time() + job.wait() + elapsed = time.time() - start + + print(f"Wait completed in {elapsed:.0f}s") + print(f"Status: {job.training_job_status}") + assert job.training_job_status == "Completed" + + +# ── Test 7: Transform job wait ────────────────────────────────────────────── + +def test_transform_job_wait(sm_session, role, region, training_image): + """Test TransformJob.create() + wait() with session. + + Creates a model from a training job output, then runs a batch transform + job and waits for it. Tests TransformJob.refresh() session propagation. + """ + from sagemaker.core.resources import Model, TransformJob + from sagemaker.core.shapes import ( + ContainerDefinition, + TransformInput, + TransformDataSource, + TransformS3DataSource, + TransformOutput, + TransformResources, + ) + + bucket = sm_session.default_bucket() + ts = int(time.time()) + + # Create a dummy model artifact (empty tar.gz) + model_dir = tempfile.mkdtemp() + model_tar = os.path.join(model_dir, "model.tar.gz") + import tarfile + with tarfile.open(model_tar, "w:gz") as tar: + # Add a dummy file + dummy_path = os.path.join(model_dir, "dummy.txt") + with open(dummy_path, "w") as f: + f.write("dummy model") + tar.add(dummy_path, arcname="dummy.txt") + + # Upload model artifact + model_s3_uri = sm_session.upload_data( + path=model_tar, + bucket=bucket, + key_prefix=f"integ-transform-{ts}/model", + ) + + # Create dummy input data + input_dir = tempfile.mkdtemp() + input_path = os.path.join(input_dir, "input.csv") + with open(input_path, "w") as f: + f.write("1,2,3\n4,5,6\n") + input_s3_uri = sm_session.upload_data( + path=input_dir, + bucket=bucket, + key_prefix=f"integ-transform-{ts}/input", + ) + + # Use a simple inference image that just echoes input + # sklearn image is lightweight and handles CSV + sklearn_image = image_uris.retrieve( + framework="sklearn", + region=region, + version="1.2-1", + instance_type="ml.m5.large", + image_scope="inference", + ) + + # Create model + model_name = f"integ-transform-model-{ts}" + model = Model.create( + model_name=model_name, + primary_container=ContainerDefinition( + image=sklearn_image, + model_data_url=model_s3_uri, + ), + execution_role_arn=role, + session=sm_session.boto_session, + ) + assert model is not None + assert model._session is sm_session.boto_session + print(f"\nCreated model: {model_name}") + + # Create transform job + transform_job_name = f"integ-transform-{ts}" + transform_job = TransformJob.create( + transform_job_name=transform_job_name, + model_name=model_name, + transform_input=TransformInput( + data_source=TransformDataSource( + s3_data_source=TransformS3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{bucket}/integ-transform-{ts}/input/", + ), + ), + content_type="text/csv", + ), + transform_output=TransformOutput( + s3_output_path=f"s3://{bucket}/integ-transform-{ts}/output/", + ), + transform_resources=TransformResources( + instance_type="ml.m5.large", + instance_count=1, + ), + session=sm_session.boto_session, + ) + + assert transform_job is not None + assert hasattr(transform_job, "_session") + assert transform_job._session is sm_session.boto_session + print(f"Created transform job: {transform_job_name}") + + start = time.time() + try: + transform_job.wait() + except Exception as e: + # Transform may fail because the model isn't a real inference model, + # but the wait mechanism itself should work (session propagation) + print(f"Transform job ended with: {type(e).__name__}: {e}") + # Verify the job reached a terminal state (not a credentials error) + transform_job.refresh() + assert transform_job.transform_job_status in ("Completed", "Failed", "Stopped"), ( + f"Unexpected status: {transform_job.transform_job_status}" + ) + + elapsed = time.time() - start + print(f"Transform job finished in {elapsed:.0f}s") + print(f"Status: {transform_job.transform_job_status}") + + # Clean up model + try: + model.delete() + print(f"Deleted model: {model_name}") + except Exception: + pass + + +# ── Test 8: Endpoint wait_for_status and wait_for_delete ──────────────────── + +def test_endpoint_wait_for_status_and_delete(sm_session, role, region): + """Test Endpoint.wait_for_status() and wait_for_delete() with session. + + Creates a model + endpoint config + endpoint, waits for InService, + then deletes and waits for deletion. Tests two different wait patterns + that both use refresh() internally. + """ + from sagemaker.core.resources import Model, EndpointConfig, Endpoint + from sagemaker.core.shapes import ( + ContainerDefinition, + ProductionVariant, + ) + + bucket = sm_session.default_bucket() + ts = int(time.time()) + + # Create a dummy model artifact + model_dir = tempfile.mkdtemp() + model_tar = os.path.join(model_dir, "model.tar.gz") + import tarfile + with tarfile.open(model_tar, "w:gz") as tar: + dummy_path = os.path.join(model_dir, "dummy.txt") + with open(dummy_path, "w") as f: + f.write("dummy model") + tar.add(dummy_path, arcname="dummy.txt") + + model_s3_uri = sm_session.upload_data( + path=model_tar, + bucket=bucket, + key_prefix=f"integ-endpoint-{ts}/model", + ) + + # Use a lightweight inference image + sklearn_image = image_uris.retrieve( + framework="sklearn", + region=region, + version="1.2-1", + instance_type="ml.m5.large", + image_scope="inference", + ) + + # Create model + model_name = f"integ-ep-model-{ts}" + model = Model.create( + model_name=model_name, + primary_container=ContainerDefinition( + image=sklearn_image, + model_data_url=model_s3_uri, + ), + execution_role_arn=role, + session=sm_session.boto_session, + ) + print(f"\nCreated model: {model_name}") + + # Create endpoint config + ep_config_name = f"integ-ep-config-{ts}" + ep_config = EndpointConfig.create( + endpoint_config_name=ep_config_name, + production_variants=[ + ProductionVariant( + variant_name="AllTraffic", + model_name=model_name, + instance_type="ml.m5.large", + initial_instance_count=1, + ), + ], + session=sm_session.boto_session, + ) + print(f"Created endpoint config: {ep_config_name}") + + # Create endpoint + ep_name = f"integ-ep-{ts}" + endpoint = Endpoint.create( + endpoint_name=ep_name, + endpoint_config_name=ep_config_name, + session=sm_session.boto_session, + ) + + assert endpoint is not None + assert hasattr(endpoint, "_session") + assert endpoint._session is sm_session.boto_session + print(f"Created endpoint: {ep_name}") + + # Wait for InService + print("Waiting for endpoint to reach InService...") + start = time.time() + endpoint.wait_for_status(target_status="InService") + elapsed = time.time() - start + print(f"Endpoint InService in {elapsed:.0f}s") + assert endpoint.endpoint_status == "InService" + + # Delete endpoint and wait for deletion + print("Deleting endpoint...") + endpoint.delete() + start = time.time() + endpoint.wait_for_delete() + elapsed = time.time() - start + print(f"Endpoint deleted in {elapsed:.0f}s") + + # Clean up endpoint config and model + try: + ep_config.delete() + print(f"Deleted endpoint config: {ep_config_name}") + except Exception: + pass + try: + model.delete() + print(f"Deleted model: {model_name}") + except Exception: + pass + + +# ── Test 9: CompilationJob wait ───────────────────────────────────────────── + +def test_compilation_job_wait(sm_session, role, region): + """Test CompilationJob.create() + wait() with session. + + Compiles a model for a target device. Tests CompilationJob.refresh() + session propagation. + """ + from sagemaker.core.resources import CompilationJob + from sagemaker.core.shapes import ( + InputConfig, + OutputConfig, + StoppingCondition, + ) + + bucket = sm_session.default_bucket() + ts = int(time.time()) + + # Create a dummy model artifact (Neo expects a tar.gz with model files) + model_dir = tempfile.mkdtemp() + model_tar = os.path.join(model_dir, "model.tar.gz") + import tarfile + with tarfile.open(model_tar, "w:gz") as tar: + dummy_path = os.path.join(model_dir, "model.pth") + with open(dummy_path, "w") as f: + f.write("dummy pytorch model") + tar.add(dummy_path, arcname="model.pth") + + model_s3_uri = sm_session.upload_data( + path=model_tar, + bucket=bucket, + key_prefix=f"integ-compile-{ts}/model", + ) + + job_name = f"integ-compile-{ts}" + + compilation_job = CompilationJob.create( + compilation_job_name=job_name, + role_arn=role, + input_config=InputConfig( + s3_uri=model_s3_uri, + data_input_config='{"input0": [1, 3, 224, 224]}', + framework="PYTORCH", + ), + output_config=OutputConfig( + s3_output_location=f"s3://{bucket}/integ-compile-{ts}/output/", + target_device="ml_m5", + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=900), + session=sm_session.boto_session, + ) + + assert compilation_job is not None + assert hasattr(compilation_job, "_session") + assert compilation_job._session is sm_session.boto_session + print(f"\nCreated compilation job: {job_name}") + + start = time.time() + try: + compilation_job.wait() + except Exception as e: + # Compilation may fail because the model isn't a real PyTorch model, + # but the wait mechanism should work (session propagation) + print(f"Compilation job ended with: {type(e).__name__}: {e}") + compilation_job.refresh() + assert compilation_job.compilation_job_status in ("COMPLETED", "FAILED", "STOPPED"), ( + f"Unexpected status: {compilation_job.compilation_job_status}" + ) + + elapsed = time.time() - start + print(f"Compilation job finished in {elapsed:.0f}s") + print(f"Status: {compilation_job.compilation_job_status}") diff --git a/sagemaker-core/tests/unit/generated/test_session_propagation.py b/sagemaker-core/tests/unit/generated/test_session_propagation.py new file mode 100644 index 0000000000..029ec912f1 --- /dev/null +++ b/sagemaker-core/tests/unit/generated/test_session_propagation.py @@ -0,0 +1,335 @@ +"""Tests for session propagation fix (GitHub issue #5765). + +Verifies that when a user passes a custom session to create() or get(), +the session is stored on the resource instance and used by all instance +methods (refresh, update, delete, stop, wait). + +Before this fix, instance methods called Base.get_sagemaker_client() with +no session argument, falling back to ambient/default credentials. This +caused NoCredentialsError when the user's session used different credentials +(e.g., assumed-role via STS). +""" + +import unittest +from unittest.mock import patch, MagicMock, call + +from boto3.session import Session as BotoSession + +from sagemaker.core.resources import ( + Base, + TrainingJob, + ProcessingJob, + Endpoint, + TransformJob, + Model, +) + + +class TestSessionStoredOnGet(unittest.TestCase): + """Test that get() stores the session on the resource instance.""" + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_get_stores_session_on_training_job(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.describe_training_job.return_value = { + "TrainingJobName": "test-job", + "TrainingJobStatus": "Completed", + } + mock_transform.return_value = { + "training_job_name": "test-job", + "training_job_status": "Completed", + } + + job = TrainingJob.get(training_job_name="test-job", session=mock_session) + + assert hasattr(job, "_session") + assert job._session is mock_session + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_get_stores_session_on_processing_job(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.describe_processing_job.return_value = { + "ProcessingJobName": "test-proc", + "ProcessingJobStatus": "Completed", + } + mock_transform.return_value = { + "processing_job_name": "test-proc", + "processing_job_status": "Completed", + } + + job = ProcessingJob.get(processing_job_name="test-proc", session=mock_session) + + assert hasattr(job, "_session") + assert job._session is mock_session + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_get_stores_session_on_endpoint(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.describe_endpoint.return_value = { + "EndpointName": "test-ep", + "EndpointStatus": "InService", + } + mock_transform.return_value = { + "endpoint_name": "test-ep", + "endpoint_status": "InService", + } + + ep = Endpoint.get(endpoint_name="test-ep", session=mock_session) + + assert hasattr(ep, "_session") + assert ep._session is mock_session + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_get_stores_none_when_no_session_passed(self, mock_get_client, mock_transform): + """Backward compatibility: _session is None when no session is passed.""" + client = MagicMock() + mock_get_client.return_value = client + client.describe_training_job.return_value = { + "TrainingJobName": "test-job", + "TrainingJobStatus": "Completed", + } + mock_transform.return_value = { + "training_job_name": "test-job", + "training_job_status": "Completed", + } + + job = TrainingJob.get(training_job_name="test-job") + + assert hasattr(job, "_session") + assert job._session is None + + +class TestSessionUsedByRefresh(unittest.TestCase): + """Test that refresh() uses the stored session.""" + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_refresh_uses_stored_session(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.describe_training_job.return_value = { + "TrainingJobName": "test-job", + "TrainingJobStatus": "Completed", + } + + job = TrainingJob(training_job_name="test-job") + job._session = mock_session + + job.refresh() + + # Verify get_sagemaker_client was called with the stored session + mock_get_client.assert_called_with(session=mock_session) + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_refresh_uses_none_session_when_not_set(self, mock_get_client, mock_transform): + """Backward compatibility: refresh works when _session is not set.""" + client = MagicMock() + mock_get_client.return_value = client + client.describe_processing_job.return_value = { + "ProcessingJobName": "test-proc", + "ProcessingJobStatus": "Completed", + } + + job = ProcessingJob(processing_job_name="test-proc") + # Don't set _session — getattr should return None + + job.refresh() + + mock_get_client.assert_called_with(session=None) + + +class TestSessionUsedByDelete(unittest.TestCase): + """Test that delete() uses the stored session.""" + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_delete_uses_stored_session(self, mock_get_client): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.delete_model.return_value = {} + + model = Model(model_name="test-model") + model._session = mock_session + + model.delete() + + mock_get_client.assert_called_with(session=mock_session) + + +class TestSessionUsedByStop(unittest.TestCase): + """Test that stop() uses the stored session instead of SageMakerClient().""" + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_stop_uses_stored_session(self, mock_get_client): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.stop_training_job.return_value = {} + + job = TrainingJob(training_job_name="test-job") + job._session = mock_session + + job.stop() + + mock_get_client.assert_called_with(session=mock_session) + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_stop_uses_none_session_when_not_set(self, mock_get_client): + """Backward compatibility: stop works when _session is not set.""" + client = MagicMock() + mock_get_client.return_value = client + client.stop_training_job.return_value = {} + + job = TrainingJob(training_job_name="test-job") + + job.stop() + + mock_get_client.assert_called_with(session=None) + + +class TestSessionUsedByUpdate(unittest.TestCase): + """Test that update() uses the stored session.""" + + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_update_uses_stored_session(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + client.update_training_job.return_value = {} + client.describe_training_job.return_value = { + "TrainingJobName": "test-job", + "TrainingJobStatus": "Completed", + } + + job = TrainingJob(training_job_name="test-job") + job._session = mock_session + + job.update() + + # update() calls get_sagemaker_client, then refresh() also calls it + # Both should use the stored session + for c in mock_get_client.call_args_list: + assert c == call(session=mock_session) + + +class TestSessionFlowsThroughCreate(unittest.TestCase): + """Test that create() -> get() stores session on the returned instance.""" + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_create_stores_session_via_get(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + + # create() calls the API then calls get() + client.create_model.return_value = {"ModelArn": "arn:aws:sagemaker:us-west-2:123:model/m"} + client.describe_model.return_value = { + "ModelName": "test-model", + "ModelArn": "arn:aws:sagemaker:us-west-2:123:model/m", + } + mock_transform.return_value = { + "model_name": "test-model", + "model_arn": "arn:aws:sagemaker:us-west-2:123:model/m", + } + + model = Model.create( + model_name="test-model", + session=mock_session, + ) + + # The session should be stored on the instance returned by get() + assert hasattr(model, "_session") + assert model._session is mock_session + + +class TestSessionPropagationEndToEnd(unittest.TestCase): + """End-to-end test: create with session, then refresh uses that session.""" + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_create_then_refresh_uses_same_session(self, mock_get_client, mock_transform): + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + + client.create_model.return_value = {"ModelArn": "arn:aws:sagemaker:us-west-2:123:model/m"} + client.describe_model.return_value = { + "ModelName": "test-model", + "ModelArn": "arn:aws:sagemaker:us-west-2:123:model/m", + } + mock_transform.return_value = { + "model_name": "test-model", + "model_arn": "arn:aws:sagemaker:us-west-2:123:model/m", + } + + # Step 1: Create with session + model = Model.create(model_name="test-model", session=mock_session) + + # Reset mock to track refresh calls + mock_get_client.reset_mock() + + # Step 2: Refresh should use the same session + model.refresh() + + mock_get_client.assert_called_with(session=mock_session) + + +class TestAllResourceTypesHaveSession(unittest.TestCase): + """Verify that all resource types with get() store _session.""" + + @patch("sagemaker.core.resources.validate_call", lambda **kwargs: lambda func: func) + @patch("sagemaker.core.resources.transform") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_multiple_resource_types_store_session(self, mock_get_client, mock_transform): + """Spot-check several resource types to verify _session is stored.""" + mock_session = MagicMock(spec=BotoSession) + client = MagicMock() + mock_get_client.return_value = client + + test_cases = [ + (TrainingJob, "describe_training_job", "TrainingJobName", "training_job_name"), + (ProcessingJob, "describe_processing_job", "ProcessingJobName", "processing_job_name"), + (TransformJob, "describe_transform_job", "TransformJobName", "transform_job_name"), + ] + + for resource_cls, describe_method, api_key, attr_key in test_cases: + with self.subTest(resource=resource_cls.__name__): + getattr(client, describe_method).return_value = { + api_key: "test-name", + } + mock_transform.return_value = { + attr_key: "test-name", + } + + instance = resource_cls.get(**{attr_key: "test-name"}, session=mock_session) + + assert hasattr(instance, "_session"), ( + f"{resource_cls.__name__} missing _session attribute" + ) + assert instance._session is mock_session, ( + f"{resource_cls.__name__}._session is not the passed session" + ) + + +if __name__ == "__main__": + unittest.main() From 9f439c9d23bbeeb2827301ef23fbbc1a66b94295 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 22 Apr 2026 12:13:34 -0700 Subject: [PATCH 2/2] fix: restore IntPipeVar types in shapes.py that were lost during regeneration --- sagemaker-core/src/sagemaker/core/shapes/shapes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index 08fb48fb84..7c598fdf8e 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -13,10 +13,10 @@ import datetime import warnings -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from typing import List, Dict, Optional, Any, Union from sagemaker.core.utils.utils import Unassigned -from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.core.helper.pipeline_variable import StrPipeVar, IntPipeVar, BoolPipeVar # Suppress Pydantic warnings about field names shadowing parent attributes warnings.filterwarnings("ignore", message=".*shadows an attribute.*") @@ -1324,10 +1324,10 @@ class ResourceConfig(Base): """ instance_type: Optional[StrPipeVar] = Unassigned() - instance_count: Optional[int] = Unassigned() - volume_size_in_gb: Optional[int] = Unassigned() + instance_count: Optional[IntPipeVar] = Unassigned() + volume_size_in_gb: Optional[IntPipeVar] = Unassigned() volume_kms_key_id: Optional[StrPipeVar] = Unassigned() - keep_alive_period_in_seconds: Optional[int] = Unassigned() + keep_alive_period_in_seconds: Optional[IntPipeVar] = Unassigned() capacity_reservation_ids: Optional[List[StrPipeVar]] = Unassigned() instance_groups: Optional[List[InstanceGroup]] = Unassigned() capacity_schedules_config: Optional[CapacitySchedulesConfig] = Unassigned()