Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,729 changes: 1,252 additions & 477 deletions sagemaker-core/src/sagemaker/core/config_schema.py

Large diffs are not rendered by default.

1,532 changes: 966 additions & 566 deletions sagemaker-core/src/sagemaker/core/resources.py

Large diffs are not rendered by default.

162 changes: 158 additions & 4 deletions sagemaker-core/src/sagemaker/core/shapes/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ class ModelPackageContainerDefinition(Base):
"""

container_hostname: Optional[StrPipeVar] = Unassigned()
image: Optional[StrPipeVar] = Unassigned() # Revert back to autogen version
image: Optional[StrPipeVar] = Unassigned()
image_digest: Optional[StrPipeVar] = Unassigned()
model_data_url: Optional[StrPipeVar] = Unassigned()
model_data_source: Optional[ModelDataSource] = Unassigned()
Expand Down Expand Up @@ -7542,6 +7542,52 @@ class MetricsConfig(Base):
metric_publish_frequency_in_seconds: Optional[int] = Unassigned()


class CreateEndpointConfigInput(Base):
"""
CreateEndpointConfigInput

Attributes
----------------------
endpoint_config_name: The name of the endpoint configuration. You specify this name in a CreateEndpoint request.
production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint.
data_capture_config
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
kms_key_id: The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. The KmsKeyId can be any of the following formats: Key ID: 1234abcd-12ab-34cd-56ef-1234567890ab Key ARN: arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab Alias name: alias/ExampleAlias Alias name ARN: arn:aws:kms:us-west-2:111122223333:alias/ExampleAlias The KMS key policy must grant permission to the IAM role that you specify in your CreateEndpoint, UpdateEndpoint requests. For more information, refer to the Amazon Web Services Key Management Service section Using Key Policies in Amazon Web Services KMS Certain Nitro-based instances include local storage, dependent on the instance type. Local storage volumes are encrypted using a hardware module on the instance. You can't request a KmsKeyId when using an instance type with local storage. If any of the models that you specify in the ProductionVariants parameter use nitro-based instances with local storage, do not specify a value for the KmsKeyId parameter. If you specify a value for KmsKeyId when using any nitro-based instances with local storage, the call to CreateEndpointConfig fails. For a list of instance types that support local instance storage, see Instance Store Volumes. For more information about local instance storage encryption, see SSD Instance Store Volumes.
async_inference_config: Specifies configuration for how an endpoint performs asynchronous inference. This is a required field in order for your Endpoint to be invoked using InvokeEndpointAsync.
explainer_config: A member of CreateEndpointConfig that enables explainers.
shadow_production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. If you use this field, you can only specify one variant for ProductionVariants and one variant for ShadowProductionVariants.
execution_role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker AI can assume to perform actions on your behalf. For more information, see SageMaker AI Roles. To be able to pass this role to Amazon SageMaker AI, the caller of this action must have the iam:PassRole permission.
vpc_config
enable_network_isolation: Sets whether all model containers deployed to the endpoint are isolated. If they are, no inbound or outbound network calls can be made to or from the model containers.
metrics_config: The Configuration parameters for Utilization metrics.
"""

endpoint_config_name: Union[StrPipeVar, object]
production_variants: List[ProductionVariant]
data_capture_config: Optional[DataCaptureConfig] = Unassigned()
tags: Optional[List[Tag]] = Unassigned()
kms_key_id: Optional[StrPipeVar] = Unassigned()
async_inference_config: Optional[AsyncInferenceConfig] = Unassigned()
explainer_config: Optional[ExplainerConfig] = Unassigned()
shadow_production_variants: Optional[List[ProductionVariant]] = Unassigned()
execution_role_arn: Optional[StrPipeVar] = Unassigned()
vpc_config: Optional[VpcConfig] = Unassigned()
enable_network_isolation: Optional[bool] = Unassigned()
metrics_config: Optional[MetricsConfig] = Unassigned()


class CreateEndpointConfigOutput(Base):
"""
CreateEndpointConfigOutput

Attributes
----------------------
endpoint_config_arn: The Amazon Resource Name (ARN) of the endpoint configuration.
"""

endpoint_config_arn: StrPipeVar


class EndpointDeletionCondition(Base):
"""
EndpointDeletionCondition
Expand Down Expand Up @@ -7592,6 +7638,40 @@ class DeploymentConfig(Base):
auto_rollback_configuration: Optional[AutoRollbackConfig] = Unassigned()


class CreateEndpointInput(Base):
"""
CreateEndpointInput

Attributes
----------------------
endpoint_name: The name of the endpoint.The name must be unique within an Amazon Web Services Region in your Amazon Web Services account. The name is case-insensitive in CreateEndpoint, but the case is preserved and must be matched in InvokeEndpoint.
endpoint_config_name: The name of an endpoint configuration. For more information, see CreateEndpointConfig.
graph_config_name
deletion_condition
deployment_config
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
"""

