Skip to content

Commit 5c0ba06

Browse files
committed
fix(sessions): annotate result_processor for mypy-diff; pyink-format test file
- Add Any type hints to DynamicPickleType.result_processor and its inner process() to clear mypy-diff [no-untyped-def]. - Reformat test_safe_unpickle.py to 2-space pyink style.
1 parent 9f9c67e commit 5c0ba06

3 files changed

Lines changed: 143 additions & 142 deletions

File tree

src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from google.adk.sessions import _session_util
3131
from google.adk.sessions.migration import _schema_check_utils
3232
from google.adk.sessions.schemas import v1
33-
from google.adk.sessions.schemas._safe_unpickle import safe_loads as _safe_pickle_loads
3433
from google.genai import types
3534
import sqlalchemy
3635
from sqlalchemy import create_engine

src/google/adk/sessions/schemas/v0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def process_bind_param(self, value, dialect):
111111
return pickle.dumps(value)
112112
return value
113113

114-
def result_processor(self, dialect, coltype):
114+
def result_processor(self, dialect: Any, coltype: Any) -> Any:
115115
if dialect.name in ("mysql", "spanner+spanner"):
116116
return super().result_processor(dialect, coltype)
117117

118-
def process(value):
118+
def process(value: Any) -> Any:
119119
if value is None:
120120
return None
121121
if isinstance(value, memoryview):

tests/unittests/sessions/test_safe_unpickle.py

Lines changed: 141 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -27,162 +27,164 @@
2727

2828

2929
def _make_global_payload(module: str, func: str, *args: str) -> bytes:
30-
"""Craft a pickle stream that calls module.func(*args)."""
31-
buf = io.BytesIO()
32-
buf.write(pickle.PROTO + struct.pack("B", 2))
33-
buf.write(b"c" + f"{module}\n{func}\n".encode())
34-
buf.write(b"(")
35-
for arg in args:
36-
encoded = arg.encode("utf-8")
37-
buf.write(
38-
pickle.SHORT_BINUNICODE + struct.pack("<B", len(encoded)) + encoded
39-
)
40-
buf.write(b"t")
41-
buf.write(b"R")
42-
buf.write(b".")
43-
return buf.getvalue()
30+
"""Craft a pickle stream that calls module.func(*args)."""
31+
buf = io.BytesIO()
32+
buf.write(pickle.PROTO + struct.pack("B", 2))
33+
buf.write(b"c" + f"{module}\n{func}\n".encode())
34+
buf.write(b"(")
35+
for arg in args:
36+
encoded = arg.encode("utf-8")
37+
buf.write(
38+
pickle.SHORT_BINUNICODE + struct.pack("<B", len(encoded)) + encoded
39+
)
40+
buf.write(b"t")
41+
buf.write(b"R")
42+
buf.write(b".")
43+
return buf.getvalue()
4444

4545

4646
class TestBlockedPayloads(unittest.TestCase):
47-
"""Malicious pickle payloads must be blocked."""
47+
"""Malicious pickle payloads must be blocked."""
4848

49-
def test_os_system(self):
50-
with self.assertRaises(pickle.UnpicklingError):
51-
safe_loads(_make_global_payload("os", "system", "echo pwned"))
49+
def test_os_system(self):
50+
with self.assertRaises(pickle.UnpicklingError):
51+
safe_loads(_make_global_payload("os", "system", "echo pwned"))
5252

53-
def test_subprocess_popen(self):
54-
with self.assertRaises(pickle.UnpicklingError):
55-
safe_loads(_make_global_payload("subprocess", "Popen", "id"))
53+
def test_subprocess_popen(self):
54+
with self.assertRaises(pickle.UnpicklingError):
55+
safe_loads(_make_global_payload("subprocess", "Popen", "id"))
5656

57-
def test_builtins_import(self):
58-
with self.assertRaises(pickle.UnpicklingError):
59-
safe_loads(_make_global_payload("builtins", "__import__", "os"))
57+
def test_builtins_import(self):
58+
with self.assertRaises(pickle.UnpicklingError):
59+
safe_loads(_make_global_payload("builtins", "__import__", "os"))
6060

61-
def test_posix_system(self):
62-
with self.assertRaises(pickle.UnpicklingError):
63-
safe_loads(_make_global_payload("posix", "system", "whoami"))
61+
def test_posix_system(self):
62+
with self.assertRaises(pickle.UnpicklingError):
63+
safe_loads(_make_global_payload("posix", "system", "whoami"))
6464

65-
def test_nt_system(self):
66-
with self.assertRaises(pickle.UnpicklingError):
67-
safe_loads(_make_global_payload("nt", "system", "whoami"))
65+
def test_nt_system(self):
66+
with self.assertRaises(pickle.UnpicklingError):
67+
safe_loads(_make_global_payload("nt", "system", "whoami"))
6868

