diff --git a/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiver.java b/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiver.java index 6040820..9b41b2e 100644 --- a/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiver.java +++ b/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiver.java @@ -39,6 +39,27 @@ GetResult get( @BindArg("lastSeenCount") Long lastSeenCount ); + @Query( + "local current_count = redis.call(\"GET\", $countKey$)\n" + + "if not current_count then\n" + + " return {tostring(0)}\n" + + "else\n" + + " current_count = tonumber(current_count)\n" + + "end\n" + + "if current_count <= tonumber($lastSeenCount$) then\n" + + " return {tostring(current_count)}\n" + + "end\n" + + "local results = redis.call(\"LRANGE\", $listKey$, 0, current_count - tonumber($lastSeenCount$) - 1)\n" + + "results[#results + 1] = tostring(current_count)\n" + + "return results" + ) + @Mapper(GetWithDepthResultMapper.class) + GetResult getAndReturnCurrentDepth( + @BindKey("countKey") String countKey, + @BindKey("listKey") String listKey, + @BindArg("lastSeenCount") Long lastSeenCount + ); + @Query( "local current_count = redis.call(\"GET\", $countKey$)\n" + "if not current_count then\n" + @@ -85,8 +106,8 @@ GetResult get( ) @Mapper(GetBulkResultMapper.class) GetBulkResult getMulti( - @BindKey("allKeys") List inputKeys, - @BindArg("allArgs") List inputArgs + @BindKey("allKeys") List inputKeys, + @BindArg("allArgs") List inputArgs ); @Query( @@ -114,7 +135,7 @@ Long getDepth(@BindKey("countKey") String countKey, @BindKey("copyDepthToKey") String copyDepthToKey); } - public static class GetResultMapper implements ResultMapper> { + public static class GetResultMapper implements ResultMapper> { @Override public GetResult map(List result) { @@ -127,7 +148,24 @@ public GetResult map(List result) { } } - public static class GetBulkResultMapper implements ResultMapper>> { + public static class GetWithDepthResultMapper implements ResultMapper> { + + @Override + public GetResult map(List result) { + + if (result.size() == 0) { + throw new IllegalStateException("unexpected 0 length return from redis lua script"); + } + + if (result.size() == 1) { + return new GetResult(null, Long.valueOf(result.get(0))); + } + + return new GetResult(Lists.reverse(result.subList(0, result.size() - 1)), Long.valueOf(result.get(result.size() - 1))); + } + } + + public static class GetBulkResultMapper implements ResultMapper>> { @Override public GetBulkResult map(List> result) { @@ -136,17 +174,17 @@ public GetBulkResult map(List> result) { return null; } - List> listsResult = new ArrayList<>(); + List> listsResult = new ArrayList<>(); List listsSizes = new ArrayList<>(); - for (List each: result) { + for (List each : result) { if (each.size() == 0) { listsResult.add(each); listsSizes.add(0L); continue; } listsResult.add(Lists.reverse(each.subList(0, each.size() - 1))); - listsSizes.add(Long.valueOf(each.get(each.size()-1))); + listsSizes.add(Long.valueOf(each.get(each.size() - 1))); } return new GetBulkResult(listsResult, listsSizes); @@ -162,8 +200,32 @@ public GetResult get(String channel, Long lastSeenId) { return get(channel, lastSeenId, null); } + /** + * Gets new data from the channel and returns the current channel depth. + * This is unlike the get method, that returns the input channel depth if no data was found. + * + * This is useful because if the channel was reset, the client will see their id > the channel's id and should reset. + * This happens on redis clearing etc. + * + * @param channel the channel's name + * @param lastSeenId the last seen id by the client + * @return GetResult that contains the channel's latest depth + */ + @Override + public GetResult getAndReturnCurrentCount(String channel, Long lastSeenId) { + + try (Handle handle = rdbi.open()) { + DAO dao = handle.attach(DAO.class); + + return dao.getAndReturnCurrentDepth(ChannelPublisher.getChannelDepthKey(channel), + ChannelPublisher.getChannelQueueKey(channel), + lastSeenId); + } + } + @Override public GetResult get(String channel, Long lastSeenId, String copyDepthToKey) { + try (Handle handle = rdbi.open()) { DAO dao = handle.attach(DAO.class); if (copyDepthToKey == null) { diff --git a/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelReceiver.java b/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelReceiver.java index 1b67af1..aeba5c2 100644 --- a/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelReceiver.java +++ b/rdbi-recipes/src/main/java/com/lithium/dbi/rdbi/recipes/channel/ChannelReceiver.java @@ -4,6 +4,7 @@ public interface ChannelReceiver { GetResult get(String channel, Long lastSeenId); + GetResult getAndReturnCurrentCount(String channel, Long lastSeenId); GetResult get(String channel, Long lastSeenId, String copyDepthToKey); GetBulkResult getMulti(List channels, List lastSeenIds); Long getDepth(String channel); diff --git a/rdbi-recipes/src/test/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiverTest.java b/rdbi-recipes/src/test/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiverTest.java index 81925ac..428656e 100644 --- a/rdbi-recipes/src/test/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiverTest.java +++ b/rdbi-recipes/src/test/java/com/lithium/dbi/rdbi/recipes/channel/ChannelLuaReceiverTest.java @@ -18,9 +18,11 @@ import java.util.concurrent.atomic.AtomicLong; import static java.util.stream.Collectors.toList; -import static org.testng.AssertJUnit.assertEquals; -import static org.testng.AssertJUnit.assertTrue; -import static org.testng.AssertJUnit.fail; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + public class ChannelLuaReceiverTest { @@ -161,6 +163,37 @@ public void testEmptyChannelPublishAndReceive() throws Exception { } + @Test + public void testGetAndReturnCurrentCount() { + + final String channel = "channel1"; + + RDBI rdbi = new RDBI(new JedisPool("localhost")); + final ChannelPublisher channelPublisher = new ChannelPublisher(rdbi); + channelPublisher.resetChannel(channel); + + try { + final ChannelReceiver channelReceiver = new ChannelLuaReceiver(rdbi); + GetResult result = channelReceiver.getAndReturnCurrentCount(channel, 0L); + + assertEquals((Long) 0L, result.getDepth()); + + channelPublisher.publish(channel, "1"); + GetResult result2 = channelReceiver.getAndReturnCurrentCount(channel, 0L); + + assertEquals((Long) 1L, result2.getDepth()); + assertEquals(1, result2.getMessages().size()); + assertEquals("1", result2.getMessages().get(0)); + + GetResult result3 = channelReceiver.getAndReturnCurrentCount(channel, 1000L); + assertEquals((Long) 1L, result3.getDepth()); + assertNull(result3.getMessages()); + + } finally { + channelPublisher.resetChannel(channel); + } + } + @Test public void testMultiThreadedMultiChannelPublishAndReceive() throws InterruptedException { final Set channelSet = ImmutableSet.of("channel1", "channel2", "channel3", "channel4", "channel5"); @@ -176,39 +209,8 @@ public void testMultiThreadedMultiChannelPublishAndReceive() throws InterruptedE Map uuidMap = new HashMap<>(); - Thread thread1 = new Thread(new Runnable() { - @Override - public void run() { - for (int i = 0; i < messageAmount; i++) { - String stringVal = "value" + UUID.randomUUID(); - uuidMap.put(stringVal, 0); - final List value = ImmutableList.of(stringVal); - channelPublisher.publish(channelSet, value); - - if (Thread.interrupted()) { - return; - } - } - thread1Finished.set(true); - } - }); - - Thread thread2 = new Thread(new Runnable() { - @Override - public void run() { - for (int i = 0; i < messageAmount; i++) { - String stringVal = "value" + UUID.randomUUID(); - uuidMap.put(stringVal, 0); - final List value = ImmutableList.of(stringVal); - channelPublisher.publish(channelSet, value); - - if (Thread.interrupted()) { - return; - } - } - thread2Finished.set(true); - } - }); + Thread thread1 = randomPublish(channelSet, messageAmount, channelPublisher, thread1Finished, uuidMap); + Thread thread2 = randomPublish(channelSet, messageAmount, channelPublisher, thread2Finished, uuidMap); thread1.start(); thread2.start(); @@ -243,6 +245,22 @@ public void run() { } + private Thread randomPublish(Set channelSet, int messageAmount, ChannelPublisher channelPublisher, AtomicBoolean thread1Finished, Map uuidMap) { + return new Thread(() -> { + for (int i = 0; i < messageAmount; i++) { + String stringVal = "value" + UUID.randomUUID(); + uuidMap.put(stringVal, 0); + final List value = ImmutableList.of(stringVal); + channelPublisher.publish(channelSet, value); + + if (Thread.interrupted()) { + return; + } + } + thread1Finished.set(true); + }); + } + //ignored because this is a test to compare consecutive single channel gets vs. batch channel gets //results on a local redis instance //channels batch(ms) single(ms)