endpoint_name: Union[StrPipeVar, object]
endpoint_config_name: Union[StrPipeVar, object]
graph_config_name: Optional[StrPipeVar] = Unassigned()
deletion_condition: Optional[EndpointDeletionCondition] = Unassigned()
deployment_config: Optional[DeploymentConfig] = Unassigned()
tags: Optional[List[Tag]] = Unassigned()


class CreateEndpointOutput(Base):
"""
CreateEndpointOutput

Attributes
----------------------
endpoint_arn: The Amazon Resource Name (ARN) of the endpoint.
"""

endpoint_arn: StrPipeVar


class EvaluationJobModel(Base):
"""
EvaluationJobModel
Expand Down Expand Up @@ -8577,9 +8657,9 @@ class InferenceComponentComputeResourceRequirements(Base):
max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component.
"""

min_memory_required_in_mb: Optional[int] = Unassigned()
number_of_cpu_cores_required: Optional[float] = Unassigned()
number_of_accelerator_devices_required: Optional[float] = Unassigned()
min_memory_required_in_mb: Optional[int] = Unassigned()
max_memory_required_in_mb: Optional[int] = Unassigned()


Expand Down Expand Up @@ -9492,6 +9572,44 @@ class InferenceExecutionConfig(Base):
mode: StrPipeVar


class CreateModelInput(Base):
"""
CreateModelInput

Attributes
----------------------
model_name: The name of the new model.
primary_container: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.
containers: Specifies the containers in the inference pipeline.
inference_execution_config: Specifies details of how containers in a multi-container endpoint are called.
execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs. Deploying on ML compute instances is part of model hosting. For more information, see SageMaker Roles. To be able to pass this role to SageMaker, the caller of this API must have the iam:PassRole permission.
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
vpc_config: A VpcConfig object that specifies the VPC that you want your model to connect to. Control access to and from your model container by configuring the VPC. VpcConfig is used in hosting services and in batch transform. For more information, see Protect Endpoints by Using an Amazon Virtual Private Cloud and Protect Data in Batch Transform Jobs by Using an Amazon Virtual Private Cloud.
enable_network_isolation: Isolates the model container. No inbound or outbound network calls can be made to or from the model container.
"""

model_name: Union[StrPipeVar, object]
primary_container: Optional[ContainerDefinition] = Unassigned()
containers: Optional[List[ContainerDefinition]] = Unassigned()
inference_execution_config: Optional[InferenceExecutionConfig] = Unassigned()
execution_role_arn: Optional[StrPipeVar] = Unassigned()
tags: Optional[List[Tag]] = Unassigned()
vpc_config: Optional[VpcConfig] = Unassigned()
enable_network_isolation: Optional[bool] = Unassigned()


class CreateModelOutput(Base):
"""
CreateModelOutput

Attributes
----------------------
model_arn: The ARN of the model created in SageMaker.
"""

model_arn: StrPipeVar


class ModelPackageValidationProfile(Base):
"""
ModelPackageValidationProfile
Expand Down Expand Up @@ -9770,7 +9888,7 @@ class ModelPackageSecurityConfig(Base):
kms_key_id: The KMS Key ID (KMSKeyId) used for encryption of model package information.
"""

kms_key_id: Optional[str] = Unassigned()
kms_key_id: StrPipeVar


class ModelPackageModelCard(Base):
Expand Down Expand Up @@ -10537,6 +10655,18 @@ class ExperimentConfig(Base):
run_name: Optional[StrPipeVar] = Unassigned()


class CreateProcessingJobResponse(Base):
"""
CreateProcessingJobResponse

Attributes
----------------------
processing_job_arn: The Amazon Resource Name (ARN) of the processing job.
"""

processing_job_arn: StrPipeVar


class ProvisioningParameter(Base):
"""
ProvisioningParameter
Expand Down Expand Up @@ -11032,6 +11162,18 @@ class UpstreamPlatformConfig(Base):
execution_role: Optional[StrPipeVar] = Unassigned()


class CreateTrainingJobResponse(Base):
"""
CreateTrainingJobResponse

Attributes
----------------------
training_job_arn: The Amazon Resource Name (ARN) of the training job.
"""

training_job_arn: StrPipeVar


class DebugHookConfig(Base):
"""
DebugHookConfig
Expand Down Expand Up @@ -11189,7 +11331,7 @@ class ServerlessJobConfig(Base):
evaluator_arn
job_spec
"""

base_model_arn: StrPipeVar
job_type: StrPipeVar
accept_eula: Optional[bool] = Unassigned()
Expand Down Expand Up @@ -11264,6 +11406,18 @@ class DataProcessing(Base):
join_source: Optional[StrPipeVar] = Unassigned()


class CreateTransformJobResponse(Base):
"""
CreateTransformJobResponse

