diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..70d2dfbf2 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -24,6 +24,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import javax.annotation.Nullable; /** A {@link State} object that also keeps track of the changes to the state. */ @SuppressWarnings("ShouldNotSubclass") @@ -39,13 +40,22 @@ public final class State implements ConcurrentMap { private final ConcurrentMap state; private final ConcurrentMap delta; - public State(ConcurrentMap state) { - this(state, new ConcurrentHashMap<>()); - } - - public State(ConcurrentMap state, ConcurrentMap delta) { - this.state = Objects.requireNonNull(state); - this.delta = delta; + public State(Map state) { + this(state, null); + } + + public State(Map state, @Nullable Map delta) { + Objects.requireNonNull(state, "state is null"); + this.state = + state instanceof ConcurrentMap + ? (ConcurrentMap) state + : new ConcurrentHashMap<>(state); + this.delta = + delta == null + ? new ConcurrentHashMap<>() + : delta instanceof ConcurrentMap + ? (ConcurrentMap) delta + : new ConcurrentHashMap<>(delta); } @Override diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java new file mode 100644 index 000000000..e1fcaeadc --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -0,0 +1,50 @@ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StateTest { + @Test + public void constructor_nullDelta_createsEmptyConcurrentHashMap() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + State state = new State(stateMap, null); + assertThat(state.hasDelta()).isFalse(); + state.put("key", "value"); + assertThat(state.hasDelta()).isTrue(); + } + + @Test + public void constructor_nullState_throwsException() { + Assert.assertThrows(NullPointerException.class, () -> new State(null, new HashMap<>())); + } + + @Test + public void constructor_regularMapState() { + Map stateMap = new HashMap<>(); + stateMap.put("initial", "val"); + State state = new State(stateMap, null); + // It should have copied the contents + assertThat(state).containsEntry("initial", "val"); + state.put("key", "value"); + // The original map should NOT be updated because a copy was created + assertThat(stateMap).doesNotContainKey("key"); + } + + @Test + public void constructor_singleArgument() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + State state = new State(stateMap); + assertThat(state.hasDelta()).isFalse(); + state.put("key", "value"); + assertThat(state.hasDelta()).isTrue(); + } +}