Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,16 @@ def __init__(
code_location[:-1] if (code_location and code_location.endswith("/")) else code_location
)

def _s3_code_prefix(self):
"""Return the S3 prefix for code uploads, respecting code_location if set."""
if self.code_location:
return self.code_location
return s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
)

def _package_code(
self,
entry_point,
Expand Down Expand Up @@ -1155,9 +1165,7 @@ def _package_code(

# Upload to S3
s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
self._s3_code_prefix(),
job_name,
"source",
"sourcedir.tar.gz",
Expand Down Expand Up @@ -1320,9 +1328,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
runproc_file_str = self._generate_framework_script(user_script)
runproc_file_hash = hash_object(runproc_file_str)
s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
self._s3_code_prefix(),
_pipeline_config.pipeline_name,
"code",
runproc_file_hash,
Expand Down
76 changes: 76 additions & 0 deletions sagemaker-core/tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,82 @@ def test_package_code_source_dir_not_exists(self, mock_session):
)


def test_package_code_with_code_location(self, mock_session):
processor = FrameworkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
code_location="s3://my-custom-bucket/my-prefix",
)

with tempfile.TemporaryDirectory() as tmpdir:
entry_point = os.path.join(tmpdir, "train.py")
with open(entry_point, "w") as f:
f.write("print('training')")

result = processor._package_code(
entry_point=entry_point,
source_dir=tmpdir,
requirements=None,
job_name="test-job",
kms_key=None,
)
assert result.startswith("s3://my-custom-bucket/my-prefix")
assert "sourcedir.tar.gz" in result

def test_package_code_without_code_location_uses_default_bucket(self, mock_session):
processor = FrameworkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

with tempfile.TemporaryDirectory() as tmpdir:
entry_point = os.path.join(tmpdir, "train.py")
with open(entry_point, "w") as f:
f.write("print('training')")

result = processor._package_code(
entry_point=entry_point,
source_dir=tmpdir,
requirements=None,
job_name="test-job",
kms_key=None,
)
assert result.startswith("s3://test-bucket/sagemaker")
assert "sourcedir.tar.gz" in result

def test_package_code_with_code_location_trailing_slash(self, mock_session):
processor = FrameworkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
code_location="s3://my-custom-bucket/my-prefix/",
)

with tempfile.TemporaryDirectory() as tmpdir:
entry_point = os.path.join(tmpdir, "train.py")
with open(entry_point, "w") as f:
f.write("print('training')")

result = processor._package_code(
entry_point=entry_point,
source_dir=tmpdir,
requirements=None,
job_name="test-job",
kms_key=None,
)
# Trailing slash is stripped in __init__, so same result
assert result.startswith("s3://my-custom-bucket/my-prefix")
assert "sourcedir.tar.gz" in result


class TestFrameworkProcessorRun:
def test_run_with_s3_code(self, mock_session):
processor = FrameworkProcessor(
Expand Down
Loading