Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/ml/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from apache_beam.ml.inference.base import ModelT
from apache_beam.ml.inference.base import RunInferenceDLQ
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options_context import get_pipeline_options

_LOGGER = logging.getLogger(__name__)
_ATTRIBUTE_FILE_NAME = 'attributes.json'
Expand Down Expand Up @@ -591,7 +592,18 @@ def save_attributes(
def load_attributes(artifact_location):
with FileSystems.open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
'rb') as f:
return jsonpickle.decode(f.read())
# load_attributes runs eagerly during MLTransform.expand() at pipeline
# construction time, so the pipeline's options are available via the
# construction-time context.
pipeline_options = get_pipeline_options()
safe = True
if (pipeline_options is not None and
pipeline_options.is_compat_version_prior_to("2.75.0")):
# Keep the pre-2.75.0 jsonpickle behavior (safe=False permits
# eval-based decoding) for backwards compatibility with already-staged
# artifacts.
safe = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there actually a downside to always setting safe=False? IIUC, this is protecting against remote code execution exploits that allow escalating privileges. But if you are able to compromise this artifact file, I think you already functionally have code execution privileges since it allows you to define (semi) arbitrary transforms.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default in v 4+ is safe=True, so I think it is best if we align with the default behavior in the underlying library?

Is the benefit of always setting to False simplicity? I don't have a strong opinion here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think it is reasonable to just skip compat flag completely. I doubt it will break anyone, and if it does they can file a bug and we can add a compat flag in an upcoming beam version?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the benefit of always setting to False simplicity? I don't have a strong opinion here.

Yeah - IMO this is either a vulnerability which we should unconditionally patch, or it is not and we can retain a single behavior.

I also think it is reasonable to just skip compat flag completely. I doubt it will break anyone, and if it does they can file a bug and we can add a compat flag in an upcoming beam version?

I agree with this. I actually also realized that the compat flag doesn't work here anyways - these artifacts persist across (usually batch) pipeline runs, so it isn't coming from an update, its coming from a subsequent pipeline run.

My slight preference would be safe=False, but I'm good with either approach as long as we skip the compat flag piece and will defer to you here.

return jsonpickle.decode(f.read(), safe=safe)
Comment thread
claudevdm marked this conversation as resolved.


_transform_attribute_manager = _JsonPickleTransformAttributeManager
Expand Down
42 changes: 42 additions & 0 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dataclasses import dataclass
from typing import Any
from typing import Optional
from unittest import mock

import numpy as np
from parameterized import param
Expand All @@ -36,6 +37,7 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms import base
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

Expand Down Expand Up @@ -841,6 +843,46 @@ def test_save_and_load_run_inference(self):
self.assertListEqual(
get_keys(model_handler), get_keys(loaded_model_handler))

@parameterized.expand([
# Pipelines pinned to a version older than 2.75.0 keep the pre-2.75.0
# jsonpickle behavior (safe=False, which permits eval-based decoding).
param(update_compatibility_version='2.74.0', expected_safe=False),
# The breaking-change version itself and newer decode securely.
param(update_compatibility_version='2.75.0', expected_safe=True),
# Pipelines that do not set the option (the common case) decode securely.
param(update_compatibility_version=None, expected_safe=True),
])
def test_load_attributes_safe_flag_follows_compat_version(
self, update_compatibility_version, expected_safe):
data = [{'x': 'Hello world'}, {'x': 'Apache Beam'}]
with beam.Pipeline() as p:
_ = (
p
| beam.Create(data)
| base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
FakeEmbeddingsManager(columns=['x'])))

# FakeEmbeddingsManager reverses the values of the embedded columns.
expected_data = [{'x': d['x'][::-1]} for d in data]

options = PipelineOptions(
update_compatibility_version=update_compatibility_version)
with mock.patch.object(base.jsonpickle,
'decode',
wraps=base.jsonpickle.decode) as mock_decode:
with beam.Pipeline(options=options) as p:
result = (
p
| beam.Create(data)
| base.MLTransform(read_artifact_location=self.artifact_location))
assert_that(result, equal_to(expected_data))

safe_flags = [
call.kwargs.get('safe') for call in mock_decode.call_args_list
]
self.assertEqual(safe_flags, [expected_safe])
Comment thread
claudevdm marked this conversation as resolved.

def test_mltransform_to_ptransform_wrapper(self):
transforms = [
FakeEmbeddingsManager(columns=['x']),
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def get_portability_package_data():
'fasteners>=0.3,<1.0',
'grpcio>=1.33.1,<2,!=1.48.0,!=1.59.*,!=1.60.*,!=1.61.*,!=1.62.0,!=1.62.1,!=1.66.*,!=1.67.*,!=1.68.*,!=1.69.*,!=1.70.*', # pylint: disable=line-too-long
'httplib2>=0.8,<0.32.0',
'jsonpickle>=3.0.0,<4.0.0',
'jsonpickle>=3.0.0,<5.0.0',
# numpy can have breaking changes in minor versions.
# Use a strict upper bound.
'numpy>=1.14.3,<2.5.0', # Update pyproject.toml as well.
Expand Down
Loading