69-
def test_builtins_eval(self):
70-
with self.assertRaises(pickle.UnpicklingError):
71-
safe_loads(
72-
_make_global_payload(
73-
"builtins", "eval", "__import__('os').system('id')"
74-
)
75-
)
69+
def test_builtins_eval(self):
70+
with self.assertRaises(pickle.UnpicklingError):
71+
safe_loads(
72+
_make_global_payload(
73+
"builtins", "eval", "__import__('os').system('id')"
74+
)
75+
)
7676

7777

7878
class TestEventActionsRoundTrip(unittest.TestCase):
79-
"""Legitimate EventActions data must survive pickle -> safe_loads."""
80-
81-
def _round_trip(self, obj):
82-
return safe_loads(pickle.dumps(obj))
83-
84-
def test_string_values(self):
85-
original = {"state_delta": {"key": "value"}, "artifact_delta": {}}
86-
self.assertEqual(self._round_trip(original), original)
87-
88-
def test_nested_dict(self):
89-
original = {
90-
"state_delta": {
91-
"user_prefs": {"theme": "dark", "lang": "en"},
92-
"counter": 42,
93-
},
94-
"artifact_delta": {"files": ["a.txt", "b.txt"]},
95-
}
96-
self.assertEqual(self._round_trip(original), original)
97-
98-
def test_none_and_bool(self):
99-
original = {
100-
"skip_summarization": True,
101-
"requested_auth_configs": None,
102-
"escalate": False,
103-
}
104-
self.assertEqual(self._round_trip(original), original)
105-
106-
def test_empty_dict(self):
107-
self.assertEqual(self._round_trip({}), {})
79+
"""Legitimate EventActions data must survive pickle -> safe_loads."""
80+
81+
def _round_trip(self, obj):
82+
return safe_loads(pickle.dumps(obj))
83+
84+
def test_string_values(self):
85+
original = {"state_delta": {"key": "value"}, "artifact_delta": {}}
86+
self.assertEqual(self._round_trip(original), original)
87+
88+
def test_nested_dict(self):
89+
original = {
90+
"state_delta": {
91+
"user_prefs": {"theme": "dark", "lang": "en"},
92+
"counter": 42,
93+
},
94+
"artifact_delta": {"files": ["a.txt", "b.txt"]},
95+
}
96+
self.assertEqual(self._round_trip(original), original)
97+
98+
def test_none_and_bool(self):
99+
original = {
100+
"skip_summarization": True,
101+
"requested_auth_configs": None,
102+
"escalate": False,
103+
}
104+
self.assertEqual(self._round_trip(original), original)
105+
106+
def test_empty_dict(self):
107+
self.assertEqual(self._round_trip({}), {})
108+
108109

