diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 7a0885544..b4b93b645 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -236,7 +236,7 @@ default Single appendEvent(Session session, Event event) { stateDelta.forEach( (key, value) -> { if (!key.startsWith(State.TEMP_PREFIX)) { - if (value == State.REMOVED) { + if (State.isRemoved(value)) { sessionState.remove(key); } else { sessionState.put(key, value); diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b2a584b11..dede50f66 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -244,7 +244,7 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - if (value == State.REMOVED) { + if (State.isRemoved(value)) { appState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .remove(appStateKey); @@ -255,7 +255,7 @@ public Single appendEvent(Session session, Event event) { } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - if (value == State.REMOVED) { + if (State.isRemoved(value)) { userState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) @@ -267,8 +267,13 @@ public Single appendEvent(Session session, Event event) { .put(userStateKey, value); } } else { - if (value == State.REMOVED) { - session.state().remove(key); + if (State.isRemoved(value)) { + Map s = session.state(); + if (s instanceof State state) { + state.removeWithoutDelta(key); + } else { + s.remove(key); + } } else { session.state().put(key, value); } @@ -333,12 +338,34 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess // Merge App State directly into the session's state map appState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); + .forEach( + (key, value) -> { + if (State.isRemoved(value)) { + if (sessionState instanceof State state) { + state.removeWithoutDelta(State.APP_PREFIX + key); + } else { + sessionState.remove(State.APP_PREFIX + key); + } + } else { + sessionState.put(State.APP_PREFIX + key, value); + } + }); userState .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); + .forEach( + (key, value) -> { + if (State.isRemoved(value)) { + if (sessionState instanceof State state) { + state.removeWithoutDelta(State.USER_PREFIX + key); + } else { + sessionState.remove(State.USER_PREFIX + key); + } + } else { + sessionState.put(State.USER_PREFIX + key, value); + } + }); return session; } diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..e3884a3a6 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -17,6 +17,7 @@ package com.google.adk.sessions; import com.fasterxml.jackson.annotation.JsonValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.Collection; import java.util.Map; import java.util.Map.Entry; @@ -33,6 +34,8 @@ public final class State implements ConcurrentMap { public static final String USER_PREFIX = "user:"; public static final String TEMP_PREFIX = "temp:"; + public static final String REMOVED_SENTINEL_STRING = "__ADK_SENTINEL_REMOVED__"; + /** Sentinel object to mark removed entries in the delta map. */ public static final Object REMOVED = RemovedSentinel.INSTANCE; @@ -129,6 +132,19 @@ public Object remove(Object key) { return state.remove(key); } + /** + * Removes a key from the state map without recording the removal in the delta map. This is + * intended for internal use when rebuilding state from an event stream where the removal is + * already known and doesn't need to be represented as a new change. + * + * @param key The key to remove. + * @return The previous value associated with key, or null if there was no mapping for key. + */ + @CanIgnoreReturnValue + public Object removeWithoutDelta(Object key) { + return state.remove(key); + } + @Override public boolean remove(Object key, Object value) { boolean removed = state.remove(key, value); @@ -170,6 +186,16 @@ public boolean hasDelta() { return !delta.isEmpty(); } + /** + * Checks if a value represents a removed state entry, accounting for deserialization from JSON. + * + * @param value The value to check. + * @return True if the value indicates removal, false otherwise. + */ + public static boolean isRemoved(Object value) { + return value == REMOVED || Objects.equals(value, REMOVED_SENTINEL_STRING); + } + private static final class RemovedSentinel { public static final RemovedSentinel INSTANCE = new RemovedSentinel(); @@ -179,7 +205,7 @@ private RemovedSentinel() { @JsonValue public String toJson() { - return "__ADK_SENTINEL_REMOVED__"; + return REMOVED_SENTINEL_STRING; } } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 41e156ffd..5fea6bab2 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; /** Unit tests for {@link InMemorySessionService}. */ @RunWith(JUnit4.class) @@ -214,6 +215,55 @@ public void appendEvent_removesState() { assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + @Test + public void appendEvent_removesStateFromJsonDeserializedSentinel() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + stateDeltaAdd.put("sessionKey", "sessionValue"); + stateDeltaAdd.put("_app_appKey", "appValue"); + stateDeltaAdd.put("_user_userKey", "userValue"); + + Event eventAdd = + Event.builder().actions(EventActions.builder().stateDelta(stateDeltaAdd).build()).build(); + + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify state is added + Session retrievedSessionAdd = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue"); + + // Prepare and append event to remove state using the String representation of the sentinel + // to simulate Jackson JSON deserialization. + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED_SENTINEL_STRING); + stateDeltaRemove.put("_app_appKey", State.REMOVED_SENTINEL_STRING); + stateDeltaRemove.put("_user_userKey", State.REMOVED_SENTINEL_STRING); + + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + unused = sessionService.appendEvent(session, eventRemove).blockingGet(); + + // Verify state is removed despite being a String instead of the State.REMOVED object + Session retrievedSessionRemove = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey"); + } + @Test public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); @@ -247,4 +297,108 @@ public void sequentialAgents_shareTempState() { assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output"); assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data"); } + + @Test + public void mergeGlobalState_removesSentinels() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + session.state().put(State.APP_PREFIX + "appKey", "appValue"); + session.state().put(State.USER_PREFIX + "userKey", "userValue"); + + // Use reflection to directly put REMOVED sentinel into appState and userState + java.lang.reflect.Field appStateField = + InMemorySessionService.class.getDeclaredField("appState"); + appStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap> appState = + (ConcurrentMap>) appStateField.get(sessionService); + appState.computeIfAbsent("app", k -> new ConcurrentHashMap<>()).put("appKey", State.REMOVED); + + java.lang.reflect.Field userStateField = + InMemorySessionService.class.getDeclaredField("userState"); + userStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap>> userState = + (ConcurrentMap>>) + userStateField.get(sessionService); + userState + .computeIfAbsent("app", k -> new ConcurrentHashMap<>()) + .computeIfAbsent("user", k -> new ConcurrentHashMap<>()) + .put("userKey", State.REMOVED); + + // Call getSession to trigger mergeWithGlobalState + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + + assertThat(retrievedSession.state()).doesNotContainKey(State.APP_PREFIX + "appKey"); + assertThat(retrievedSession.state()).doesNotContainKey(State.USER_PREFIX + "userKey"); + } + + @Test + public void appendEvent_withNullState_throwsNpeOnRemoval() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + java.lang.reflect.Field stateField = Session.class.getDeclaredField("state"); + stateField.setAccessible(true); + stateField.set(session, null); + + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED); + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + org.junit.Assert.assertThrows( + NullPointerException.class, + () -> { + sessionService.appendEvent(session, eventRemove).blockingGet(); + }); + } + + @Test + public void mergeGlobalState_withNullState_throwsNpeOnRemoval() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + // Inject REMOVED sentinel into appState and userState + java.lang.reflect.Field appStateField = + InMemorySessionService.class.getDeclaredField("appState"); + appStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap> appState = + (ConcurrentMap>) appStateField.get(sessionService); + appState.computeIfAbsent("app", k -> new ConcurrentHashMap<>()).put("appKey", State.REMOVED); + + java.lang.reflect.Field userStateField = + InMemorySessionService.class.getDeclaredField("userState"); + userStateField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap>> userState = + (ConcurrentMap>>) + userStateField.get(sessionService); + userState + .computeIfAbsent("app", k -> new ConcurrentHashMap<>()) + .computeIfAbsent("user", k -> new ConcurrentHashMap<>()) + .put("userKey", State.REMOVED); + + // Set session state to null AFTER injecting global state + java.lang.reflect.Field stateField = Session.class.getDeclaredField("state"); + stateField.setAccessible(true); + stateField.set(session, null); + + Event emptyEvent = Event.builder().build(); + + org.junit.Assert.assertThrows( + NullPointerException.class, + () -> { + sessionService.appendEvent(session, emptyEvent).blockingGet(); + }); + } } diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java new file mode 100644 index 000000000..8e4b599ca --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link State}. */ +@RunWith(JUnit4.class) +public final class StateTest { + + @Test + public void removedSentinel_serializesToJson() throws JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + String json = mapper.writeValueAsString(State.REMOVED); + assertThat(json).isEqualTo("\"" + State.REMOVED_SENTINEL_STRING + "\""); + } +}