diff --git a/backends/build.gradle.kts b/backends/build.gradle.kts index 35fedca3..22ee10b1 100644 --- a/backends/build.gradle.kts +++ b/backends/build.gradle.kts @@ -15,9 +15,11 @@ repositories { dependencies { api(project(":core")) api("com.squareup.okhttp3:okhttp:4.12.0") + api("redis.clients:jedis:7.1.0") compileOnly("com.velocitypowered:velocity-api:3.4.0-SNAPSHOT") compileOnly("org.spigotmc:spigot-api:1.21.1-R0.1-SNAPSHOT") compileOnly("dev.jorel:commandapi-spigot-core:11.1.0") + compileOnly("org.bstats:bstats-bukkit:3.2.0") } // plugin.yml and paper-plugin.yml contain @version@ placeholders. diff --git a/backends/bukkit/src/main/java/dev/objz/commandbridge/bukkit/Adapter.java b/backends/bukkit/src/main/java/dev/objz/commandbridge/bukkit/Adapter.java index 2b35f04e..c68c5f96 100644 --- a/backends/bukkit/src/main/java/dev/objz/commandbridge/bukkit/Adapter.java +++ b/backends/bukkit/src/main/java/dev/objz/commandbridge/bukkit/Adapter.java @@ -1,5 +1,7 @@ package dev.objz.commandbridge.bukkit; +import dev.objz.commandbridge.backends.net.BackendClient; +import dev.objz.commandbridge.backends.net.RedisClient; import dev.objz.commandbridge.backends.net.WsClient; import dev.objz.commandbridge.backends.net.in.ExecuteCommandHandler; import dev.objz.commandbridge.backends.net.in.RegistrationHandler; @@ -9,6 +11,7 @@ import dev.objz.commandbridge.backends.platform.cmd.CommandExecutor; import dev.objz.commandbridge.config.ConfigManager; import dev.objz.commandbridge.config.model.BackendsConfig; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.proto.MessageType; @@ -20,7 +23,7 @@ import java.time.Duration; public final class Adapter implements PlatformAdapter { - private WsClient client; + private BackendClient client; private BackendsConfig cfg; private Path dataDir; private JavaPlugin plugin; @@ -72,7 +75,9 @@ public void start(PlatformEnv env) throws Exception { () -> Bukkit.isPrimaryThread(), task -> Bukkit.getScheduler().runTask(plugin, task)); - this.client = new WsClient(cfg, dataDir, this); + this.client = cfg.endpointType() == EndpointType.REDIS + ? new RedisClient(cfg, dataDir, this) + : new WsClient(cfg, dataDir, this); try { client.start(); diff --git a/backends/folia/src/main/java/dev/objz/commandbridge/folia/Adapter.java b/backends/folia/src/main/java/dev/objz/commandbridge/folia/Adapter.java index 36649cbd..389db1dc 100644 --- a/backends/folia/src/main/java/dev/objz/commandbridge/folia/Adapter.java +++ b/backends/folia/src/main/java/dev/objz/commandbridge/folia/Adapter.java @@ -1,5 +1,7 @@ package dev.objz.commandbridge.folia; +import dev.objz.commandbridge.backends.net.BackendClient; +import dev.objz.commandbridge.backends.net.RedisClient; import dev.objz.commandbridge.backends.net.WsClient; import dev.objz.commandbridge.backends.net.in.ExecuteCommandHandler; import dev.objz.commandbridge.backends.net.in.RegistrationHandler; @@ -9,6 +11,7 @@ import dev.objz.commandbridge.backends.platform.cmd.CommandExecutor; import dev.objz.commandbridge.config.ConfigManager; import dev.objz.commandbridge.config.model.BackendsConfig; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.proto.MessageType; import io.papermc.paper.threadedregions.scheduler.ScheduledTask; @@ -21,7 +24,7 @@ import java.util.concurrent.TimeUnit; public final class Adapter implements PlatformAdapter { - private WsClient client; + private BackendClient client; private BackendsConfig cfg; private Path dataDir; private JavaPlugin plugin; @@ -69,7 +72,9 @@ public void start(PlatformEnv env) throws Exception { this.commandExecutor = new FoliaExecutor(plugin); } - this.client = new WsClient(cfg, dataDir, this); + this.client = cfg.endpointType() == EndpointType.REDIS + ? new RedisClient(cfg, dataDir, this) + : new WsClient(cfg, dataDir, this); try { client.start(); diff --git a/backends/paper/src/main/java/dev/objz/commandbridge/paper/Adapter.java b/backends/paper/src/main/java/dev/objz/commandbridge/paper/Adapter.java index 171e3c9a..ddaafe92 100644 --- a/backends/paper/src/main/java/dev/objz/commandbridge/paper/Adapter.java +++ b/backends/paper/src/main/java/dev/objz/commandbridge/paper/Adapter.java @@ -1,5 +1,7 @@ package dev.objz.commandbridge.paper; +import dev.objz.commandbridge.backends.net.BackendClient; +import dev.objz.commandbridge.backends.net.RedisClient; import dev.objz.commandbridge.backends.net.WsClient; import dev.objz.commandbridge.backends.net.in.ExecuteCommandHandler; import dev.objz.commandbridge.backends.net.in.RegistrationHandler; @@ -9,6 +11,7 @@ import dev.objz.commandbridge.backends.platform.cmd.CommandExecutor; import dev.objz.commandbridge.config.ConfigManager; import dev.objz.commandbridge.config.model.BackendsConfig; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.proto.MessageType; @@ -20,7 +23,7 @@ import java.time.Duration; public final class Adapter implements PlatformAdapter { - private WsClient client; + private BackendClient client; private BackendsConfig cfg; private Path dataDir; private JavaPlugin plugin; @@ -72,7 +75,9 @@ public void start(PlatformEnv env) throws Exception { () -> Bukkit.isPrimaryThread(), task -> Bukkit.getScheduler().runTask(plugin, task)); - this.client = new WsClient(cfg, dataDir, this); + this.client = cfg.endpointType() == EndpointType.REDIS + ? new RedisClient(cfg, dataDir, this) + : new WsClient(cfg, dataDir, this); try { client.start(); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/AuthHandler.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/AuthHandler.java index 62ec8c6a..f0b55a91 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/AuthHandler.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/AuthHandler.java @@ -5,7 +5,6 @@ import dev.objz.commandbridge.net.OutNode; import dev.objz.commandbridge.net.proto.MessageType; import dev.objz.commandbridge.backends.net.out.ctx.AuthRequestContext; -import io.undertow.websockets.core.WebSocketChannel; import java.time.Duration; import java.util.concurrent.atomic.AtomicReference; @@ -22,7 +21,7 @@ public AuthHandler(BackendsConfig cfg, OutNode outNode, AtomicReference< this.stateRef = stateRef; } - public boolean authenticate(WebSocketChannel channel) { + public boolean authenticate() { if (!Boolean.TRUE.equals(cfg.security().requireAuth())) { Log.warn("Auth disabled by config; continuing unauthenticated"); stateRef.set(ConnectionState.AUTHENTICATED); @@ -41,7 +40,7 @@ public boolean authenticate(WebSocketChannel channel) { }; Duration timeout = Duration.ofSeconds(cfg.timeouts().authTimeout()); - AuthRequestContext context = new AuthRequestContext(channel, timeout, statusUpdater); + AuthRequestContext context = new AuthRequestContext(timeout, statusUpdater); try { outNode.send(MessageType.AUTH_REQUEST, context); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/BackendClient.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/BackendClient.java new file mode 100644 index 00000000..14fa75ef --- /dev/null +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/BackendClient.java @@ -0,0 +1,32 @@ +package dev.objz.commandbridge.backends.net; + +import dev.objz.commandbridge.net.InNode; +import dev.objz.commandbridge.net.OutNode; +import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.proto.Envelope; +import dev.objz.commandbridge.scripting.model.enums.Location; + +public interface BackendClient extends AutoCloseable { + void start() throws Exception; + + void reconnect() throws Exception; + + void scheduleReconnection(); + + SendOperation send(Envelope request); + + ClientStatus status(); + + String serverId(); + + InNode inboundRouter(); + + OutNode outboundRouter(); + + void setLocation(Location location); + + void setServerId(String serverId); + + @Override + void close() throws Exception; +} diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/ConnectionHandler.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/ConnectionHandler.java index d7c759c2..5c357102 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/ConnectionHandler.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/ConnectionHandler.java @@ -49,7 +49,9 @@ private WebSocketChannel connectInternal(boolean isReconnecting) throws Exceptio : TlsMode.TOFU; String scheme = TlsResolver.schemeFor(mode); - String url = scheme + "://" + cfg.host() + ":" + cfg.port() + "/ws"; + String host = cfg.endpoints().websocket().host(); + int port = cfg.endpoints().websocket().port(); + String url = scheme + "://" + host + ":" + port + "/ws"; if (!isReconnecting) { Log.info("Connecting to {} as '{}'", url, cfg.clientId()); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/MessageRouter.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/MessageRouter.java index 1ff78268..87545b0d 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/MessageRouter.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/MessageRouter.java @@ -8,6 +8,7 @@ import dev.objz.commandbridge.net.OutNode; import dev.objz.commandbridge.net.ResponseAwaiter; import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.endpoints.WsEndpoint; import dev.objz.commandbridge.net.proto.MessageType; import dev.objz.commandbridge.scripting.model.enums.Location; import dev.objz.commandbridge.security.AuthService; @@ -27,7 +28,7 @@ public final class MessageRouter { private final AtomicReference stateRef; private final Runnable reconnectCallback; private final String secret; - private final Location location; + private volatile Location location; public MessageRouter( InNode inNode, @@ -46,9 +47,15 @@ public MessageRouter( this.location = location; } + public void setLocation(Location location) { + this.location = location; + } + public void setupChannel(WebSocketChannel channel) { - inNode.setSendOperationFactory((ch, envelope) -> new SendOperation(ch, envelope, awaiter)); - outNode.setSendOperationFactory(envelope -> new SendOperation(channel, envelope, awaiter)); + var endpoint = new WsEndpoint(channel); + + inNode.setSendOperationFactory((ep, envelope) -> new SendOperation(ep, envelope, awaiter)); + outNode.setSendOperationFactory(envelope -> new SendOperation(endpoint, envelope, awaiter)); inNode.setInboundTap(env -> { boolean matched = false; @@ -76,7 +83,7 @@ public void setupChannel(WebSocketChannel channel) { @Override protected void onFullTextMessage(WebSocketChannel ch, BufferedTextMessage message) { try { - inNode.onText(ch, message.getData()); + inNode.onText(endpoint, message.getData()); } catch (Throwable t) { Log.error(t, "Inbound message handling failed"); } diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/RedisClient.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/RedisClient.java new file mode 100644 index 00000000..dd8c93a1 --- /dev/null +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/RedisClient.java @@ -0,0 +1,303 @@ +package dev.objz.commandbridge.backends.net; + +import dev.objz.commandbridge.backends.platform.PlatformAdapter; +import dev.objz.commandbridge.config.model.BackendsConfig; +import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.InNode; +import dev.objz.commandbridge.net.OutNode; +import dev.objz.commandbridge.net.ResponseAwaiter; +import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.endpoints.RedisEndpoint; +import dev.objz.commandbridge.net.proto.Envelope; +import dev.objz.commandbridge.net.redis.RedisChannels; +import dev.objz.commandbridge.scripting.model.enums.Location; +import dev.objz.commandbridge.security.SecretLoader; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisPool; +import redis.clients.jedis.JedisPubSub; + +import java.nio.file.Path; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public final class RedisClient implements BackendClient { + private final BackendsConfig cfg; + private final Path dataDir; + private final PlatformAdapter adapter; + + private final ReconnectHandler reconnectHandler; + private final AuthHandler authHandler; + private final RedisMessageRouter messageRouter; + + private final InNode inNode = new InNode(); + private final OutNode outNode = new OutNode<>(); + private final ResponseAwaiter awaiter = new ResponseAwaiter(); + private final AtomicReference stateRef = new AtomicReference<>(ConnectionState.DISCONNECTED); + + private volatile JedisPool pool; + private volatile JedisPubSub subscriber; + private volatile Thread subscriberThread; + private volatile boolean running; + private volatile RedisEndpoint proxyEndpoint; + + private Location location = Location.BACKEND; + private String serverId; + + public RedisClient(BackendsConfig cfg, Path dataDir, PlatformAdapter adapter) { + this.cfg = Objects.requireNonNull(cfg); + this.dataDir = Objects.requireNonNull(dataDir); + this.adapter = Objects.requireNonNull(adapter); + + this.reconnectHandler = new ReconnectHandler(cfg, adapter, this::attemptReconnect); + this.authHandler = new AuthHandler(cfg, outNode, stateRef); + this.messageRouter = new RedisMessageRouter( + inNode, + outNode, + awaiter, + stateRef, + resolveSecret(), + location); + + outNode.setClientId(cfg.clientId()); + } + + @Override + public synchronized void start() throws Exception { + if (isConnected()) { + Log.debug("Already connected, skipping start"); + return; + } + + if (!reconnectHandler.isReconnecting()) { + reconnectHandler.stopReconnect(); + } + + try { + stateRef.set(ConnectionState.CONNECTING); + running = true; + pool = createPool(); + + proxyEndpoint = new RedisEndpoint("proxy", this::publishToProxy, this::isConnected); + messageRouter.setupEndpoint(proxyEndpoint); + startSubscriber(); + + stateRef.set(ConnectionState.CONNECTED); + authHandler.authenticate(); + + if (reconnectHandler.isReconnecting()) { + reconnectHandler.onReconnectSuccess(); + } + } catch (Exception e) { + stateRef.set(ConnectionState.DISCONNECTED); + running = false; + stopInternal(); + if (!reconnectHandler.isReconnecting()) { + Log.warn("Redis connection failed: {}", e.getMessage()); + } + throw e; + } + } + + @Override + public synchronized void reconnect() throws Exception { + Log.warn("Manual reconnection initiated"); + reconnectHandler.shutdown(); + messageRouter.clearTap(); + stopInternal(); + stateRef.set(ConnectionState.DISCONNECTED); + start(); + } + + @Override + public void scheduleReconnection() { + reconnectHandler.scheduleReconnect(); + } + + @Override + public SendOperation send(Envelope request) { + if (!isConnected() || proxyEndpoint == null) { + throw new IllegalStateException("Redis endpoint not connected"); + } + return new SendOperation(proxyEndpoint, request, awaiter); + } + + @Override + public synchronized void close() { + Log.debug("Closing RedisClient"); + reconnectHandler.shutdown(); + messageRouter.clearTap(); + stopInternal(); + stateRef.set(ConnectionState.DISCONNECTED); + Log.debug("RedisClient closed"); + } + + @Override + public ClientStatus status() { + return stateRef.get().toClientStatus(); + } + + @Override + public String serverId() { + return serverId; + } + + @Override + public InNode inboundRouter() { + return inNode; + } + + @Override + public OutNode outboundRouter() { + return outNode; + } + + @Override + public void setLocation(Location location) { + this.location = Objects.requireNonNull(location); + this.messageRouter.setLocation(this.location); + } + + @Override + public void setServerId(String serverId) { + this.serverId = serverId; + outNode.setServerId(serverId); + } + + private boolean isConnected() { + return running && pool != null; + } + + private void onConnectionLost() { + if (!running) { + return; + } + stateRef.set(ConnectionState.RECONNECTING); + stopInternal(); + reconnectHandler.scheduleReconnect(); + } + + private void attemptReconnect() { + try { + stopInternal(); + start(); + } catch (Exception e) { + throw new RuntimeException(e.getMessage(), e); + } + } + + private void startSubscriber() throws Exception { + CountDownLatch subscribed = new CountDownLatch(1); + subscriberThread = new Thread(() -> subscribeLoop(subscribed), "commandbridge-redis-client-sub"); + subscriberThread.setDaemon(true); + subscriberThread.start(); + + if (!subscribed.await(5, TimeUnit.SECONDS)) { + throw new IllegalStateException("Timed out waiting for Redis subscription"); + } + } + + private void subscribeLoop(CountDownLatch subscribed) { + try (Jedis jedis = pool.getResource()) { + JedisPubSub localSubscriber = new JedisPubSub() { + @Override + public void onSubscribe(String channel, int subscribedChannels) { + subscribed.countDown(); + } + + @Override + public void onMessage(String channel, String message) { + RedisEndpoint endpoint = proxyEndpoint; + if (endpoint != null) { + messageRouter.onText(endpoint, message); + } + } + }; + + subscriber = localSubscriber; + jedis.subscribe(localSubscriber, RedisChannels.clientInbound(cfg.clientId())); + } catch (Exception e) { + subscribed.countDown(); + if (running) { + Log.warn("Redis subscriber disconnected: {}", e.getMessage()); + onConnectionLost(); + } + } finally { + subscriber = null; + } + } + + private CompletableFuture publishToProxy(Envelope env) { + try { + String payload = Envelope.MAPPER.writeValueAsString(env); + try (Jedis jedis = pool.getResource()) { + jedis.publish(RedisChannels.PROXY_INBOUND, payload); + } + return CompletableFuture.completedFuture(null); + } catch (Exception e) { + if (running) { + Log.warn("Redis publish failed: {}", e.getMessage()); + onConnectionLost(); + } + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(e); + return failed; + } + } + + private JedisPool createPool() { + BackendsConfig.Endpoints.Redis redis = cfg.endpoints().redis(); + DefaultJedisClientConfig.Builder builder = DefaultJedisClientConfig.builder(); + if (redis.username() != null && !redis.username().isBlank()) { + builder.user(redis.username().trim()); + } + if (redis.password() != null && !redis.password().isBlank()) { + builder.password(redis.password()); + } + return new JedisPool(new HostAndPort(redis.host(), redis.port()), builder.build()); + } + + private void stopInternal() { + running = false; + + JedisPubSub sub = subscriber; + if (sub != null) { + try { + sub.unsubscribe(); + } catch (Exception ignore) { + } + } + + Thread t = subscriberThread; + subscriberThread = null; + if (t != null) { + t.interrupt(); + try { + t.join(1000L); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + JedisPool p = pool; + pool = null; + if (p != null) { + try { + p.close(); + } catch (Exception ignore) { + } + } + } + + private String resolveSecret() { + String s = cfg.security() != null ? cfg.security().secret() : null; + if (s != null && !s.isBlank()) { + return s; + } + return new SecretLoader(dataDir).loadOrCreate(); + } +} diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/RedisMessageRouter.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/RedisMessageRouter.java new file mode 100644 index 00000000..30bac4dd --- /dev/null +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/RedisMessageRouter.java @@ -0,0 +1,82 @@ +package dev.objz.commandbridge.backends.net; + +import dev.objz.commandbridge.backends.net.in.PingHandler; +import dev.objz.commandbridge.backends.net.out.AuthRequest; +import dev.objz.commandbridge.backends.net.out.InvokedCommandEvent; +import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; +import dev.objz.commandbridge.net.InNode; +import dev.objz.commandbridge.net.OutNode; +import dev.objz.commandbridge.net.ResponseAwaiter; +import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.proto.MessageType; +import dev.objz.commandbridge.scripting.model.enums.Location; +import dev.objz.commandbridge.security.AuthService; + +import java.util.concurrent.atomic.AtomicReference; + +public final class RedisMessageRouter { + private final InNode inNode; + private final OutNode outNode; + private final ResponseAwaiter awaiter; + private final AtomicReference stateRef; + private final String secret; + private volatile Location location; + + public RedisMessageRouter( + InNode inNode, + OutNode outNode, + ResponseAwaiter awaiter, + AtomicReference stateRef, + String secret, + Location location) { + this.inNode = inNode; + this.outNode = outNode; + this.awaiter = awaiter; + this.stateRef = stateRef; + this.secret = secret; + this.location = location; + } + + public void setLocation(Location location) { + this.location = location; + } + + public void setupEndpoint(Endpoint endpoint) { + inNode.setSendOperationFactory((ep, envelope) -> new SendOperation(ep, envelope, awaiter)); + outNode.setSendOperationFactory(envelope -> new SendOperation(endpoint, envelope, awaiter)); + + inNode.setInboundTap(env -> { + boolean matched = false; + try { + matched = awaiter.signal(env); + } catch (Exception ignore) { + } + + ConnectionState state = stateRef.get(); + if (state != ConnectionState.AUTHENTICATED) { + switch (env.type()) { + case AUTH_OK: + case AUTH_FAIL: + return matched; + default: + Log.warn("Dropping {} while unauthenticated", env.type()); + return true; + } + } + return matched; + }); + + outNode.register(MessageType.AUTH_REQUEST, new AuthRequest(new AuthService(secret), location)); + outNode.register(MessageType.INVOKED_COMMAND, new InvokedCommandEvent()); + inNode.register(MessageType.PING, new PingHandler()); + } + + public void onText(Endpoint endpoint, String text) { + inNode.onText(endpoint, text); + } + + public void clearTap() { + inNode.setInboundTap(null); + } +} diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/WsClient.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/WsClient.java index 637a8a98..15872ea5 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/WsClient.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/WsClient.java @@ -8,6 +8,7 @@ import dev.objz.commandbridge.net.OutNode; import dev.objz.commandbridge.net.ResponseAwaiter; import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.endpoints.WsEndpoint; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.scripting.model.enums.Location; import dev.objz.commandbridge.security.SecretLoader; @@ -17,7 +18,7 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; -public final class WsClient implements AutoCloseable { +public final class WsClient implements BackendClient { private final BackendsConfig cfg; private final Path dataDir; @@ -60,6 +61,7 @@ public WsClient(BackendsConfig cfg, Path dataDir, PlatformAdapter adapter) { outNode.setClientId(cfg.clientId()); } + @Override public synchronized void start() throws Exception { if (connectionHandler.isConnected()) { Log.debug("Already connected, skipping start"); @@ -87,7 +89,7 @@ public synchronized void start() throws Exception { messageRouter.setupChannel(channel); - authHandler.authenticate(channel); + authHandler.authenticate(); if (reconnectHandler.isReconnecting()) { reconnectHandler.onReconnectSuccess(); @@ -104,6 +106,7 @@ public synchronized void start() throws Exception { } } + @Override public synchronized void reconnect() throws Exception { Log.warn("Manual reconnection initiated"); reconnectHandler.shutdown(); @@ -114,21 +117,23 @@ public synchronized void reconnect() throws Exception { start(); } + @Override public void scheduleReconnection() { reconnectHandler.scheduleReconnect(); } + @Override public SendOperation send(Envelope request) { if (!connectionHandler.isChannelHealthy()) { - throw new IllegalStateException("WebSocket not connected or unhealthy"); + throw new IllegalStateException("Endpoint not connected or unhealthy"); } WebSocketChannel channel = connectionHandler.getChannel(); if (channel == null) { - throw new IllegalStateException("WebSocket channel is null"); + throw new IllegalStateException("Endpoint transport is null"); } - return new SendOperation(channel, request, awaiter); + return new SendOperation(new WsEndpoint(channel), request, awaiter); } @Override @@ -152,26 +157,33 @@ public synchronized void close() throws Exception { Log.debug("WsClient closed"); } + @Override public ClientStatus status() { return stateRef.get().toClientStatus(); } + @Override public String serverId() { return serverId; } + @Override public InNode inboundRouter() { return inNode; } + @Override public OutNode outboundRouter() { return outNode; } + @Override public void setLocation(Location location) { this.location = Objects.requireNonNull(location); + this.messageRouter.setLocation(this.location); } + @Override public void setServerId(String serverId) { this.serverId = serverId; outNode.setServerId(serverId); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/in/ExecuteCommandHandler.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/in/ExecuteCommandHandler.java index 9fb3d4aa..94b41fa7 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/in/ExecuteCommandHandler.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/in/ExecuteCommandHandler.java @@ -2,13 +2,13 @@ import dev.objz.commandbridge.backends.platform.cmd.CommandExecutor; import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.InboundHandler; import dev.objz.commandbridge.net.payloads.cmd.ExecuteCommand; import dev.objz.commandbridge.net.payloads.cmd.ExecuteCommandResult; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; import dev.objz.commandbridge.scripting.model.enums.RunAs; -import io.undertow.websockets.core.WebSocketChannel; import java.util.HashSet; import java.util.Objects; @@ -23,10 +23,10 @@ public ExecuteCommandHandler(CommandExecutor executor) { } @Override - public void accept(WebSocketChannel ch, Envelope env) { + public void accept(Endpoint endpoint, Envelope env) { if (env.payload() == null) { Log.warn("Received EXECUTE_COMMAND with null payload from '{}'", env.from()); - sendFailure(ch, env, null, null, "Null payload received"); + sendFailure(endpoint, env, null, null, "Null payload received"); return; } @@ -35,13 +35,13 @@ public void accept(WebSocketChannel ch, Envelope env) { exec = Envelope.MAPPER.treeToValue(env.payload(), ExecuteCommand.class); } catch (Exception e) { Log.error(e, "Failed to parse EXECUTE_COMMAND from '{}'", env.from()); - sendFailure(ch, env, null, null, "Failed to parse command: " + e.getMessage()); + sendFailure(endpoint, env, null, null, "Failed to parse command: " + e.getMessage()); return; } if (exec == null || exec.command() == null || exec.command().isBlank()) { Log.warn("Received empty EXECUTE_COMMAND from '{}'", env.from()); - sendFailure(ch, env, "", null, "Empty command received"); + sendFailure(endpoint, env, "", null, "Empty command received"); return; } @@ -61,24 +61,24 @@ public void accept(WebSocketChannel ch, Envelope env) { .thenAccept(result -> { if (result.isSuccess()) { Log.debug("Command '{}' executed successfully", exec.command()); - sendSuccess(ch, env, exec.command(), exec.uuid()); + sendSuccess(endpoint, env, exec.command(), exec.uuid()); } else { Log.warn("Command '{}' execution failed: {}", exec.command(), result.message()); - sendFailure(ch, env, exec.command(), exec.uuid(), result.message()); + sendFailure(endpoint, env, exec.command(), exec.uuid(), result.message()); } }) .exceptionally(ex -> { Log.error(ex, "Command '{}' execution threw exception", exec.command()); - sendFailure(ch, env, exec.command(), exec.uuid(), + sendFailure(endpoint, env, exec.command(), exec.uuid(), "Exception: " + ex.getMessage()); return null; }); } - private void sendSuccess(WebSocketChannel ch, Envelope env, String command, java.util.UUID playerUuid) { + private void sendSuccess(Endpoint endpoint, Envelope env, String command, java.util.UUID playerUuid) { ExecuteCommandResult result = ExecuteCommandResult.success(command, playerUuid); - reply(ch, env, MessageType.EXECUTE_COMMAND_RESULT, result) + reply(endpoint, env, MessageType.EXECUTE_COMMAND_RESULT, result) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send execution result: {}", ex.toString()); @@ -86,10 +86,10 @@ private void sendSuccess(WebSocketChannel ch, Envelope env, String command, java }); } - private void sendFailure(WebSocketChannel ch, Envelope env, String command, java.util.UUID playerUuid, + private void sendFailure(Endpoint endpoint, Envelope env, String command, java.util.UUID playerUuid, String message) { ExecuteCommandResult result = ExecuteCommandResult.failure(command, playerUuid, message); - reply(ch, env, MessageType.EXECUTE_COMMAND_RESULT, result) + reply(endpoint, env, MessageType.EXECUTE_COMMAND_RESULT, result) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send execution result: {}", ex.toString()); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/in/PingHandler.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/in/PingHandler.java index d52ed32e..59da8dc7 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/in/PingHandler.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/in/PingHandler.java @@ -1,23 +1,23 @@ package dev.objz.commandbridge.backends.net.in; import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.InboundHandler; import dev.objz.commandbridge.net.payloads.util.PingPayload; import dev.objz.commandbridge.net.payloads.util.PongPayload; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; -import io.undertow.websockets.core.WebSocketChannel; public final class PingHandler extends InboundHandler { @Override - public void accept(WebSocketChannel ch, Envelope env) { + public void accept(Endpoint endpoint, Envelope env) { try { PingPayload ping = Envelope.MAPPER.treeToValue(env.payload(), PingPayload.class); Log.debug("Received ping request from '{}' with timestamp {}", env.from(), ping.timestamp()); PongPayload pong = new PongPayload(ping.timestamp()); - reply(ch, env, MessageType.PONG, pong) + reply(endpoint, env, MessageType.PONG, pong) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send pong response: {}", ex.toString()); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/in/RegistrationHandler.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/in/RegistrationHandler.java index 272f6e9b..e02c86f8 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/in/RegistrationHandler.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/in/RegistrationHandler.java @@ -1,10 +1,11 @@ package dev.objz.commandbridge.backends.net.in; -import dev.objz.commandbridge.backends.net.WsClient; +import dev.objz.commandbridge.backends.net.BackendClient; import dev.objz.commandbridge.backends.platform.cmd.ArgumentMapper; import dev.objz.commandbridge.backends.platform.cmd.CommandRegistry; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.logging.Summary; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.InboundHandler; import dev.objz.commandbridge.net.payloads.cmd.CommandStub; import dev.objz.commandbridge.net.payloads.cmd.RegisterCommands; @@ -12,22 +13,21 @@ import dev.objz.commandbridge.net.payloads.feedback.FeedbackCollector; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; -import io.undertow.websockets.core.WebSocketChannel; import java.util.List; import java.util.Objects; public final class RegistrationHandler extends InboundHandler { - private final WsClient ws; + private final BackendClient client; private final CommandRegistry registry; - public RegistrationHandler(WsClient ws) { - this.ws = Objects.requireNonNull(ws); - this.registry = new CommandRegistry(new ArgumentMapper(), ws.outboundRouter()); + public RegistrationHandler(BackendClient client) { + this.client = Objects.requireNonNull(client); + this.registry = new CommandRegistry(new ArgumentMapper(), client.outboundRouter()); } @Override - public void accept(WebSocketChannel ch, Envelope env) { + public void accept(Endpoint endpoint, Envelope env) { RegisterCommands rc = null; try { rc = Envelope.MAPPER.treeToValue(env.payload(), RegisterCommands.class); @@ -46,7 +46,7 @@ public void accept(WebSocketChannel ch, Envelope env) { Feedback f = new Feedback(0, 0, 0, List.of("Received empty registration request"), List.of()); - reply(ch, env, MessageType.REGISTER_COMMANDS_RESULT, f).dispatch().exceptionally(ex -> { + reply(endpoint, env, MessageType.REGISTER_COMMANDS_RESULT, f).dispatch().exceptionally(ex -> { Log.warn("Failed to send FEEDBACK: {}", ex.toString()); return null; }); @@ -64,12 +64,12 @@ public void accept(WebSocketChannel ch, Envelope env) { + t.getMessage()); } } - ws.setServerId(env.from()); + client.setServerId(env.from()); Feedback f = fc.build(); Summary.feedbackSummary("Registration", f, env.from()); Summary.feedbackDetails(f, env.from(), true); - reply(ch, env, MessageType.REGISTER_COMMANDS_RESULT, f).dispatch().exceptionally(ex -> { + reply(endpoint, env, MessageType.REGISTER_COMMANDS_RESULT, f).dispatch().exceptionally(ex -> { Log.warn("Failed to send FEEDBACK: {}", ex.toString()); return null; }); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/net/out/ctx/AuthRequestContext.java b/backends/src/main/java/dev/objz/commandbridge/backends/net/out/ctx/AuthRequestContext.java index 65661051..a5b0f053 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/net/out/ctx/AuthRequestContext.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/net/out/ctx/AuthRequestContext.java @@ -1,7 +1,5 @@ package dev.objz.commandbridge.backends.net.out.ctx; -import io.undertow.websockets.core.WebSocketChannel; - import java.time.Duration; import java.util.Objects; import java.util.function.Consumer; @@ -9,12 +7,10 @@ import dev.objz.commandbridge.backends.net.ClientStatus; public final class AuthRequestContext { - public final WebSocketChannel ch; public final Duration timeout; public final Consumer statusSink; - public AuthRequestContext(WebSocketChannel ch, Duration timeout, Consumer statusSink) { - this.ch = Objects.requireNonNull(ch); + public AuthRequestContext(Duration timeout, Consumer statusSink) { this.timeout = Objects.requireNonNull(timeout); this.statusSink = Objects.requireNonNull(statusSink); } diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/platform/bootstrap/BukkitMain.java b/backends/src/main/java/dev/objz/commandbridge/backends/platform/bootstrap/BukkitMain.java index 215431d0..59f2459c 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/platform/bootstrap/BukkitMain.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/platform/bootstrap/BukkitMain.java @@ -4,6 +4,8 @@ import dev.objz.commandbridge.backends.platform.PlatformDetector; import dev.objz.commandbridge.backends.platform.PlatformDetector.Platform; import dev.objz.commandbridge.logging.Log; + +import org.bstats.bukkit.Metrics; import org.bukkit.plugin.java.JavaPlugin; public final class BukkitMain extends JavaPlugin { @@ -44,6 +46,8 @@ public void onLoad() { @Override public void onEnable() { + int pluginID = 22008; + new Metrics(this, pluginID); try { var env = new PlatformAdapter.PlatformEnv(getDataFolder().toPath()); adapter.start(env); diff --git a/backends/src/main/java/dev/objz/commandbridge/backends/platform/cmd/ClientCommands.java b/backends/src/main/java/dev/objz/commandbridge/backends/platform/cmd/ClientCommands.java index ae2073a2..be7614dc 100644 --- a/backends/src/main/java/dev/objz/commandbridge/backends/platform/cmd/ClientCommands.java +++ b/backends/src/main/java/dev/objz/commandbridge/backends/platform/cmd/ClientCommands.java @@ -2,7 +2,7 @@ import dev.jorel.commandapi.CommandAPICommand; import dev.jorel.commandapi.executors.CommandExecutor; -import dev.objz.commandbridge.backends.net.WsClient; +import dev.objz.commandbridge.backends.net.BackendClient; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.util.MM; import net.kyori.adventure.audience.Audience; @@ -14,7 +14,7 @@ public final class ClientCommands { private ClientCommands() { } - public static void register(WsClient client) { + public static void register(BackendClient client) { new CommandAPICommand("commandbridgeclient") .withAliases("cbc") .withPermission("commandbridge.admin") diff --git a/backends/velocity/src/main/java/dev/objz/commandbridge/velocity/Adapter.java b/backends/velocity/src/main/java/dev/objz/commandbridge/velocity/Adapter.java index 3879d8aa..2e899b4c 100644 --- a/backends/velocity/src/main/java/dev/objz/commandbridge/velocity/Adapter.java +++ b/backends/velocity/src/main/java/dev/objz/commandbridge/velocity/Adapter.java @@ -3,6 +3,8 @@ import com.velocitypowered.api.proxy.ProxyServer; import com.velocitypowered.api.scheduler.ScheduledTask; +import dev.objz.commandbridge.backends.net.BackendClient; +import dev.objz.commandbridge.backends.net.RedisClient; import dev.objz.commandbridge.backends.net.WsClient; import dev.objz.commandbridge.backends.net.in.ExecuteCommandHandler; import dev.objz.commandbridge.backends.net.in.RegistrationHandler; @@ -13,6 +15,7 @@ import dev.objz.commandbridge.backends.platform.cmd.CommandExecutor; import dev.objz.commandbridge.config.ConfigManager; import dev.objz.commandbridge.config.model.BackendsConfig; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.proto.MessageType; import dev.objz.commandbridge.scripting.model.enums.Location; @@ -21,7 +24,7 @@ import java.time.Duration; public final class Adapter implements PlatformAdapter { - private WsClient client; + private BackendClient client; private BackendsConfig cfg; private Path dataDir; private ProxyServer proxy; @@ -80,7 +83,9 @@ public void start(PlatformEnv env) throws Exception { this.commandExecutor = new VelocityExecutor(proxy, pluginInstance); } - this.client = new WsClient(cfg, dataDir, this); + this.client = cfg.endpointType() == EndpointType.REDIS + ? new RedisClient(cfg, dataDir, this) + : new WsClient(cfg, dataDir, this); client.setLocation(Location.VELOCITY); diff --git a/core/src/main/java/dev/objz/commandbridge/config/ConfigManager.java b/core/src/main/java/dev/objz/commandbridge/config/ConfigManager.java index f5d8a5ca..17f14526 100644 --- a/core/src/main/java/dev/objz/commandbridge/config/ConfigManager.java +++ b/core/src/main/java/dev/objz/commandbridge/config/ConfigManager.java @@ -7,14 +7,20 @@ import dev.objz.commandbridge.config.profile.VelocityConfigProfile; import dev.objz.commandbridge.logging.Log; import org.spongepowered.configurate.ConfigurationNode; +import org.spongepowered.configurate.serialize.SerializationException; +import org.spongepowered.configurate.objectmapping.meta.Setting; import org.spongepowered.configurate.yaml.NodeStyle; import org.spongepowered.configurate.yaml.YamlConfigurationLoader; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.RecordComponent; import java.nio.file.Files; import java.nio.file.Path; import java.util.Comparator; import java.util.HashSet; +import java.util.Locale; import java.util.Map; import java.util.Set; @@ -53,6 +59,7 @@ public boolean load(Class modelClass) { ConfigurationNode root = loader.load(); ConfigProfile profile = profileOf(modelClass); + T defaults = profile.defaults(); Set valid = new HashSet<>(ConfigKeys.topLevelKeysOf(modelClass)); for (var key : root.childrenMap().keySet()) { @@ -68,7 +75,14 @@ public boolean load(Class modelClass) { } } - T loaded = root.get(modelClass, profile.defaults()); + boolean enumOk = validateEnumValues(root, modelClass, ""); + if (!enumOk) { + this.current = defaults; + Log.error("Invalid config.yml"); + return false; + } + + T loaded = root.get(modelClass, defaults); var result = profile.normalize(loaded); this.current = result.config(); @@ -78,6 +92,10 @@ public boolean load(Class modelClass) { } return result.ok(); + } catch (SerializationException e) { + Log.error("Invalid config.yml: {}", e.getMessage()); + this.current = profileOf(modelClass).defaults(); + return false; } catch (Exception e) { Log.error("Error loading config.yml: " + e.getMessage(), e); this.current = profileOf(modelClass).defaults(); @@ -144,4 +162,102 @@ private static int levenshtein(String a, String b) { } return costs[b.length()]; } + + private static boolean validateEnumValues(ConfigurationNode node, Class recordType, String path) { + if (!recordType.isRecord() || node == null) { + return true; + } + + boolean ok = true; + for (RecordComponent rc : recordType.getRecordComponents()) { + String key = yamlKeyFor(recordType, rc); + Class componentType = rc.getType(); + ConfigurationNode child = node.node(key); + String fullPath = path == null || path.isBlank() ? key : path + "." + key; + + if (componentType.isEnum()) { + if (child.virtual()) { + continue; + } + + Object raw = child.raw(); + String rawText = raw == null ? null : String.valueOf(raw).trim(); + Class> enumType = (Class>) componentType; + + if (rawText == null || rawText.isBlank()) { + Log.error("'{}' must be set", fullPath); + ok = false; + continue; + } + + String canonical = canonicalEnumValue(enumType, rawText); + if (canonical != null) { + if (!canonical.equals(rawText)) { + try { + child.set(canonical); + } catch (SerializationException ignored) { + } + } + continue; + } + + Set values = enumConstants(enumType); + String suggestion = findClosest(rawText.toUpperCase(Locale.ROOT), values); + if (suggestion != null) { + Log.warn("Invalid value '{}' for '{}'. did you mean '{}' ?", rawText, fullPath, suggestion); + } else { + Log.error("Invalid value '{}' for '{}'. expected one of {}", rawText, fullPath, values); + } + ok = false; + } else if (componentType.isRecord()) { + ok &= validateEnumValues(child, componentType, fullPath); + } + } + + return ok; + } + + private static String yamlKeyFor(Class recordType, RecordComponent rc) { + try { + Method m = recordType.getMethod(rc.getName()); + Setting s = m.getAnnotation(Setting.class); + if (s != null && !s.value().isBlank()) { + return s.value(); + } + } catch (NoSuchMethodException ignored) { + } + + try { + Field f = recordType.getDeclaredField(rc.getName()); + Setting s = f.getAnnotation(Setting.class); + if (s != null && !s.value().isBlank()) { + return s.value(); + } + } catch (NoSuchFieldException ignored) { + } + + Setting s = rc.getAnnotation(Setting.class); + if (s != null && !s.value().isBlank()) { + return s.value(); + } + + return rc.getName().replaceAll("(? enumConstants(Class> enumType) { + Set values = new HashSet<>(); + for (Enum constant : enumType.getEnumConstants()) { + values.add(constant.name()); + } + return values; + } + + private static String canonicalEnumValue(Class> enumType, String raw) { + for (Enum constant : enumType.getEnumConstants()) { + if (constant.name().equalsIgnoreCase(raw)) { + return constant.name(); + } + } + return null; + } } diff --git a/core/src/main/java/dev/objz/commandbridge/config/model/BackendsConfig.java b/core/src/main/java/dev/objz/commandbridge/config/model/BackendsConfig.java index 94c4cbb5..d14cca2f 100644 --- a/core/src/main/java/dev/objz/commandbridge/config/model/BackendsConfig.java +++ b/core/src/main/java/dev/objz/commandbridge/config/model/BackendsConfig.java @@ -5,13 +5,33 @@ @ConfigSerializable public record BackendsConfig( - @Setting("host") String host, - @Setting("port") int port, @Setting("client-id") String clientId, + @Setting("endpoint-type") EndpointType endpointType, + @Setting("endpoints") Endpoints endpoints, @Setting("security") Security security, @Setting("timeouts") Timeouts timeouts, @Setting("limits") Limits limits, @Setting("debug") boolean debug) { + @ConfigSerializable + public static record Endpoints( + @Setting("websocket") WebSocket websocket, + @Setting("redis") Redis redis) { + + @ConfigSerializable + public static record WebSocket( + @Setting("host") String host, + @Setting("port") int port) { + } + + @ConfigSerializable + public static record Redis( + @Setting("host") String host, + @Setting("port") int port, + @Setting("username") String username, + @Setting("password") String password) { + } + } + @ConfigSerializable public static record Security( @Setting("tls-mode") TlsMode tlsMode, @@ -35,9 +55,11 @@ public static record Limits( public static BackendsConfig defaults() { return new BackendsConfig( - "127.0.0.1", - 8765, "survival-1", + EndpointType.WEBSOCKET, + new Endpoints( + new Endpoints.WebSocket("127.0.0.1", 8765), + new Endpoints.Redis("127.0.0.1", 6379, "", "")), new Security(TlsMode.TOFU, "", "change-me", true), new Timeouts(5, 60, 5), new Limits(60), diff --git a/core/src/main/java/dev/objz/commandbridge/config/model/EndpointType.java b/core/src/main/java/dev/objz/commandbridge/config/model/EndpointType.java new file mode 100644 index 00000000..5cba45c5 --- /dev/null +++ b/core/src/main/java/dev/objz/commandbridge/config/model/EndpointType.java @@ -0,0 +1,6 @@ +package dev.objz.commandbridge.config.model; + +public enum EndpointType { + WEBSOCKET, + REDIS +} diff --git a/core/src/main/java/dev/objz/commandbridge/config/model/VelocityConfig.java b/core/src/main/java/dev/objz/commandbridge/config/model/VelocityConfig.java index 23cb215c..7e00528d 100644 --- a/core/src/main/java/dev/objz/commandbridge/config/model/VelocityConfig.java +++ b/core/src/main/java/dev/objz/commandbridge/config/model/VelocityConfig.java @@ -6,14 +6,37 @@ @ConfigSerializable public record VelocityConfig( @Setting("act-as-client") boolean actAsClient, - @Setting("bind-host") String bindHost, - @Setting("bind-port") int bindPort, @Setting("server-id") String serverId, + @Setting("endpoint-type") EndpointType endpointType, + @Setting("endpoints") Endpoints endpoints, @Setting("heartbeat") Heartbeat heartbeat, @Setting("security") Security security, @Setting("timeouts") Timeouts timeouts, @Setting("limits") Limits limits, @Setting("debug") boolean debug) { + + @ConfigSerializable + public static record Endpoints( + @Setting("websocket") WebSocket webSocket, + @Setting("redis") Redis redis + + ) { + @ConfigSerializable + public static record WebSocket( + @Setting("bind-host") String bindHost, + @Setting("bind-port") int bindPort) { + } + + @ConfigSerializable + public static record Redis( + @Setting("host") String host, + @Setting("port") int port, + @Setting("username") String username, + @Setting("password") String password) { + } + + } + @ConfigSerializable public static record Heartbeat( @Setting("app-ping-seconds") int appPingSeconds, @@ -47,9 +70,11 @@ public static record Limits( public static VelocityConfig defaults() { return new VelocityConfig( false, - "0.0.0.0", - 8765, "proxy-1", + EndpointType.WEBSOCKET, + new Endpoints( + new Endpoints.WebSocket("0.0.0.0", 8765), + new Endpoints.Redis("127.0.0.1", 6379, "", "")), new Heartbeat(10, 60), new Security(true, 10, TlsMode.TOFU, "", "", "PKCS12"), new Timeouts(5, 5), diff --git a/core/src/main/java/dev/objz/commandbridge/config/profile/BackendsConfigProfile.java b/core/src/main/java/dev/objz/commandbridge/config/profile/BackendsConfigProfile.java index db8d91eb..ffb53cca 100644 --- a/core/src/main/java/dev/objz/commandbridge/config/profile/BackendsConfigProfile.java +++ b/core/src/main/java/dev/objz/commandbridge/config/profile/BackendsConfigProfile.java @@ -1,6 +1,7 @@ package dev.objz.commandbridge.config.profile; import dev.objz.commandbridge.config.model.BackendsConfig; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.config.model.TlsMode; import dev.objz.commandbridge.logging.Log; @@ -15,39 +16,87 @@ public Result normalize(BackendsConfig in) { BackendsConfig d = defaults(); boolean ok = true; - String host = in.host(); - if (host == null || host.isBlank()) { - Log.error("'host' must not be empty"); - host = d.host(); + String clientId = in.clientId(); + if (clientId == null || clientId.isBlank()) { + Log.warn("'client-id' missing or blank"); + clientId = d.clientId(); + ok = false; + } else { + clientId = clientId.trim(); + } + + EndpointType endpointType = in.endpointType(); + if (endpointType == null) { + Log.error("'endpoint-type' must be set"); + endpointType = d.endpointType(); + ok = false; + } + boolean websocketMode = endpointType == EndpointType.WEBSOCKET; + boolean redisMode = endpointType == EndpointType.REDIS; + + BackendsConfig.Endpoints endpointsIn = in.endpoints() != null ? in.endpoints() : d.endpoints(); + + BackendsConfig.Endpoints.WebSocket wsIn = endpointsIn.websocket() != null + ? endpointsIn.websocket() + : d.endpoints().websocket(); + + String wsHost = wsIn.host(); + if (wsHost == null || wsHost.isBlank()) { + Log.error("'endpoints.websocket.host' must not be empty"); + wsHost = d.endpoints().websocket().host(); ok = false; } else { - host = host.trim(); - if (host.startsWith("ws://") || host.startsWith("wss://")) { - Log.warn("'host' must NOT include ws:// or wss:// "); - host = host.replaceFirst("^wss?://", ""); // normalize anyway + wsHost = wsHost.trim(); + if (wsHost.startsWith("ws://") || wsHost.startsWith("wss://")) { + if (websocketMode) { + Log.warn("'endpoints.websocket.host' must NOT include ws:// or wss://"); + } + wsHost = wsHost.replaceFirst("^wss?://", ""); } } - // port - int port = in.port(); - if (port <= 0 || port > 65535) { - Log.error("'port' must be between 1 and 65535"); - port = d.port(); + int wsPort = wsIn.port(); + if (wsPort <= 0 || wsPort > 65535) { + Log.error("'endpoints.websocket.port' must be between 1 and 65535"); + wsPort = d.endpoints().websocket().port(); ok = false; } - // client-id - String clientId = in.clientId(); - if (clientId == null || clientId.isBlank()) { - Log.warn("'client-id' missing or blank"); - clientId = d.clientId(); + BackendsConfig.Endpoints.Redis redisIn = endpointsIn.redis() != null + ? endpointsIn.redis() + : d.endpoints().redis(); + + String redisHost = redisIn.host(); + if (redisHost == null || redisHost.isBlank()) { + Log.error("'endpoints.redis.host' must not be empty"); + redisHost = d.endpoints().redis().host(); + ok = false; + } else { + redisHost = redisHost.trim(); + if (redisHost.startsWith("redis://") || redisHost.startsWith("rediss://")) { + if (redisMode) { + Log.warn("'endpoints.redis.host' must NOT include redis:// or rediss://"); + } + redisHost = redisHost.replaceFirst("^rediss?://", ""); + } + } + + int redisPort = redisIn.port(); + if (redisPort <= 0 || redisPort > 65535) { + Log.error("'endpoints.redis.port' must be between 1 and 65535"); + redisPort = d.endpoints().redis().port(); ok = false; } - // limits - BackendsConfig.Limits limitsIn = in.limits(); - int inboundMessagesSec = (limitsIn != null ? limitsIn.inboundMessagesSec() - : d.limits().inboundMessagesSec()); + String redisUsername = redisIn.username() == null ? "" : redisIn.username().trim(); + String redisPassword = redisIn.password() == null ? "" : redisIn.password(); + + BackendsConfig.Endpoints endpointsOut = new BackendsConfig.Endpoints( + new BackendsConfig.Endpoints.WebSocket(wsHost, wsPort), + new BackendsConfig.Endpoints.Redis(redisHost, redisPort, redisUsername, redisPassword)); + + BackendsConfig.Limits limitsIn = in.limits() != null ? in.limits() : d.limits(); + int inboundMessagesSec = limitsIn.inboundMessagesSec(); if (inboundMessagesSec <= 0) { Log.error("'limits.inbound-messages-per-sec' must be positive"); inboundMessagesSec = d.limits().inboundMessagesSec(); @@ -55,11 +104,10 @@ public Result normalize(BackendsConfig in) { } BackendsConfig.Limits limitsOut = new BackendsConfig.Limits(inboundMessagesSec); - // timeouts - BackendsConfig.Timeouts timeoutsIn = in.timeouts(); - int authTimeout = (timeoutsIn != null ? timeoutsIn.authTimeout() : d.timeouts().authTimeout()); - int reconnectTimeout = (timeoutsIn != null ? timeoutsIn.reconnectTimeout() : d.timeouts().reconnectTimeout()); - int reconnectInterval = (timeoutsIn != null ? timeoutsIn.reconnectInterval() : d.timeouts().reconnectInterval()); + BackendsConfig.Timeouts timeoutsIn = in.timeouts() != null ? in.timeouts() : d.timeouts(); + int authTimeout = timeoutsIn.authTimeout(); + int reconnectTimeout = timeoutsIn.reconnectTimeout(); + int reconnectInterval = timeoutsIn.reconnectInterval(); if (authTimeout <= 0) { Log.error("'timeouts.auth-timeout' must be > 0"); authTimeout = d.timeouts().authTimeout(); @@ -75,46 +123,54 @@ public Result normalize(BackendsConfig in) { reconnectInterval = d.timeouts().reconnectInterval(); ok = false; } - BackendsConfig.Timeouts timeoutsOut = new BackendsConfig.Timeouts(authTimeout, reconnectTimeout, reconnectInterval); + BackendsConfig.Timeouts timeoutsOut = new BackendsConfig.Timeouts(authTimeout, reconnectTimeout, + reconnectInterval); - // security - BackendsConfig.Security secIn = in.security(); - TlsMode tlsMode = (secIn.tlsMode() != null) ? secIn.tlsMode() : d.security().tlsMode(); + BackendsConfig.Security secIn = in.security() != null ? in.security() : d.security(); + TlsMode tlsMode = secIn.tlsMode() != null ? secIn.tlsMode() : d.security().tlsMode(); if (secIn.tlsMode() == null) { Log.error("'security.tls-mode' must be set"); ok = false; } + String tlsPin = (secIn.tlsPin() != null && !secIn.tlsPin().isBlank()) ? secIn.tlsPin().trim() : d.security().tlsPin(); + String secret = secIn.secret(); if (secret == null || secret.isBlank()) { Log.error("'security.secret' must not be empty"); secret = d.security().secret(); ok = false; - } else if (secret.toLowerCase().contains("change-me")) { + } else if (websocketMode && secret.toLowerCase().contains("change-me")) { Log.warn("'security.secret' contains 'change-me'. replace with real key"); } + Boolean requireAuth = secIn.requireAuth(); if (requireAuth == null) { - Log.warn("'security.require-auth' must be set"); + if (websocketMode) { + Log.warn("'security.require-auth' must be set"); + } requireAuth = d.security().requireAuth(); - } else if (Boolean.FALSE.equals(requireAuth)) { + } else if (websocketMode && Boolean.FALSE.equals(requireAuth)) { Log.warn("Authentication is disabled! This is insecure and should not be used"); } - if (tlsMode == TlsMode.PLAIN && Boolean.TRUE.equals(requireAuth)) { + + if (endpointType == EndpointType.WEBSOCKET && tlsMode == TlsMode.PLAIN && Boolean.TRUE.equals(requireAuth)) { Log.warn("'tls-mode=PLAIN' with 'require-auth=true' is unusual. consider TLS"); } + BackendsConfig.Security secOut = new BackendsConfig.Security(tlsMode, tlsPin, secret, requireAuth); BackendsConfig out = new BackendsConfig( - host, port, clientId, + clientId, + endpointType, + endpointsOut, secOut, timeoutsOut, limitsOut, in.debug()); return new Result<>(out, ok); - } } diff --git a/core/src/main/java/dev/objz/commandbridge/config/profile/VelocityConfigProfile.java b/core/src/main/java/dev/objz/commandbridge/config/profile/VelocityConfigProfile.java index cacb41d1..d69251a7 100644 --- a/core/src/main/java/dev/objz/commandbridge/config/profile/VelocityConfigProfile.java +++ b/core/src/main/java/dev/objz/commandbridge/config/profile/VelocityConfigProfile.java @@ -1,5 +1,6 @@ package dev.objz.commandbridge.config.profile; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.config.model.TlsMode; import dev.objz.commandbridge.config.model.VelocityConfig; import dev.objz.commandbridge.logging.Log; @@ -15,117 +16,165 @@ public Result normalize(VelocityConfig in) { VelocityConfig d = defaults(); boolean ok = true; - // bind-host - String bindHost = in.bindHost(); + String serverId = in.serverId(); + if (serverId == null || serverId.isBlank()) { + serverId = d.serverId(); + Log.warn("'server-id' missing or blank. defaulting to '{}'", serverId); + } else { + serverId = serverId.trim(); + } + + EndpointType endpointType = in.endpointType(); + if (endpointType == null) { + Log.error("'endpoint-type' must be set"); + endpointType = d.endpointType(); + ok = false; + } + boolean websocketMode = endpointType == EndpointType.WEBSOCKET; + boolean redisMode = endpointType == EndpointType.REDIS; + + VelocityConfig.Endpoints endpointsIn = in.endpoints() != null ? in.endpoints() : d.endpoints(); + + VelocityConfig.Endpoints.WebSocket wsIn = endpointsIn.webSocket() != null + ? endpointsIn.webSocket() + : d.endpoints().webSocket(); + + String bindHost = wsIn.bindHost(); if (bindHost == null || bindHost.isBlank()) { - Log.error("'bind-host' must not be empty"); - bindHost = d.bindHost(); + Log.error("'endpoints.websocket.bind-host' must not be empty"); + bindHost = d.endpoints().webSocket().bindHost(); ok = false; } else { bindHost = bindHost.trim(); if (bindHost.startsWith("ws://") || bindHost.startsWith("wss://")) { - Log.warn("'bind-host' must NOT include ws:// or wss://"); + if (websocketMode) { + Log.warn("'endpoints.websocket.bind-host' must NOT include ws:// or wss://"); + } bindHost = bindHost.replaceFirst("^wss?://", ""); } } - // bind-port - int bindPort = in.bindPort(); + int bindPort = wsIn.bindPort(); if (bindPort <= 0 || bindPort > 65535) { - Log.error("'bind-port' must be between 1 and 65535"); - bindPort = d.bindPort(); + Log.error("'endpoints.websocket.bind-port' must be between 1 and 65535"); + bindPort = d.endpoints().webSocket().bindPort(); ok = false; } - // server-id - String serverId = in.serverId(); - if (serverId == null || serverId.isBlank()) { - serverId = d.serverId(); - Log.warn("'server-id' missing or blank. defaulting to '{}'", serverId); + VelocityConfig.Endpoints.Redis redisIn = endpointsIn.redis() != null + ? endpointsIn.redis() + : d.endpoints().redis(); + + String redisHost = redisIn.host(); + if (redisHost == null || redisHost.isBlank()) { + Log.error("'endpoints.redis.host' must not be empty"); + redisHost = d.endpoints().redis().host(); + ok = false; } else { - serverId = serverId.trim(); + redisHost = redisHost.trim(); + if (redisHost.startsWith("redis://") || redisHost.startsWith("rediss://")) { + if (redisMode) { + Log.warn("'endpoints.redis.host' must NOT include redis:// or rediss://"); + } + redisHost = redisHost.replaceFirst("^rediss?://", ""); + } } - // heartbeat - VelocityConfig.Heartbeat hb = in.heartbeat(); - int appPing = hb.appPingSeconds(); - int stale = hb.staleAfterSeconds(); + int redisPort = redisIn.port(); + if (redisPort <= 0 || redisPort > 65535) { + Log.error("'endpoints.redis.port' must be between 1 and 65535"); + redisPort = d.endpoints().redis().port(); + ok = false; + } + + String redisUsername = redisIn.username() == null ? "" : redisIn.username().trim(); + String redisPassword = redisIn.password() == null ? "" : redisIn.password(); + + VelocityConfig.Endpoints endpointsOut = new VelocityConfig.Endpoints( + new VelocityConfig.Endpoints.WebSocket(bindHost, bindPort), + new VelocityConfig.Endpoints.Redis(redisHost, redisPort, redisUsername, redisPassword)); + + VelocityConfig.Heartbeat hbIn = in.heartbeat() != null ? in.heartbeat() : d.heartbeat(); + int appPing = hbIn.appPingSeconds(); + int stale = hbIn.staleAfterSeconds(); if (appPing <= 0) { - Log.error("'app-ping-seconds' must be > 0"); + Log.error("'heartbeat.app-ping-seconds' must be > 0"); appPing = d.heartbeat().appPingSeconds(); ok = false; } if (stale < appPing) { - Log.error("'stale-after-seconds' must be >= 'app-ping-seconds'"); + Log.error("'heartbeat.stale-after-seconds' must be >= 'heartbeat.app-ping-seconds'"); stale = Math.max(appPing, d.heartbeat().staleAfterSeconds()); ok = false; } - hb = new VelocityConfig.Heartbeat(appPing, stale); + VelocityConfig.Heartbeat hbOut = new VelocityConfig.Heartbeat(appPing, stale); - VelocityConfig.Timeouts to = in.timeouts(); - int registerTimeout = to.registerTimeout(); + VelocityConfig.Timeouts toIn = in.timeouts() != null ? in.timeouts() : d.timeouts(); + int registerTimeout = toIn.registerTimeout(); if (registerTimeout <= 0) { Log.error("'timeouts.register-timeout' must be > 0"); registerTimeout = d.timeouts().registerTimeout(); ok = false; } - int pingTimeout = to.pingTimeout(); + int pingTimeout = toIn.pingTimeout(); if (pingTimeout <= 0) { Log.error("'timeouts.ping-timeout' must be > 0"); pingTimeout = d.timeouts().pingTimeout(); ok = false; } - to = new VelocityConfig.Timeouts(registerTimeout, pingTimeout); + VelocityConfig.Timeouts toOut = new VelocityConfig.Timeouts(registerTimeout, pingTimeout); - // limits - VelocityConfig.Limits limits = in.limits(); - int maxCon = limits.maxConnections(); - int maxMsg = limits.maxMessageSizeBytes(); - int inboundPerSec = limits.inboundMessagesSec(); + VelocityConfig.Limits limitsIn = in.limits() != null ? in.limits() : d.limits(); + int maxCon = limitsIn.maxConnections(); + int maxMsg = limitsIn.maxMessageSizeBytes(); + int inboundPerSec = limitsIn.inboundMessagesSec(); if (maxCon <= 0) { - Log.error("'max-connections' must be positive"); + Log.error("'limits.max-connections' must be positive"); maxCon = d.limits().maxConnections(); ok = false; } if (maxMsg < 1024) { - Log.error("'max-message-size-bytes' too small"); + Log.error("'limits.max-message-size-bytes' too small"); maxMsg = Math.max(1024, d.limits().maxMessageSizeBytes()); ok = false; } if (inboundPerSec <= 0) { - Log.error("'inbound-messages-per-sec' must be positive"); + Log.error("'limits.inbound-messages-per-sec' must be positive"); inboundPerSec = d.limits().inboundMessagesSec(); ok = false; } - limits = new VelocityConfig.Limits(inboundPerSec, maxCon, maxMsg); - - // security - VelocityConfig.Security secIn = in.security(); + VelocityConfig.Limits limitsOut = new VelocityConfig.Limits(inboundPerSec, maxCon, maxMsg); - TlsMode tlsMode = (secIn.tlsMode() != null) ? secIn.tlsMode() : d.security().tlsMode(); + VelocityConfig.Security secIn = in.security() != null ? in.security() : d.security(); + TlsMode tlsMode = secIn.tlsMode() != null ? secIn.tlsMode() : d.security().tlsMode(); if (secIn.tlsMode() == null) { Log.error("'security.tls-mode' must be set"); ok = false; } - boolean requireAuth = secIn.requireAuth(); // primitive + boolean requireAuth = secIn.requireAuth(); int authTimeout = secIn.authTimeoutSeconds(); + if (authTimeout <= 0) { + Log.error("'security.auth-timeout-seconds' must be > 0"); + authTimeout = d.security().authTimeoutSeconds(); + ok = false; + } String keystorePath = emptyToNull(secIn.keystorePath()); - String keystorePassword = secIn.keystorePassword(); // may be null + String keystorePassword = secIn.keystorePassword(); String keystoreType = (secIn.keystoreType() == null || secIn.keystoreType().isBlank()) ? d.security().keystoreType() : secIn.keystoreType().trim(); - if (!requireAuth) { + if (websocketMode && !requireAuth) { Log.warn("Authentication is disabled! This is insecure and should not be used"); } - if (tlsMode == TlsMode.PLAIN && requireAuth) { + + if (websocketMode && tlsMode == TlsMode.PLAIN && requireAuth) { Log.warn("'tls-mode=PLAIN' with 'require-auth=true' is unusual. consider TLS"); } - // STRICT requires keystore pieces - if (tlsMode == TlsMode.STRICT) { + if (endpointType == EndpointType.WEBSOCKET && tlsMode == TlsMode.STRICT) { if (keystorePath == null || keystorePath.isBlank()) { Log.error("'security.keystore-path' is required in STRICT mode"); ok = false; @@ -146,8 +195,15 @@ public Result normalize(VelocityConfig in) { VelocityConfig.Security secOut = new VelocityConfig.Security( requireAuth, authTimeout, tlsMode, keystorePath, keystorePassword, keystoreType); - VelocityConfig out = new VelocityConfig(in.actAsClient(), bindHost, bindPort, serverId, hb, secOut, to, - limits, + VelocityConfig out = new VelocityConfig( + in.actAsClient(), + serverId, + endpointType, + endpointsOut, + hbOut, + secOut, + toOut, + limitsOut, in.debug()); return new Result<>(out, ok); } diff --git a/core/src/main/java/dev/objz/commandbridge/net/Endpoint.java b/core/src/main/java/dev/objz/commandbridge/net/Endpoint.java new file mode 100644 index 00000000..55d29eef --- /dev/null +++ b/core/src/main/java/dev/objz/commandbridge/net/Endpoint.java @@ -0,0 +1,15 @@ +package dev.objz.commandbridge.net; + +import java.util.concurrent.CompletableFuture; + +import dev.objz.commandbridge.net.proto.Envelope; + +public interface Endpoint { + CompletableFuture send(Envelope env); + + boolean isOpen(); + + default String describe() { + return "unknown"; + } +} diff --git a/core/src/main/java/dev/objz/commandbridge/net/InNode.java b/core/src/main/java/dev/objz/commandbridge/net/InNode.java index bb9674ce..a553a3e4 100644 --- a/core/src/main/java/dev/objz/commandbridge/net/InNode.java +++ b/core/src/main/java/dev/objz/commandbridge/net/InNode.java @@ -3,7 +3,6 @@ import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; -import io.undertow.websockets.core.WebSocketChannel; import java.util.EnumMap; import java.util.Map; @@ -15,7 +14,7 @@ public class InNode { private final Map handlers; private Predicate inboundTap; - private BiFunction sendOperationFactory; + private BiFunction sendOperationFactory; public InNode() { this.handlers = new EnumMap<>(MessageType.class); @@ -25,7 +24,7 @@ public void setInboundTap(Predicate tap) { this.inboundTap = tap; } - public InNode setSendOperationFactory(BiFunction factory) { + public InNode setSendOperationFactory(BiFunction factory) { this.sendOperationFactory = factory; return this; } @@ -38,12 +37,13 @@ public InNode register(MessageType type, InboundHandler handler) { return this; } - public void onText(WebSocketChannel ch, String text) { + public void onText(Endpoint endpoint, String text) { final Envelope env; try { env = Envelope.MAPPER.readValue(text, Envelope.class); } catch (Exception e) { - Log.warn("Bad JSON from {}: {}", ch.getSourceAddress(), e.getMessage()); + String source = endpoint != null ? endpoint.describe() : "unknown"; + Log.warn("Bad JSON from {}: {}", source, e.getMessage()); return; } @@ -62,7 +62,7 @@ public void onText(WebSocketChannel ch, String text) { } try { - handler.accept(ch, env); + handler.accept(endpoint, env); } catch (Exception ex) { Log.error(ex, "Handler failure for type {}", env.type()); } diff --git a/core/src/main/java/dev/objz/commandbridge/net/InboundHandler.java b/core/src/main/java/dev/objz/commandbridge/net/InboundHandler.java index cb9c8dc7..89263e45 100644 --- a/core/src/main/java/dev/objz/commandbridge/net/InboundHandler.java +++ b/core/src/main/java/dev/objz/commandbridge/net/InboundHandler.java @@ -2,41 +2,40 @@ import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; -import io.undertow.websockets.core.WebSocketChannel; import java.util.function.BiFunction; public abstract class InboundHandler { - private BiFunction sendOperationFactory; + private BiFunction sendOperationFactory; - public void setSendOperationFactory(BiFunction factory) { + public void setSendOperationFactory(BiFunction factory) { this.sendOperationFactory = factory; } /** * Handles an inbound message * - * @param ch The WebSocket channel + * @param endpoint The transport endpoint * @param env The received envelope */ - public abstract void accept(WebSocketChannel ch, Envelope env); + public abstract void accept(Endpoint endpoint, Envelope env); /** * Sends a reply to the original request * - * @param ch The WebSocket channel + * @param endpoint The transport endpoint * @param request The original request envelope * @param replyType The message type for the reply * @param payload The payload object * @return SendOperation */ - protected SendOperation reply(WebSocketChannel ch, Envelope request, MessageType replyType, Object payload) { + protected SendOperation reply(Endpoint endpoint, Envelope request, MessageType replyType, Object payload) { if (sendOperationFactory == null) { - throw new IllegalStateException("SendOperation factory not configured"); + throw new IllegalStateException("Inbound send factory not configured"); } var payloadNode = Envelope.MAPPER.valueToTree(payload); Envelope replyEnv = Envelope.reply(request, replyType, request.to(), payloadNode); - return sendOperationFactory.apply(ch, replyEnv); + return sendOperationFactory.apply(endpoint, replyEnv); } } diff --git a/core/src/main/java/dev/objz/commandbridge/net/OutNode.java b/core/src/main/java/dev/objz/commandbridge/net/OutNode.java index 949ed962..01bf6d49 100644 --- a/core/src/main/java/dev/objz/commandbridge/net/OutNode.java +++ b/core/src/main/java/dev/objz/commandbridge/net/OutNode.java @@ -3,8 +3,6 @@ import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.proto.MessageType; -import io.undertow.websockets.core.WebSocketChannel; - import java.util.EnumMap; import java.util.Map; import java.util.Objects; @@ -20,7 +18,7 @@ public class OutNode { private final Map> handlers; private Function sendOperationFactory; - private BiFunction channelSendOperationFactory; + private BiFunction endpointSendFactory; private String clientId; private String serverId; @@ -33,9 +31,9 @@ public OutNode setSendOperationFactory(Function fact return this; } - public OutNode setChannelSendOperationFactory( - BiFunction factory) { - this.channelSendOperationFactory = factory; + public OutNode setEndpointSendFactory( + BiFunction factory) { + this.endpointSendFactory = factory; return this; } @@ -53,7 +51,7 @@ public OutNode register(MessageType type, OutboundHandler ha Objects.requireNonNull(type); Objects.requireNonNull(handler); handler.setSendOperationFactory(sendOperationFactory); - handler.setChannelSendOperationFactory(channelSendOperationFactory); + handler.setEndpointSendFactory(endpointSendFactory); handler.setClientId(clientId); handler.setServerId(serverId); @SuppressWarnings("unchecked") @@ -78,9 +76,9 @@ public SendOperation send(MessageType type, T context) { throw new IllegalStateException("No OutboundHandler registered for " + type); } - if (sendOperationFactory == null && channelSendOperationFactory == null) { + if (sendOperationFactory == null && endpointSendFactory == null) { throw new IllegalStateException( - "SendOperation factory not configured. Call setSendOperationFactory() first."); + "Send factory not configured. Call setSendOperationFactory() or setEndpointSendFactory() first."); } try { diff --git a/core/src/main/java/dev/objz/commandbridge/net/OutboundHandler.java b/core/src/main/java/dev/objz/commandbridge/net/OutboundHandler.java index 6fa90f3e..74144243 100644 --- a/core/src/main/java/dev/objz/commandbridge/net/OutboundHandler.java +++ b/core/src/main/java/dev/objz/commandbridge/net/OutboundHandler.java @@ -1,7 +1,6 @@ package dev.objz.commandbridge.net; import dev.objz.commandbridge.net.proto.Envelope; -import io.undertow.websockets.core.WebSocketChannel; import java.util.function.BiFunction; import java.util.function.Function; @@ -15,14 +14,14 @@ public abstract class OutboundHandler { protected String clientId; protected String serverId; private Function sendOperationFactory; - private BiFunction channelSendOperationFactory; + private BiFunction endpointSendFactory; public void setSendOperationFactory(Function factory) { this.sendOperationFactory = factory; } - public void setChannelSendOperationFactory(BiFunction factory) { - this.channelSendOperationFactory = factory; + public void setEndpointSendFactory(BiFunction factory) { + this.endpointSendFactory = factory; } public void setClientId(String clientId) { @@ -45,20 +44,20 @@ public void setServerId(String serverId) { */ protected SendOperation send(Envelope envelope) { if (sendOperationFactory == null) { - throw new IllegalStateException("SendOperation factory not configured"); + throw new IllegalStateException("Default send factory not configured"); } return sendOperationFactory.apply(envelope); } /** - * @param channel The WebSocket channel to send to + * @param endpoint The endpoint to send to * @param envelope The envelope to send * @return SendOperation */ - protected SendOperation send(WebSocketChannel channel, Envelope envelope) { - if (channelSendOperationFactory == null) { - throw new IllegalStateException("Channel SendOperation factory not configured"); + protected SendOperation send(Endpoint endpoint, Envelope envelope) { + if (endpointSendFactory == null) { + throw new IllegalStateException("Endpoint send factory not configured"); } - return channelSendOperationFactory.apply(channel, envelope); + return endpointSendFactory.apply(endpoint, envelope); } } diff --git a/core/src/main/java/dev/objz/commandbridge/net/SendOperation.java b/core/src/main/java/dev/objz/commandbridge/net/SendOperation.java index ba94ac4d..633265e9 100644 --- a/core/src/main/java/dev/objz/commandbridge/net/SendOperation.java +++ b/core/src/main/java/dev/objz/commandbridge/net/SendOperation.java @@ -2,8 +2,6 @@ import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; -import io.undertow.websockets.core.WebSocketChannel; -import io.undertow.websockets.core.WebSockets; import java.time.Duration; import java.util.Objects; @@ -12,7 +10,7 @@ public final class SendOperation { - private final WebSocketChannel ch; + private final Endpoint endpoint; private final Envelope request; private final ResponseAwaiter awaiter; @@ -20,10 +18,10 @@ public final class SendOperation { private Predicate matcher; private Duration timeout = Duration.ofSeconds(15); - public SendOperation(WebSocketChannel ch, + public SendOperation(Endpoint endpoint, Envelope request, ResponseAwaiter awaiter) { - this.ch = Objects.requireNonNull(ch); + this.endpoint = Objects.requireNonNull(endpoint); this.request = Objects.requireNonNull(request); this.awaiter = Objects.requireNonNull(awaiter); } @@ -42,6 +40,7 @@ public SendOperation timeout(Duration timeout) { this.timeout = Objects.requireNonNull(timeout); return this; } + // sends the package and awaits for a response public CompletableFuture await() { Predicate m = (matcher != null) @@ -51,22 +50,15 @@ public CompletableFuture await() { var clientId = request.to(); var fut = awaiter.await(clientId, String.valueOf(request.id()), m, timeout); - try { - WebSockets.sendText(Envelope.MAPPER.writeValueAsString(request), ch, null); - } catch (Exception e) { - fut.completeExceptionally(e); - } + endpoint.send(request).whenComplete((v, ex) -> { + if (ex != null) + fut.completeExceptionally(ex); + }); return fut; } + // sends the package without waiting for anything public CompletableFuture dispatch() { - try { - WebSockets.sendText(Envelope.MAPPER.writeValueAsString(request), ch, null); - return CompletableFuture.completedFuture(null); - } catch (Exception e) { - var cf = new CompletableFuture(); - cf.completeExceptionally(e); - return cf; - } + return endpoint.send(request); } } diff --git a/core/src/main/java/dev/objz/commandbridge/net/endpoints/RedisEndpoint.java b/core/src/main/java/dev/objz/commandbridge/net/endpoints/RedisEndpoint.java new file mode 100644 index 00000000..d7cc2f39 --- /dev/null +++ b/core/src/main/java/dev/objz/commandbridge/net/endpoints/RedisEndpoint.java @@ -0,0 +1,42 @@ +package dev.objz.commandbridge.net.endpoints; + +import dev.objz.commandbridge.net.Endpoint; +import dev.objz.commandbridge.net.proto.Envelope; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.BooleanSupplier; +import java.util.function.Function; + +public final class RedisEndpoint implements Endpoint { + private final String id; + private final Function> sender; + private final BooleanSupplier openSupplier; + + public RedisEndpoint(String id, + Function> sender, + BooleanSupplier openSupplier) { + this.id = Objects.requireNonNull(id); + this.sender = Objects.requireNonNull(sender); + this.openSupplier = Objects.requireNonNull(openSupplier); + } + + public String id() { + return id; + } + + @Override + public CompletableFuture send(Envelope env) { + return sender.apply(env); + } + + @Override + public boolean isOpen() { + return openSupplier.getAsBoolean(); + } + + @Override + public String describe() { + return "redis:" + id; + } +} diff --git a/core/src/main/java/dev/objz/commandbridge/net/endpoints/WsEndpoint.java b/core/src/main/java/dev/objz/commandbridge/net/endpoints/WsEndpoint.java new file mode 100644 index 00000000..8ab7d848 --- /dev/null +++ b/core/src/main/java/dev/objz/commandbridge/net/endpoints/WsEndpoint.java @@ -0,0 +1,46 @@ +package dev.objz.commandbridge.net.endpoints; + +import java.util.concurrent.CompletableFuture; + +import dev.objz.commandbridge.net.Endpoint; +import dev.objz.commandbridge.net.proto.Envelope; +import io.undertow.websockets.core.WebSocketChannel; +import io.undertow.websockets.core.WebSockets; + +public class WsEndpoint implements Endpoint { + + private final WebSocketChannel ch; + + public WsEndpoint(WebSocketChannel ch) { + this.ch = ch; + } + + public WebSocketChannel channel() { + return ch; + } + + @Override + public boolean isOpen() { + return ch != null && ch.isOpen(); + } + + @Override + public String describe() { + if (ch == null || ch.getSourceAddress() == null) + return "unknown"; + return ch.getSourceAddress().toString(); + } + + @Override + public CompletableFuture send(Envelope env) { + try { + WebSockets.sendText(Envelope.MAPPER.writeValueAsString(env), ch, null); + return CompletableFuture.completedFuture(null); + } catch (Exception e) { + var cf = new CompletableFuture(); + cf.completeExceptionally(e); + return cf; + } + } + +} diff --git a/core/src/main/java/dev/objz/commandbridge/net/redis/RedisChannels.java b/core/src/main/java/dev/objz/commandbridge/net/redis/RedisChannels.java new file mode 100644 index 00000000..583832b0 --- /dev/null +++ b/core/src/main/java/dev/objz/commandbridge/net/redis/RedisChannels.java @@ -0,0 +1,16 @@ +package dev.objz.commandbridge.net.redis; + +public final class RedisChannels { + public static final String PROXY_INBOUND = "commandbridge:proxy:in"; + private static final String CLIENT_PREFIX = "commandbridge:client:"; + + private RedisChannels() { + } + + public static String clientInbound(String clientId) { + if (clientId == null || clientId.isBlank()) { + return CLIENT_PREFIX + "unknown"; + } + return CLIENT_PREFIX + clientId.trim(); + } +} diff --git a/dist/build.gradle.kts b/dist/build.gradle.kts index a0c25917..46b7b285 100644 --- a/dist/build.gradle.kts +++ b/dist/build.gradle.kts @@ -30,7 +30,9 @@ dependencies { implementation("org.snakeyaml:snakeyaml-engine:2.10") implementation("com.google.code.gson:gson:2.13.2") implementation("net.kyori:adventure-text-minimessage:4.17.0") - implementation("org.bstats:bstats-velocity:3.1.0") + implementation("org.bstats:bstats-velocity:3.2.0") + implementation("org.bstats:bstats-bukkit:3.2.0") + implementation("redis.clients:jedis:7.1.0") } val pluginVersion: Provider = providers.gradleProperty("pluginVersion") diff --git a/velocity/build.gradle.kts b/velocity/build.gradle.kts index 31c6be54..693d4d3e 100644 --- a/velocity/build.gradle.kts +++ b/velocity/build.gradle.kts @@ -18,6 +18,7 @@ repositories { dependencies { implementation(project(":core")) + implementation("redis.clients:jedis:7.1.0") compileOnly("com.velocitypowered:velocity-api:3.4.0-SNAPSHOT") annotationProcessor("com.velocitypowered:velocity-api:3.4.0-SNAPSHOT") @@ -29,7 +30,7 @@ dependencies { compileOnly("org.spongepowered:configurate-core:4.2.0") compileOnly("dev.jorel:commandapi-velocity-core:11.1.0") - compileOnly("org.bstats:bstats-velocity:3.1.0") + compileOnly("org.bstats:bstats-velocity:3.2.0") compileOnly("net.william278:papiproxybridge:1.8.4") @@ -49,4 +50,3 @@ tasks.jar { // version replacement and packaging, so exclude it from the velocity jar. exclude("velocity-plugin.json") } - diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/Main.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/Main.java index 29fa3a42..2b74d763 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/Main.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/Main.java @@ -11,6 +11,7 @@ import com.velocitypowered.api.proxy.Player; import com.velocitypowered.api.proxy.ProxyServer; import dev.objz.commandbridge.config.ConfigManager; +import dev.objz.commandbridge.config.model.EndpointType; import dev.objz.commandbridge.config.model.VelocityConfig; import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.net.InNode; @@ -30,6 +31,8 @@ import dev.objz.commandbridge.velocity.dispatch.CommandEntry; import dev.objz.commandbridge.velocity.cmd.bridge.framework.ArgumentBridge; import dev.objz.commandbridge.velocity.cmd.bridge.packetevents.PacketEventsArgumentBridge; +import dev.objz.commandbridge.velocity.net.EndpointServer; +import dev.objz.commandbridge.velocity.net.RedisServer; import dev.objz.commandbridge.velocity.net.WsServer; import dev.objz.commandbridge.velocity.net.in.AuthHandler; import dev.objz.commandbridge.velocity.net.in.ExecuteCommandHandler; @@ -65,7 +68,7 @@ public final class Main { private final Metrics.Factory metrics; private ConfigManager configManager; - private WsServer ws; + private EndpointServer endpointServer; private RegistrationManager registrations; private InNode inNode; private OutNode outNode; @@ -142,12 +145,24 @@ public void onProxyInitialization(ProxyInitializeEvent e) { outNode = new OutNode<>(); outNode.setServerId(cfg.serverId()); - var tls = TlsResolver.resolveServer(dataDir, cfg.security()); - ws = tls.enabled() - ? new WsServer(cfg.bindHost(), cfg.bindPort(), sessions, inNode, - true, tls.context()) - : new WsServer(cfg.bindHost(), cfg.bindPort(), sessions, inNode); - ws.start(); + if (cfg.endpointType() == EndpointType.REDIS) { + var redis = cfg.endpoints().redis(); + endpointServer = new RedisServer( + redis.host(), + redis.port(), + redis.username(), + redis.password(), + sessions, + inNode); + } else { + var wsCfg = cfg.endpoints().webSocket(); + var tls = TlsResolver.resolveServer(dataDir, cfg.security()); + endpointServer = tls.enabled() + ? new WsServer(wsCfg.bindHost(), wsCfg.bindPort(), sessions, inNode, + true, tls.context()) + : new WsServer(wsCfg.bindHost(), wsCfg.bindPort(), sessions, inNode); + } + endpointServer.start(); scriptManager = new ScriptManager(dataDir, platformFeatures); scriptManager.loadAll(); @@ -176,8 +191,14 @@ public void onProxyInitialization(ProxyInitializeEvent e) { command.register(); Log.debug("Config loaded:"); - Log.debug(" Host: {}", cfg.bindHost()); - Log.debug(" Port: {}", cfg.bindPort()); + Log.debug(" Endpoint Type: {}", cfg.endpointType()); + if (cfg.endpointType() == EndpointType.WEBSOCKET) { + Log.debug(" WS Host: {}", cfg.endpoints().webSocket().bindHost()); + Log.debug(" WS Port: {}", cfg.endpoints().webSocket().bindPort()); + } else { + Log.debug(" Redis Host: {}", cfg.endpoints().redis().host()); + Log.debug(" Redis Port: {}", cfg.endpoints().redis().port()); + } Log.debug(" Server ID: {}", cfg.serverId()); checkForUpdate(); @@ -195,21 +216,21 @@ public void onProxyShutdown(ProxyShutdownEvent e) { } } - if (ws != null) { - ws.stop(); + if (endpointServer != null) { + endpointServer.stop(); } } private void installRoutes() { var secret = new SecretLoader(dataDir).loadOrCreate(); var auth = new AuthService(secret); - authHandler = new AuthHandler(auth, sessions, ws); + authHandler = new AuthHandler(auth, sessions, endpointServer); authHandler.register(inNode); inNode.register(MessageType.INVOKED_COMMAND, new InvokedCommandHandler(sessions, commandEntry)); inNode.register(MessageType.EXECUTE_COMMAND_RESULT, new ExecuteCommandHandler(proxy)); - outNode.setChannelSendOperationFactory((ch, env) -> ws.send(ch, env)); + outNode.setEndpointSendFactory((endpoint, env) -> endpointServer.send(endpoint, env)); outNode.register(MessageType.REGISTER_COMMANDS, new RegistrationRequest()); outNode.register(MessageType.PING, new PingRequest()); outNode.register(MessageType.EXECUTE_COMMAND, new ExecuteCommandRequest()); diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ListCommand.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ListCommand.java index 1ae85a8d..a14de89d 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ListCommand.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ListCommand.java @@ -52,9 +52,7 @@ private void renderChat(CommandSource sender, RenderContext ctx) { int index = 0; for (var s : authenticated) { String id = s.id() != null ? s.id() : "unknown"; - String address = s.ch() != null && s.ch().getSourceAddress() != null - ? s.ch().getSourceAddress().toString() - : "unknown"; + String address = s.endpoint() != null ? s.endpoint().describe() : "unknown"; String platform = s.location() != null ? s.location().name() : "unknown"; Component header = MM.parse("<" + Theme.C_ACCENT + ">• ") @@ -101,9 +99,7 @@ private void renderConsole() { for (var s : authenticated) { String id = s.id() != null ? s.id() : "unknown"; - String address = s.ch() != null && s.ch().getSourceAddress() != null - ? s.ch().getSourceAddress().toString() - : "unknown"; + String address = s.endpoint() != null ? s.endpoint().describe() : "unknown"; String platform = s.location() != null ? s.location().name() : "unknown"; table.addRow(id, address, platform); } @@ -115,7 +111,7 @@ private void renderConsole() { private List getAuthenticatedClients() { List authenticated = new ArrayList<>(); for (ClientSession s : sessions) { - if (s.status() == AuthStatus.AUTH_OK && s.ch() != null && s.ch().isOpen()) { + if (s.status() == AuthStatus.AUTH_OK && s.endpoint() != null && s.endpoint().isOpen()) { authenticated.add(s); } } diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/PingCommand.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/PingCommand.java index cb2ffd12..41828fe8 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/PingCommand.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/PingCommand.java @@ -20,8 +20,6 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.ArrayList; -import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -64,9 +62,7 @@ private void pingAll(CommandSource sender) { for (ClientSession session : activeClients) { String id = session.id(); - String address = session.ch() != null && session.ch().getSourceAddress() != null - ? session.ch().getSourceAddress().toString() - : "unknown"; + String address = session.endpoint() != null ? session.endpoint().describe() : "unknown"; try { outNode.send( @@ -102,9 +98,7 @@ private void pingSingle(CommandSource sender, String clientId) { return; } - String address = session.ch() != null && session.ch().getSourceAddress() != null - ? session.ch().getSourceAddress().toString() - : "unknown"; + String address = session.endpoint() != null ? session.endpoint().describe() : "unknown"; try { outNode.send( @@ -264,7 +258,8 @@ private void renderClientNotFoundConsole(String clientId) { private List getActiveClients() { List activeClients = new ArrayList<>(); for (ClientSession session : sessions) { - if (session.status() == AuthStatus.AUTH_OK && session.ch() != null && session.ch().isOpen()) { + if (session.status() == AuthStatus.AUTH_OK && session.endpoint() != null + && session.endpoint().isOpen()) { activeClients.add(session); } } diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ReloadCommand.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ReloadCommand.java index fe042609..3a25a179 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ReloadCommand.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/cli/subcommands/ReloadCommand.java @@ -93,9 +93,7 @@ public void execute(CommandSource sender) { for (ClientSession session : activeClients) { String clientId = session.id(); - String address = session.ch() != null && session.ch().getSourceAddress() != null - ? session.ch().getSourceAddress().toString() - : "unknown"; + String address = session.endpoint() != null ? session.endpoint().describe() : "unknown"; Set scripts = registrationManager .getScriptsForSession(session); @@ -141,12 +139,9 @@ public void execute(CommandSource sender) { String clientId = session.id(); if (sentTo.containsKey(clientId) && !results.containsKey(clientId)) { - String address = session.ch() != null && session - .ch() - .getSourceAddress() != null - ? session.ch().getSourceAddress() - .toString() - : "unknown"; + String address = session.endpoint() != null + ? session.endpoint().describe() + : "unknown"; results.put(clientId, new ReloadResult( ReloadStatus.TIMEOUT, @@ -264,8 +259,8 @@ private void renderChatError(RenderContext ctx, String error, String details) { private List getActiveClients() { List activeClients = new ArrayList<>(); for (ClientSession session : sessionHub) { - if (session.status() == AuthStatus.AUTH_OK && session.ch() != null - && session.ch().isOpen()) { + if (session.status() == AuthStatus.AUTH_OK && session.endpoint() != null + && session.endpoint().isOpen()) { activeClients.add(session); } } diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/EndpointServer.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/EndpointServer.java new file mode 100644 index 00000000..6169efc9 --- /dev/null +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/EndpointServer.java @@ -0,0 +1,15 @@ +package dev.objz.commandbridge.velocity.net; + +import dev.objz.commandbridge.net.Endpoint; +import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.proto.Envelope; + +public interface EndpointServer { + void start(); + + void stop(); + + SendOperation send(Endpoint endpoint, Envelope request); + + void close(Endpoint endpoint); +} diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/RedisServer.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/RedisServer.java new file mode 100644 index 00000000..fbff2868 --- /dev/null +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/RedisServer.java @@ -0,0 +1,223 @@ +package dev.objz.commandbridge.velocity.net; + +import com.fasterxml.jackson.databind.JsonNode; +import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; +import dev.objz.commandbridge.net.InNode; +import dev.objz.commandbridge.net.ResponseAwaiter; +import dev.objz.commandbridge.net.SendOperation; +import dev.objz.commandbridge.net.endpoints.RedisEndpoint; +import dev.objz.commandbridge.net.proto.Envelope; +import dev.objz.commandbridge.net.redis.RedisChannels; +import dev.objz.commandbridge.velocity.net.session.SessionHub; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisPool; +import redis.clients.jedis.JedisPubSub; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +public final class RedisServer implements EndpointServer { + private final String host; + private final int port; + private final String username; + private final String password; + private final SessionHub sessions; + private final InNode inNode; + private final ResponseAwaiter responses = new ResponseAwaiter(); + private final Map endpointsByClient = new ConcurrentHashMap<>(); + + private volatile boolean running; + private volatile JedisPool pool; + private volatile JedisPubSub subscriber; + private volatile Thread subscriberThread; + + public RedisServer( + String host, + int port, + String username, + String password, + SessionHub sessions, + InNode inNode) { + this.host = host; + this.port = port; + this.username = username; + this.password = password; + this.sessions = sessions; + this.inNode = inNode; + + this.inNode.setInboundTap(this::signalInbound); + this.inNode.setSendOperationFactory(this::createSendOperation); + } + + @Override + public synchronized void start() { + if (running) { + return; + } + + try { + pool = createPool(); + running = true; + startSubscriber(); + Log.success(true, "Redis endpoint listening via '{}:{}'", host, port); + } catch (Exception e) { + running = false; + closePool(); + throw (e instanceof RuntimeException re) ? re : new RuntimeException(e); + } + } + + @Override + public synchronized void stop() { + running = false; + + JedisPubSub activeSub = subscriber; + if (activeSub != null) { + try { + activeSub.unsubscribe(); + } catch (Exception ignore) { + } + } + + Thread t = subscriberThread; + if (t != null) { + t.interrupt(); + try { + t.join(1000L); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + closePool(); + endpointsByClient.clear(); + sessions.clear(); + Log.info("Redis endpoint server stopped"); + } + + @Override + public SendOperation send(Endpoint endpoint, Envelope request) { + return new SendOperation(endpoint, request, responses); + } + + @Override + public void close(Endpoint endpoint) { + if (endpoint == null) { + return; + } + sessions.remove(endpoint); + if (endpoint instanceof RedisEndpoint redisEndpoint) { + endpointsByClient.remove(redisEndpoint.id(), redisEndpoint); + } + } + + private SendOperation createSendOperation(Endpoint endpoint, Envelope env) { + return new SendOperation(endpoint, env, responses); + } + + private void startSubscriber() { + subscriberThread = new Thread(this::subscribeLoop, "commandbridge-redis-proxy-sub"); + subscriberThread.setDaemon(true); + subscriberThread.start(); + } + + private void subscribeLoop() { + while (running) { + try (Jedis jedis = pool.getResource()) { + JedisPubSub localSubscriber = new JedisPubSub() { + @Override + public void onMessage(String channel, String message) { + handleInbound(message); + } + }; + subscriber = localSubscriber; + jedis.subscribe(localSubscriber, RedisChannels.PROXY_INBOUND); + } catch (Exception ex) { + if (!running) { + return; + } + Log.warn("Redis subscribe loop failed: {}", ex.getMessage()); + try { + Thread.sleep(1000L); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + } finally { + subscriber = null; + } + } + } + + private void handleInbound(String message) { + String clientId = parseFrom(message); + if (clientId == null || clientId.isBlank()) { + Log.warn("Dropping Redis message without a valid 'from' client id"); + return; + } + + RedisEndpoint endpoint = endpointsByClient.computeIfAbsent(clientId, + id -> new RedisEndpoint(id, env -> publishToClient(id, env), () -> running)); + inNode.onText(endpoint, message); + } + + private String parseFrom(String text) { + try { + JsonNode node = Envelope.MAPPER.readTree(text); + JsonNode from = node.get("from"); + return from != null ? from.asText() : null; + } catch (Exception e) { + Log.warn("Bad JSON from Redis: {}", e.getMessage()); + return null; + } + } + + private CompletableFuture publishToClient(String clientId, Envelope env) { + try { + String payload = Envelope.MAPPER.writeValueAsString(env); + try (Jedis jedis = pool.getResource()) { + jedis.publish(RedisChannels.clientInbound(clientId), payload); + } + return CompletableFuture.completedFuture(null); + } catch (Exception e) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(e); + return failed; + } + } + + private boolean signalInbound(Envelope env) { + try { + return responses.signal(env); + } catch (Exception ex) { + Log.debug("Awaiter signal failed: {}", ex.toString()); + return false; + } + } + + private JedisPool createPool() { + DefaultJedisClientConfig.Builder builder = DefaultJedisClientConfig.builder(); + if (username != null && !username.isBlank()) { + builder.user(username.trim()); + } + if (password != null && !password.isBlank()) { + builder.password(password); + } + return new JedisPool(new HostAndPort(host, port), builder.build()); + } + + private void closePool() { + JedisPool p = pool; + pool = null; + if (p != null) { + try { + p.close(); + } catch (Exception ignore) { + } + } + } +} diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/WsServer.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/WsServer.java index 507e6991..5a541c06 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/WsServer.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/WsServer.java @@ -1,10 +1,12 @@ package dev.objz.commandbridge.velocity.net; import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.ResponseAwaiter; import dev.objz.commandbridge.net.SendOperation; -import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.InNode; +import dev.objz.commandbridge.net.endpoints.WsEndpoint; +import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.velocity.net.session.ClientSession; import dev.objz.commandbridge.velocity.net.session.SessionHub; import io.undertow.Handlers; @@ -27,7 +29,7 @@ import java.io.IOException; import java.net.BindException; -public final class WsServer { +public final class WsServer implements EndpointServer { private final String host; private final int port; @@ -56,17 +58,19 @@ public WsServer(String host, int port, SessionHub sessions, InNode inNode, this.inNode.setSendOperationFactory(this::createSendOperation); } - private SendOperation createSendOperation(WebSocketChannel ch, Envelope env) { - return new SendOperation(ch, env, responses); + private SendOperation createSendOperation(Endpoint endpoint, Envelope env) { + return new SendOperation(endpoint, env, responses); } + @Override public void start() { WebSocketConnectionCallback cb = (WebSocketHttpExchange ex, WebSocketChannel ch) -> { + final WsEndpoint endpoint = new WsEndpoint(ch); ch.getReceiveSetter().set(new AbstractReceiveListener() { @Override protected void onFullTextMessage(WebSocketChannel c, BufferedTextMessage msg) { - inNode.onText(c, msg.getData()); + inNode.onText(endpoint, msg.getData()); } @Override @@ -90,9 +94,11 @@ protected void onClose(WebSocketChannel channel, } try { - Log.warn("WebSocket closed: {}", channel.getSourceAddress()); + Log.warn("Endpoint closed: {}", endpoint.describe()); IoUtils.safeClose(channel); } catch (Throwable ignore) { + } finally { + sessions.remove(endpoint); } } }); @@ -136,25 +142,40 @@ protected void onClose(WebSocketChannel channel, } } + @Override public void stop() { try { - if (server != null) + if (server != null) { for (ClientSession s : sessions) { - close(s.ch()); + close(s.endpoint()); } - server.stop(); + server.stop(); + } } catch (Exception ignore) { } sessions.clear(); Log.info("WebSocket server stopped"); } - public SendOperation send(WebSocketChannel ch, Envelope request) { - return new SendOperation(ch, request, responses); + @Override + public SendOperation send(Endpoint endpoint, Envelope request) { + return new SendOperation(endpoint, request, responses); + } + + @Override + public void close(Endpoint endpoint) { + if (endpoint == null) + return; + + sessions.remove(endpoint); + if (!(endpoint instanceof WsEndpoint wsEndpoint)) + return; + + closeChannel(wsEndpoint.channel(), endpoint); } // always safe close - public void close(WebSocketChannel ch) { + private void closeChannel(WebSocketChannel ch, Endpoint endpoint) { if (ch == null) return; @@ -162,7 +183,7 @@ public void close(WebSocketChannel ch) { try { org.xnio.IoUtils.safeClose(ch); } finally { - sessions.remove(ch); + sessions.remove(endpoint); } return; } @@ -173,7 +194,7 @@ public void close(WebSocketChannel ch) { } ch.addCloseTask(c -> { - sessions.remove(c); + sessions.remove(endpoint); org.xnio.IoUtils.safeClose(c); }); @@ -187,7 +208,7 @@ public void complete(WebSocketChannel channel, Void context) { try { org.xnio.IoUtils.safeClose(channel); } finally { - sessions.remove(channel); + sessions.remove(endpoint); } } @@ -196,7 +217,7 @@ public void onError(WebSocketChannel channel, Void context, Throwable cause) { try { org.xnio.IoUtils.safeClose(channel); } finally { - sessions.remove(channel); + sessions.remove(endpoint); } } }); diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/AuthHandler.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/AuthHandler.java index 73100277..d074cd27 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/AuthHandler.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/AuthHandler.java @@ -3,6 +3,7 @@ import dev.objz.commandbridge.logging.Log; import dev.objz.commandbridge.security.AuthService; import dev.objz.commandbridge.security.AuthStatus; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.InboundHandler; import dev.objz.commandbridge.net.InNode; import dev.objz.commandbridge.net.payloads.util.AuthRequestPayload; @@ -10,10 +11,9 @@ import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.net.proto.MessageType; import dev.objz.commandbridge.scripting.model.enums.Location; -import dev.objz.commandbridge.velocity.net.WsServer; +import dev.objz.commandbridge.velocity.net.EndpointServer; import dev.objz.commandbridge.velocity.net.session.ClientSession; import dev.objz.commandbridge.velocity.net.session.SessionHub; -import io.undertow.websockets.core.WebSocketChannel; import java.util.Objects; import java.util.UUID; @@ -23,14 +23,14 @@ public final class AuthHandler extends InboundHandler { private final AuthService auth; private final SessionHub sessions; - private final WsServer ws; + private final EndpointServer endpointServer; private volatile Consumer onAuthed; - public AuthHandler(AuthService auth, SessionHub sessions, WsServer ws) { + public AuthHandler(AuthService auth, SessionHub sessions, EndpointServer endpointServer) { this.auth = Objects.requireNonNull(auth); this.sessions = Objects.requireNonNull(sessions); - this.ws = Objects.requireNonNull(ws); + this.endpointServer = Objects.requireNonNull(endpointServer); } public void register(InNode router) { @@ -42,9 +42,9 @@ public void onAuthenticated(Consumer listener) { } @Override - public void accept(WebSocketChannel ch, Envelope env) { + public void accept(Endpoint endpoint, Envelope env) { if (env.type() != MessageType.AUTH_REQUEST || env.payload() == null) { - ws.close(ch); + endpointServer.close(endpoint); return; } @@ -53,56 +53,62 @@ public void accept(WebSocketChannel ch, Envelope env) { ap = Envelope.MAPPER.treeToValue(env.payload(), AuthRequestPayload.class); } catch (Exception e) { Log.error(e, "Failed to handle AUTH_REQUEST from {}", env.from()); - reply(ch, env, MessageType.AUTH_FAIL, null) + reply(endpoint, env, MessageType.AUTH_FAIL, null) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send auth response: {}", ex.toString()); return null; }); - ws.close(ch); + endpointServer.close(endpoint); return; } if (ap == null || env.from() == null || ap.clientNonce() == null || ap.hmac() == null) { - reply(ch, env, MessageType.AUTH_FAIL, null) + reply(endpoint, env, MessageType.AUTH_FAIL, null) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send auth response: {}", ex.toString()); return null; }); - ws.close(ch); - Log.error("Authentication failed (malformed payload) from '{}'", ch.getSourceAddress()); + endpointServer.close(endpoint); + Log.error("Authentication failed (malformed payload) from '{}'", endpoint.describe()); return; } if (!auth.verify(env.from(), ap.clientNonce(), ap.hmac())) { // HMAC(clientId:clientNonce) - reply(ch, env, MessageType.AUTH_FAIL, null) + reply(endpoint, env, MessageType.AUTH_FAIL, null) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send auth response: {}", ex.toString()); return null; }); - ws.close(ch); - Log.error("Authentication failed for '{}' from '{}'", env.from(), ch.getSourceAddress()); + endpointServer.close(endpoint); + Log.error("Authentication failed for '{}' from '{}'", env.from(), endpoint.describe()); return; } final String serverNonce = UUID.randomUUID().toString().replace("-", ""); final String serverMac = auth.signServerProof(env.from(), ap.clientNonce(), serverNonce); - ClientSession s = sessions.add(ch, env.from()); + sessions.get(env.from()).ifPresent(existing -> { + if (existing.endpoint() != endpoint) { + endpointServer.close(existing.endpoint()); + } + }); + + ClientSession s = sessions.add(env.from(), endpoint); s.status(AuthStatus.AUTH_FAIL); s.location(ap.location() != null ? ap.location() : Location.BACKEND); - reply(ch, env, MessageType.AUTH_OK, new AuthResponsePayload(serverNonce, serverMac)) + reply(endpoint, env, MessageType.AUTH_OK, new AuthResponsePayload(serverNonce, serverMac)) .dispatch() .exceptionally(ex -> { Log.warn("Failed to send auth response: {}", ex.toString()); return null; }); s.status(AuthStatus.AUTH_OK); - Log.success(true, "Authentication succeeded for '{}' from '{}'", env.from(), ch.getSourceAddress()); + Log.success(true, "Authentication succeeded for '{}' from '{}'", env.from(), endpoint.describe()); var cb = onAuthed; if (cb != null) { diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/ExecuteCommandHandler.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/ExecuteCommandHandler.java index 7074972a..0842ad35 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/ExecuteCommandHandler.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/ExecuteCommandHandler.java @@ -3,13 +3,13 @@ import com.velocitypowered.api.proxy.Player; import com.velocitypowered.api.proxy.ProxyServer; import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.InboundHandler; import dev.objz.commandbridge.net.payloads.cmd.ExecuteCommandResult; import dev.objz.commandbridge.net.payloads.feedback.Feedback; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.logging.Summary; import dev.objz.commandbridge.util.MM; -import io.undertow.websockets.core.WebSocketChannel; import java.util.List; import java.util.Objects; @@ -25,7 +25,7 @@ public ExecuteCommandHandler(ProxyServer proxy) { } @Override - public void accept(WebSocketChannel ch, Envelope env) { + public void accept(Endpoint endpoint, Envelope env) { if (env.payload() == null) { Log.warn("Received EXECUTE_COMMAND_RESULT with null payload from '{}'", env.from()); return; diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/InvokedCommandHandler.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/InvokedCommandHandler.java index 1d9d8868..83b86eb5 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/InvokedCommandHandler.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/in/InvokedCommandHandler.java @@ -1,13 +1,13 @@ package dev.objz.commandbridge.velocity.net.in; import dev.objz.commandbridge.logging.Log; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.net.InboundHandler; import dev.objz.commandbridge.net.payloads.cmd.InvokedCommand; import dev.objz.commandbridge.net.proto.Envelope; import dev.objz.commandbridge.velocity.dispatch.CommandEntry; import dev.objz.commandbridge.velocity.net.session.ClientSession; import dev.objz.commandbridge.velocity.net.session.SessionHub; -import io.undertow.websockets.core.WebSocketChannel; import java.util.Objects; @@ -22,7 +22,7 @@ public InvokedCommandHandler(SessionHub sessions, CommandEntry executor) { } @Override - public void accept(WebSocketChannel ch, Envelope env) { + public void accept(Endpoint endpoint, Envelope env) { if (env.payload() == null) { Log.warn("Received INVOKED_COMMAND with null payload from '{}'", env.from()); return; @@ -41,9 +41,11 @@ public void accept(WebSocketChannel ch, Envelope env) { return; } - ClientSession session = sessions.get(ch); + ClientSession session = sessions.get(endpoint); if (session == null) { - Log.warn("No session found for channel when handling INVOKED_COMMAND"); + String source = endpoint != null ? endpoint.describe() : "unknown"; + Log.warn("No session found for endpoint '{}' while handling INVOKED_COMMAND from '{}'", source, + env.from()); return; } diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/ExecuteCommandRequest.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/ExecuteCommandRequest.java index e487f9dd..59396790 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/ExecuteCommandRequest.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/ExecuteCommandRequest.java @@ -31,7 +31,7 @@ public SendOperation accept(ExecuteCommandContext ctx) { ctx.session().id(), Envelope.MAPPER.valueToTree(payload)); - SendOperation op = send(ctx.session().ch(), env); + SendOperation op = send(ctx.session().endpoint(), env); op.dispatch() .exceptionally(ex -> { Log.error(ex, "Failed to send EXECUTE_COMMAND to '{}'", ctx.session().id()); diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/PingRequest.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/PingRequest.java index 1c045be9..8741e3e6 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/PingRequest.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/PingRequest.java @@ -21,7 +21,7 @@ public SendOperation accept(PingRequestContext ctx) { ObjectNode payloadNode = Envelope.MAPPER.valueToTree(payload); final Envelope env = Envelope.make(MessageType.PING, serverId, clientId, payloadNode); - SendOperation op = send(ctx.session.ch(), env) + SendOperation op = send(ctx.session.endpoint(), env) .expect(MessageType.PONG) .timeout(ctx.timeout); diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/RegistrationRequest.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/RegistrationRequest.java index 1f407e86..630b6a02 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/RegistrationRequest.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/out/RegistrationRequest.java @@ -49,7 +49,7 @@ public SendOperation accept(RegistrationRequestContext ctx) { ObjectNode payloadNode = Envelope.MAPPER.valueToTree(payload); final Envelope env = Envelope.make(MessageType.REGISTER_COMMANDS, serverId, clientId, payloadNode); - SendOperation op = send(ctx.session.ch(), env) + SendOperation op = send(ctx.session.endpoint(), env) .expect(MessageType.REGISTER_COMMANDS_RESULT) .timeout(ctx.timeout); diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/ClientSession.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/ClientSession.java index 95159e27..fc8030db 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/ClientSession.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/ClientSession.java @@ -1,24 +1,20 @@ package dev.objz.commandbridge.velocity.net.session; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.scripting.model.enums.Location; import dev.objz.commandbridge.security.AuthStatus; -import io.undertow.websockets.core.WebSocketChannel; public final class ClientSession { - private final WebSocketChannel ch; - private volatile String id = "unknown"; - private volatile AuthStatus status = AuthStatus.AUTH_OK; + private volatile Endpoint endpoint; + private volatile String id; + private volatile AuthStatus status = AuthStatus.AUTH_FAIL; private volatile Location location = Location.BACKEND; - public ClientSession(WebSocketChannel ch, String clientId) { - this.ch = ch; + public ClientSession(Endpoint endpoint, String clientId) { + this.endpoint = endpoint; this.id = clientId; } - public WebSocketChannel ch() { - return ch; - } - public String id() { return id; } @@ -31,6 +27,14 @@ public void status(AuthStatus status) { this.status = status; } + public Endpoint endpoint() { + return endpoint; + } + + public void endpoint(Endpoint endpoint) { + this.endpoint = endpoint; + } + public Location location() { return location; } diff --git a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/SessionHub.java b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/SessionHub.java index 24b20020..76357640 100644 --- a/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/SessionHub.java +++ b/velocity/src/main/java/dev/objz/commandbridge/velocity/net/session/SessionHub.java @@ -1,78 +1,103 @@ package dev.objz.commandbridge.velocity.net.session; +import dev.objz.commandbridge.net.Endpoint; import dev.objz.commandbridge.scripting.model.enums.Location; import dev.objz.commandbridge.security.AuthStatus; -import io.undertow.websockets.core.WebSocketChannel; import java.util.Iterator; import java.util.Objects; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiConsumer; public final class SessionHub implements Iterable { - private final ConcurrentHashMap clients = new ConcurrentHashMap<>(); + private final ConcurrentHashMap clientsById = new ConcurrentHashMap<>(); + private final ConcurrentHashMap clientsByEndpoint = new ConcurrentHashMap<>(); + + public ClientSession add(String clientId, Endpoint endpoint) { + Objects.requireNonNull(endpoint); - public ClientSession add(WebSocketChannel ch, String clientId) { - Objects.requireNonNull(ch); if (clientId == null || clientId.isBlank()) - clientId = "unknown"; - var s = new ClientSession(ch, clientId); - s.status(AuthStatus.AUTH_OK); - clients.put(ch, s); - ch.getCloseSetter().set(c -> remove((WebSocketChannel) c)); + clientId = "unknown-" + UUID.randomUUID(); + + var s = new ClientSession(endpoint, clientId); + var previous = clientsById.put(clientId, s); + if (previous != null && previous.endpoint() != null) { + clientsByEndpoint.remove(previous.endpoint()); + } + clientsByEndpoint.put(endpoint, s); return s; } - public ClientSession get(WebSocketChannel ch) { - return clients.get(ch); + public Optional get(String clientId) { + if (clientId == null || clientId.isBlank()) + return Optional.empty(); + return Optional.ofNullable(clientsById.get(clientId)); } - public boolean contains(WebSocketChannel ch) { - return clients.containsKey(ch); + public ClientSession get(Endpoint endpoint) { + if (endpoint == null) + return null; + return clientsByEndpoint.get(endpoint); } - public ClientSession remove(WebSocketChannel ch) { - if (ch == null) + public ClientSession remove(String clientId) { + if (clientId == null) return null; - return clients.remove(ch); + var removed = clientsById.remove(clientId); + if (removed != null && removed.endpoint() != null) { + clientsByEndpoint.remove(removed.endpoint()); + } + return removed; + } + + public ClientSession remove(Endpoint endpoint) { + if (endpoint == null) + return null; + + var removed = clientsByEndpoint.remove(endpoint); + if (removed != null && removed.id() != null) { + clientsById.remove(removed.id(), removed); + } + return removed; } public void clear() { - clients.clear(); + clientsById.clear(); + clientsByEndpoint.clear(); } public int size() { - return clients.size(); + return clientsById.size(); } @Override public Iterator iterator() { - return clients.values().iterator(); - } - - public ClientSession set(WebSocketChannel ch, BiConsumer setter, T value) { - Objects.requireNonNull(setter); - return clients.computeIfPresent(ch, (k, s) -> { - setter.accept(s, value); - return s; - }); + return clientsById.values().iterator(); } public Optional findSession(String id, Location location) { if (id == null || location == null) return Optional.empty(); - for (ClientSession session : clients.values()) { - if (id.equals(session.id()) - && session.location() == location - && session.status() == AuthStatus.AUTH_OK - && session.ch() != null - && session.ch().isOpen()) { - return Optional.of(session); - } - } - return Optional.empty(); + var session = clientsById.get(id); + if (session == null) + return Optional.empty(); + + if (!id.equals(session.id())) + return Optional.empty(); + + if (session.location() != location) + return Optional.empty(); + + if (session.status() != AuthStatus.AUTH_OK) + return Optional.empty(); + + var ep = session.endpoint(); + if (ep == null || !ep.isOpen()) + return Optional.empty(); + + return Optional.of(session); } }