Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 180 additions & 7 deletions sagemaker-train/src/sagemaker/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,7 @@ def _build_driver_and_code_channels(cls, model_trainer):
fpath = os.path.join(root, f)
arcname = os.path.relpath(fpath, source_code.source_dir)
tar.add(fpath, arcname=arcname)
s3_client = session.boto_session.client(
"s3", region_name=session.boto_region_name
)
s3_client = session.boto_session.client("s3", region_name=session.boto_region_name)
s3_client.upload_file(tar_path, bucket, s3_key)
model_trainer.hyperparameters["sagemaker_submit_directory"] = (
f"s3://{bucket}/{s3_key}"
Expand Down Expand Up @@ -1260,7 +1258,19 @@ def _add_model_trainer(
def _start_tuning_job(self, inputs):
"""Start a new hyperparameter tuning job using HyperParameterTuningJob."""
tuning_job_config = self._build_tuning_job_config()
training_job_definition = self._build_training_job_definition(inputs)

is_multi_algo = (
self.model_trainer is None
and self.model_trainer_dict is not None
and len(self.model_trainer_dict) > 0
)

if is_multi_algo:
training_job_definition = None
training_job_definitions = self._build_training_job_definitions(inputs)
else:
training_job_definition = self._build_training_job_definition(inputs)
training_job_definitions = None

# Prepare autotune parameter
autotune_param = None
Expand All @@ -1281,6 +1291,7 @@ def _start_tuning_job(self, inputs):
"hyper_parameter_tuning_job_name": self._current_job_name,
"hyper_parameter_tuning_job_config": tuning_job_config,
"training_job_definition": training_job_definition,
"training_job_definitions": training_job_definitions,
"warm_start_config": self.warm_start_config,
"tags": tag_objects,
"autotune": autotune_param,
Expand Down Expand Up @@ -1474,9 +1485,7 @@ def _build_training_job_definition(self, inputs):

# Pass through the full OutputDataConfig from ModelTrainer so that
# kms_key_id, compression_type, and any other fields are preserved.
output_config = model_trainer.output_data_config or OutputDataConfig(
s3_output_path=None
)
output_config = model_trainer.output_data_config or OutputDataConfig(s3_output_path=None)

# Build resource config
resource_config = ResourceConfig(
Expand Down Expand Up @@ -1535,3 +1544,167 @@ def _build_training_job_definition(self, inputs):
pass

return definition

def _build_training_job_definitions(self, inputs):
"""Build a list of training job definitions for multi-algo tuning."""
from sagemaker.core.shapes import (
HyperParameterTrainingJobDefinition,
HyperParameterAlgorithmSpecification,
HyperParameterTuningJobObjective,
OutputDataConfig,
ResourceConfig,
StoppingCondition,
ParameterRanges,
Channel,
DataSource,
S3DataSource,
)

all_ranges = self.hyperparameter_ranges_dict() or {}
definitions = []

for name, model_trainer in self.model_trainer_dict.items():
algorithm_spec = HyperParameterAlgorithmSpecification(
training_image=model_trainer.training_image,
training_input_mode=model_trainer.training_input_mode or "File",
)

if self.metric_definitions_dict and name in self.metric_definitions_dict:
metric_defs_snake = []
for metric_def in self.metric_definitions_dict[name]:
metric_def_snake = {}
for key, value in metric_def.items():
snake_key = "".join(
["_" + c.lower() if c.isupper() else c for c in key]
).lstrip("_")
metric_def_snake[snake_key] = value
metric_defs_snake.append(metric_def_snake)
algorithm_spec.metric_definitions = metric_defs_snake

input_data_config = []
mt_inputs = inputs.get(name, {}) if isinstance(inputs, dict) else inputs
if isinstance(mt_inputs, str):
input_data_config = [
Channel(
channel_name="training",
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=mt_inputs,
s3_data_distribution_type="FullyReplicated",
)
),
)
]
elif isinstance(mt_inputs, dict):
for channel_name, s3_uri in mt_inputs.items():
input_data_config.append(
Channel(
channel_name=channel_name,
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=s3_uri,
s3_data_distribution_type="FullyReplicated",
)
),
)
)
elif isinstance(mt_inputs, list):
for inp in mt_inputs:
if isinstance(inp, InputData):
input_data_config.append(
Channel(
channel_name=inp.channel_name,
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=inp.data_source,
s3_data_distribution_type="FullyReplicated",
)
),
)
)
elif isinstance(inp, Channel):
input_data_config.append(inp)

if hasattr(model_trainer, "input_data_config") and model_trainer.input_data_config:
for channel in model_trainer.input_data_config:
if not any(c.channel_name == channel.channel_name for c in input_data_config):
input_data_config.append(channel)

if hasattr(model_trainer, "_tuner_channels") and model_trainer._tuner_channels:
for channel in model_trainer._tuner_channels:
if not any(c.channel_name == channel.channel_name for c in input_data_config):
input_data_config.append(channel)

output_config = model_trainer.output_data_config or OutputDataConfig(
s3_output_path=None
)

resource_config = ResourceConfig(
instance_type=(
model_trainer.compute.instance_type if model_trainer.compute else "ml.m5.xlarge"
),
instance_count=model_trainer.compute.instance_count if model_trainer.compute else 1,
volume_size_in_gb=(
model_trainer.compute.volume_size_in_gb if model_trainer.compute else 30
),
)

stopping_condition = StoppingCondition()
if (
model_trainer.stopping_condition
and model_trainer.stopping_condition.max_runtime_in_seconds
):
stopping_condition.max_runtime_in_seconds = (
model_trainer.stopping_condition.max_runtime_in_seconds
)

ranges_dict = all_ranges.get(name, {})
parameter_ranges = ParameterRanges(
integer_parameter_ranges=ranges_dict.get("IntegerParameterRanges", []),
continuous_parameter_ranges=ranges_dict.get("ContinuousParameterRanges", []),
categorical_parameter_ranges=ranges_dict.get("CategoricalParameterRanges", []),
)

tuning_objective = HyperParameterTuningJobObjective(
type=self.objective_type,
metric_name=self.objective_metric_name_dict[name],
)

static_hps = (
self.static_hyperparameters_dict.get(name, {})
if self.static_hyperparameters_dict
else {}
)

definition = HyperParameterTrainingJobDefinition(
definition_name=name,
algorithm_specification=algorithm_spec,
role_arn=model_trainer.role,
input_data_config=input_data_config if input_data_config else None,
output_data_config=output_config,
resource_config=resource_config,
stopping_condition=stopping_condition,
static_hyper_parameters=static_hps,
hyper_parameter_ranges=parameter_ranges,
tuning_objective=tuning_objective,
)

env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
definition.environment = env

networking = getattr(model_trainer, "networking", None)
if networking and hasattr(networking, "_to_vpc_config"):
try:
vpc_config = networking._to_vpc_config()
if vpc_config:
definition.vpc_config = vpc_config
except Exception:
pass

definitions.append(definition)

return definitions
Loading
Loading