Attributes
----------------------
transform_job_arn: The Amazon Resource Name (ARN) of the transform job.
"""

transform_job_arn: StrPipeVar


class InputTrialComponentSource(Base):
"""
InputTrialComponentSource
Expand Down
18 changes: 16 additions & 2 deletions sagemaker-core/src/sagemaker/core/tools/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Generates the code for the service model."""
from sagemaker.core.utils.utils import reformat_file_with_black
from sagemaker.core.tools.shapes_codegen import ShapesCodeGen
from sagemaker.core.tools.resources_codegen import ResourcesCodeGen
from typing import Optional

from sagemaker.core.tools.data_extractor import ServiceJsonData, load_service_jsons


# Generated files that should be reformatted after codegen
_GENERATED_FILES = [
"src/sagemaker/core/resources.py",
"src/sagemaker/core/shapes/shapes.py",
"src/sagemaker/core/config_schema.py",
]


def generate_code(
shapes_code_gen: Optional[ShapesCodeGen] = None,
resources_code_gen: Optional[ShapesCodeGen] = None,
Expand All @@ -37,6 +44,10 @@ def generate_code(
Returns:
None
"""
# Import lazily to avoid circular import through sagemaker.core.__init__
# which imports processing -> resources (the file we are generating)
from sagemaker.core.utils.utils import reformat_file_with_black

service_json_data: ServiceJsonData = load_service_jsons()

shapes_code_gen = shapes_code_gen or ShapesCodeGen()
Expand All @@ -45,7 +56,10 @@ def generate_code(
)

shapes_code_gen.generate_shapes()
reformat_file_with_black(".")

# Only reformat the generated files, not the entire directory
for generated_file in _GENERATED_FILES:
reformat_file_with_black(generated_file)


"""
Expand Down
38 changes: 20 additions & 18 deletions sagemaker-core/src/sagemaker/core/tools/resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def generate_imports(self) -> str:
"import functools",
"from pydantic import validate_call",
"from typing import Dict, List, Literal, Optional, Union, Any\n"
"from boto3.session import Session",
"from boto3.session import Session as Boto3Session",
"from rich.console import Group",
"from rich.live import Live",
"from rich.panel import Panel",
Expand Down Expand Up @@ -940,7 +940,7 @@ def generate_create_method(self, resource_name: str, **kwargs) -> str:
add_indent("cls,\n", 4)
+ create_args
+ "\n"
+ add_indent("session: Optional[Session] = None,\n", 4)
+ add_indent("session: Optional[Boto3Session] = None,\n", 4)
+ add_indent("region: Optional[str] = None,", 4)
)
formatted_method = GENERIC_METHOD_TEMPLATE.format(
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def generate_start_method(self, resource_name: str, **kwargs) -> str:
operation_metadata, False, resource_attributes
)
exclude_resource_attrs = resource_attributes
method_args += add_indent("session: Optional[Session] = None,\n", 4)
method_args += add_indent("session: Optional[Boto3Session] = None,\n", 4)
method_args += add_indent("region: Optional[str] = None,", 4)

serialize_operation_input = SERIALIZE_INPUT_TEMPLATE.format(
Expand Down Expand Up @@ -1448,7 +1448,7 @@ def generate_method(self, method: Method, resource_attributes: list):
operation_metadata, False, resource_attributes
)
exclude_resource_attrs = resource_attributes
method_args += add_indent("session: Optional[Session] = None,\n", 4)
method_args += add_indent("session: Optional[Boto3Session] = None,\n", 4)
method_args += add_indent("region: Optional[str] = None,", 4)

initialize_client = INITIALIZE_CLIENT_TEMPLATE.format(service_name=method.service_name)
Expand Down Expand Up @@ -1573,7 +1573,7 @@ def generate_additional_get_all_method(self, method: Method, resource_attributes
operation_metadata, False, resource_attributes, exclude_list
)
exclude_resource_attrs = resource_attributes
method_args += add_indent("session: Optional[Session] = None,\n", 4)
method_args += add_indent("session: Optional[Boto3Session] = None,\n", 4)
method_args += add_indent("region: Optional[str] = None,", 4)

if method.return_type == method.resource_name:
Expand Down Expand Up @@ -1664,19 +1664,21 @@ def _get_instance_count_ref(self, resource_name: str) -> str:
"""

if resource_name == "TrainingJob":
return """1 # Default
if not isinstance(self.resource_config, Unassigned):
if (
hasattr(self.resource_config, "instance_groups")
and self.resource_config.instance_groups
and not isinstance(self.resource_config.instance_groups, Unassigned)
):
instance_count = sum(
instance_group.instance_count
for instance_group in self.resource_config.instance_groups
)
elif hasattr(self.resource_config, "instance_count"):
instance_count = self.resource_config.instance_count"""
return (
"1 # Default\n"
"if not isinstance(self.resource_config, Unassigned):\n"
" if (\n"
' hasattr(self.resource_config, "instance_groups")\n'
" and self.resource_config.instance_groups\n"
" and not isinstance(self.resource_config.instance_groups, Unassigned)\n"
" ):\n"
" instance_count = sum(\n"
" instance_group.instance_count\n"
" for instance_group in self.resource_config.instance_groups\n"
" )\n"
' elif hasattr(self.resource_config, "instance_count"):\n'
" instance_count = self.resource_config.instance_count"
)
elif resource_name == "TransformJob":
return "self.transform_resources.instance_count"
elif resource_name == "ProcessingJob":
Expand Down
Loading
Loading