From 63f03c934cb00f0da49369f93ab3aa8c36b839c7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 10:57:11 -0700 Subject: [PATCH] fix: Handle JSON deserialized State.REMOVED sentinel The State.REMOVED sentinel object is serialized to a specific string in JSON. When deserialized, this becomes a String object, not the original sentinel instance. This change updates the logic to correctly identify removals when the value is either the State.REMOVED object or its string representation. Additionally, when applying state deltas from app/user state, removals are now handled using a new removeWithoutDelta method on the State object to avoid re-recording these removals as new deltas. PiperOrigin-RevId: 881521527 --- .../adk/sessions/BaseSessionService.java | 2 +- .../adk/sessions/InMemorySessionService.java | 39 ++++- .../java/com/google/adk/sessions/State.java | 28 +++- .../sessions/InMemorySessionServiceTest.java | 154 ++++++++++++++++++ .../com/google/adk/sessions/StateTest.java | 36 ++++ 5 files changed, 251 insertions(+), 8 deletions(-) create mode 100644 core/src/test/java/com/google/adk/sessions/StateTest.java diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 7a0885544..b4b93b645 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -236,7 +236,7 @@ default Single appendEvent(Session session, Event event) { stateDelta.forEach( (key, value) -> { if (!key.startsWith(State.TEMP_PREFIX)) { - if (value == State.REMOVED) { + if (State.isRemoved(value)) { sessionState.remove(key); } else { sessionState.put(key, value); diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b2a584b11..dede50f66 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -244,7 +244,7 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - if (value == State.REMOVED) { + if (State.isRemoved(value)) { appState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .remove(appStateKey); @@ -255,7 +255,7 @@ public Single appendEvent(Session session, Event event) { } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - if (value == State.REMOVED) { + if (State.isRemoved(value)) { userState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) @@ -267,8 +267,13 @@ public Single appendEvent(Session session, Event event) { .put(userStateKey, value); } } else { - if (value == State.REMOVED) { - session.state().remove(key); + if (State.isRemoved(value)) { + Map s = session.state(); + if (s instanceof State state) { + state.removeWithoutDelta(key); + } else { + s.remove(key); + } } else { session.state().put(key, value); } @@ -333,12 +338,34 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess // Merge App State directly into the session's state map appState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); + .forEach( + (key, value) -> { + if (State.isRemoved(value)) { + if (sessionState instanceof State state) { + state.removeWithoutDelta(State.APP_PREFIX + key); + } else { + sessionState.remove(State.APP_PREFIX + key); + } + } else { + sessionState.put(State.APP_PREFIX + key, value); + } + }); userState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); + .forEach( + (key, value) -> { + if (State.isRemoved(value)) { + if (sessionState instanceof State state) { + state.removeWithoutDelta(State.USER_PREFIX + key); + } else { + sessionState.remove(State.USER_PREFIX + key); + } + } else { + sessionState.put(State.USER_PREFIX + key, value); + } + }); return session; } diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..e3884a3a6 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -17,6 +17,7 @@ package com.google.adk.sessions; import com.fasterxml.jackson.annotation.JsonValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.Collection; import java.util.Map; import java.util.Map.Entry; @@ -33,6 +34,8 @@ public final class State implements ConcurrentMap { public static final String USER_PREFIX = "user:"; public static final String TEMP_PREFIX = "temp:"; + public static final String REMOVED_SENTINEL_STRING = "__ADK_SENTINEL_REMOVED__"; + /** Sentinel object to mark removed entries in the delta map. */ public static final Object REMOVED = RemovedSentinel.INSTANCE; @@ -129,6 +132,19 @@ public Object remove(Object key) { return state.remove(key); } + /** + * Removes a key from the state map without recording the removal in the delta map. This is + * intended for internal use when rebuilding state from an event stream where the removal is + * already known and doesn't need to be represented as a new change. + * + * @param key The key to remove. + * @return The previous value associated with key, or null if there was no mapping for key. + */ + @CanIgnoreReturnValue + public Object removeWithoutDelta(Object key) { + return state.remove(key); + } + @Override public boolean remove(Object key, Object value) { boolean removed = state.remove(key, value); @@ -170,6 +186,16 @@ public boolean hasDelta() { return !delta.isEmpty(); } + /** + * Checks if a value represents a removed state entry, accounting for deserialization from JSON. + * + * @param value The value to check. + * @return True if the value indicates removal, false otherwise. + */ + public static boolean isRemoved(Object value) { + return value == REMOVED || Objects.equals(value, REMOVED_SENTINEL_STRING); + } + private static final class RemovedSentinel { public static final RemovedSentinel INSTANCE = new RemovedSentinel(); @@ -179,7 +205,7 @@ private RemovedSentinel() { @JsonValue public String toJson() { - return "__ADK_SENTINEL_REMOVED__"; + return REMOVED_SENTINEL_STRING; } } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 41e156ffd..5fea6bab2 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; /** Unit tests for {@link InMemorySessionService}. */ @RunWith(JUnit4.class) @@ -214,6 +215,55 @@ public void appendEvent_removesState() { assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + @Test + public void appendEvent_removesStateFromJsonDeserializedSentinel() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + stateDeltaAdd.put("sessionKey", "sessionValue"); + stateDeltaAdd.put("_app_appKey", "appValue"); + stateDeltaAdd.put("_user_userKey", "userValue"); + + Event eventAdd = + Event.builder().actions(EventActions.builder().stateDelta(stateDeltaAdd).build()).build(); + + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify state is added + Session retrievedSessionAdd = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue"); + + // Prepare and append event to remove state using the String representation of the sentinel + // to simulate Jackson JSON deserialization. + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED_SENTINEL_STRING); + stateDeltaRemove.put("_app_appKey", State.REMOVED_SENTINEL_STRING); + stateDeltaRemove.put("_user_userKey", State.REMOVED_SENTINEL_STRING); + + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + unused = sessionService.appendEvent(session, eventRemove).blockingGet(); + + // Verify state is removed despite being a String instead of the State.REMOVED object + Session retrievedSessionRemove = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey"); + } + @Test public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); @@ -247,4 +297,108 @@ public void sequentialAgents_shareTempState() { assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output"); assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data"); } + + @Test + public void mergeGlobalState_removesSentinels() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + session.state().put(State.APP_PREFIX + "appKey", "appValue"); + session.state().put(State.USER_PREFIX + "userKey", "userValue"); + + // Use reflection to directly put REMOVED sentinel into appState and userState + java.lang.reflect.Field appStateField = + InMemorySessionService.class.getDeclaredField("appState"); + appStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap> appState = + (ConcurrentMap>) appStateField.get(sessionService); + appState.computeIfAbsent("app", k -> new ConcurrentHashMap<>()).put("appKey", State.REMOVED); + + java.lang.reflect.Field userStateField = + InMemorySessionService.class.getDeclaredField("userState"); + userStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap>> userState = + (ConcurrentMap>>) + userStateField.get(sessionService); + userState + .computeIfAbsent("app", k -> new ConcurrentHashMap<>()) + .computeIfAbsent("user", k -> new ConcurrentHashMap<>()) + .put("userKey", State.REMOVED); + + // Call getSession to trigger mergeWithGlobalState + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + + assertThat(retrievedSession.state()).doesNotContainKey(State.APP_PREFIX + "appKey"); + assertThat(retrievedSession.state()).doesNotContainKey(State.USER_PREFIX + "userKey"); + } + + @Test + public void appendEvent_withNullState_throwsNpeOnRemoval() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + java.lang.reflect.Field stateField = Session.class.getDeclaredField("state"); + stateField.setAccessible(true); + stateField.set(session, null); + + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED); + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + org.junit.Assert.assertThrows( + NullPointerException.class, + () -> { + sessionService.appendEvent(session, eventRemove).blockingGet(); + }); + } + + @Test + public void mergeGlobalState_withNullState_throwsNpeOnRemoval() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + // Inject REMOVED sentinel into appState and userState + java.lang.reflect.Field appStateField = + InMemorySessionService.class.getDeclaredField("appState"); + appStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap> appState = + (ConcurrentMap>) appStateField.get(sessionService); + appState.computeIfAbsent("app", k -> new ConcurrentHashMap<>()).put("appKey", State.REMOVED); + + java.lang.reflect.Field userStateField = + InMemorySessionService.class.getDeclaredField("userState"); + userStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap>> userState = + (ConcurrentMap>>) + userStateField.get(sessionService); + userState + .computeIfAbsent("app", k -> new ConcurrentHashMap<>()) + .computeIfAbsent("user", k -> new ConcurrentHashMap<>()) + .put("userKey", State.REMOVED); + + // Set session state to null AFTER injecting global state + java.lang.reflect.Field stateField = Session.class.getDeclaredField("state"); + stateField.setAccessible(true); + stateField.set(session, null); + + Event emptyEvent = Event.builder().build(); + + org.junit.Assert.assertThrows( + NullPointerException.class, + () -> { + sessionService.appendEvent(session, emptyEvent).blockingGet(); + }); + } } diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java new file mode 100644 index 000000000..8e4b599ca --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link State}. */ +@RunWith(JUnit4.class) +public final class StateTest { + + @Test + public void removedSentinel_serializesToJson() throws JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + String json = mapper.writeValueAsString(State.REMOVED); + assertThat(json).isEqualTo("\"" + State.REMOVED_SENTINEL_STRING + "\""); + } +}