109110
class TestRealEventActionsRoundTrip(unittest.TestCase):
110-
"""Smoke test: real EventActions instances survive pickle -> safe_loads."""
111-
112-
def _round_trip(self, obj):
113-
return safe_loads(pickle.dumps(obj))
114-
115-
def test_minimal_event_actions(self):
116-
original = EventActions()
117-
result = self._round_trip(original)
118-
self.assertIsInstance(result, EventActions)
119-
self.assertEqual(result.state_delta, {})
120-
self.assertEqual(result.artifact_delta, {})
121-
122-
def test_event_actions_with_state_delta(self):
123-
original = EventActions(
124-
state_delta={"user_name": "alice", "turn_count": 3, "active": True},
125-
artifact_delta={"report.pdf": 2},
126-
)
127-
result = self._round_trip(original)
128-
self.assertIsInstance(result, EventActions)
129-
self.assertEqual(result.state_delta, original.state_delta)
130-
self.assertEqual(result.artifact_delta, original.artifact_delta)
131-
132-
def test_event_actions_with_transfer_and_escalate(self):
133-
original = EventActions(
134-
transfer_to_agent="specialist_agent",
135-
escalate=True,
136-
skip_summarization=True,
137-
)
138-
result = self._round_trip(original)
139-
self.assertIsInstance(result, EventActions)
140-
self.assertEqual(result.transfer_to_agent, "specialist_agent")
141-
self.assertTrue(result.escalate)
142-
self.assertTrue(result.skip_summarization)
143-
144-
def test_event_actions_with_complex_state_values(self):
145-
original = EventActions(
146-
state_delta={
147-
"nested": {"a": [1, 2, 3], "b": None},
148-
"count": 42,
149-
"tags": ["ml", "security"],
150-
},
151-
)
152-
result = self._round_trip(original)
153-
self.assertIsInstance(result, EventActions)
154-
self.assertEqual(result.state_delta["nested"]["a"], [1, 2, 3])
155-
self.assertIsNone(result.state_delta["nested"]["b"])
111+
"""Smoke test: real EventActions instances survive pickle -> safe_loads."""
112+
113+
def _round_trip(self, obj):
114+
return safe_loads(pickle.dumps(obj))
115+
116+
def test_minimal_event_actions(self):
117+
original = EventActions()
118+
result = self._round_trip(original)
119+
self.assertIsInstance(result, EventActions)
120+
self.assertEqual(result.state_delta, {})
121+
self.assertEqual(result.artifact_delta, {})
122+
123+
def test_event_actions_with_state_delta(self):
124+
original = EventActions(
125+
state_delta={"user_name": "alice", "turn_count": 3, "active": True},
126+
artifact_delta={"report.pdf": 2},
127+
)
128+
result = self._round_trip(original)
129+
self.assertIsInstance(result, EventActions)
130+
self.assertEqual(result.state_delta, original.state_delta)
131+
self.assertEqual(result.artifact_delta, original.artifact_delta)
132+
133+
def test_event_actions_with_transfer_and_escalate(self):
134+
original = EventActions(
135+
transfer_to_agent="specialist_agent",
136+
escalate=True,
137+
skip_summarization=True,
138+
)
139+
result = self._round_trip(original)
140+
self.assertIsInstance(result, EventActions)
141+
self.assertEqual(result.transfer_to_agent, "specialist_agent")
142+
self.assertTrue(result.escalate)
143+
self.assertTrue(result.skip_summarization)
144+
145+
def test_event_actions_with_complex_state_values(self):
146+
original = EventActions(
147+
state_delta={
148+
"nested": {"a": [1, 2, 3], "b": None},
149+
"count": 42,
150+
"tags": ["ml", "security"],
151+
},
152+
)
153+
result = self._round_trip(original)
154+
self.assertIsInstance(result, EventActions)
155+
self.assertEqual(result.state_delta["nested"]["a"], [1, 2, 3])
156+
self.assertIsNone(result.state_delta["nested"]["b"])
156157

157158

158159
class TestEnvVarFallback(unittest.TestCase):
159-
"""ADK_ALLOW_UNSAFE_V0_PICKLE=1 must bypass RestrictedUnpickler."""
160-
161-
_ENV_KEY = "ADK_ALLOW_UNSAFE_V0_PICKLE"
162-
_PAYLOAD = _make_global_payload("collections", "Counter")
163-
164-
def test_blocked_without_env_var(self):
165-
old = os.environ.pop(self._ENV_KEY, None)
166-
try:
167-
with self.assertRaises(pickle.UnpicklingError):
168-
safe_loads(self._PAYLOAD)
169-
finally:
170-
if old is not None:
171-
os.environ[self._ENV_KEY] = old
172-
173-
def test_allowed_with_env_var(self):
174-
old = os.environ.get(self._ENV_KEY)
175-
try:
176-
os.environ[self._ENV_KEY] = "1"
177-
from collections import Counter
178-
result = safe_loads(self._PAYLOAD)
179-
self.assertIsInstance(result, Counter)
180-
finally:
181-
if old is None:
182-
os.environ.pop(self._ENV_KEY, None)
183-
else:
184-
os.environ[self._ENV_KEY] = old
160+
"""ADK_ALLOW_UNSAFE_V0_PICKLE=1 must bypass RestrictedUnpickler."""
161+
162+
_ENV_KEY = "ADK_ALLOW_UNSAFE_V0_PICKLE"
163+
_PAYLOAD = _make_global_payload("collections", "Counter")
164+
165+
def test_blocked_without_env_var(self):
166+
old = os.environ.pop(self._ENV_KEY, None)
167+
try:
168+
with self.assertRaises(pickle.UnpicklingError):
169+
safe_loads(self._PAYLOAD)
170+
finally:
171+
if old is not None:
172+
os.environ[self._ENV_KEY] = old
173+
174+
def test_allowed_with_env_var(self):
175+
old = os.environ.get(self._ENV_KEY)
176+
try:
177+
os.environ[self._ENV_KEY] = "1"
178+
from collections import Counter
179+
180+
result = safe_loads(self._PAYLOAD)
181+
self.assertIsInstance(result, Counter)
182+
finally:
183+
if old is None:
184+
os.environ.pop(self._ENV_KEY, None)
185+
else:
186+
os.environ[self._ENV_KEY] = old
185187

186188

187189
if __name__ == "__main__":
188-
unittest.main()
190+
unittest.main()

0 commit comments

Comments
 (0)