Add WebSocket support with routing, origin validation, session management, and broadcasting

This commit is contained in:
CodingPhoenixx
2026-05-28 13:23:23 +02:00
parent b75e1e5c6e
commit 994e7fa80c
11 changed files with 799 additions and 4 deletions
@@ -0,0 +1,138 @@
package dev.coph.nextusweb.server.websocket;
import java.time.Duration;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
public final class WebSocketConfig {
private final int maxFramePayloadLength;
private final int maxAggregatedMessageSize;
private final Duration idleTimeout;
private final Set<String> allowedOrigins;
private final boolean allowAnyOrigin;
private final Set<String> subprotocols;
private final boolean compression;
private final boolean checkStartsWith;
private WebSocketConfig(Builder b) {
this.maxFramePayloadLength = b.maxFramePayloadLength;
this.maxAggregatedMessageSize = b.maxAggregatedMessageSize;
this.idleTimeout = b.idleTimeout;
this.allowedOrigins = Set.copyOf(b.allowedOrigins);
this.allowAnyOrigin = b.allowAnyOrigin;
this.subprotocols = Set.copyOf(b.subprotocols);
this.compression = b.compression;
this.checkStartsWith = b.checkStartsWith;
}
public static WebSocketConfig defaults() {
return builder().build();
}
public static Builder builder() {
return new Builder();
}
public boolean isOriginAllowed(String origin) {
if (allowAnyOrigin) return true;
if (origin == null) return false;
return allowedOrigins.contains(origin);
}
public int maxFramePayloadLength() {
return maxFramePayloadLength;
}
public int maxAggregatedMessageSize() {
return maxAggregatedMessageSize;
}
public Duration idleTimeout() {
return idleTimeout;
}
public boolean allowAnyOrigin() {
return allowAnyOrigin;
}
public Set<String> allowedOrigins() {
return allowedOrigins;
}
public String subprotocolsCsv() {
if (subprotocols.isEmpty()) return null;
return String.join(",", subprotocols);
}
public boolean compression() {
return compression;
}
public boolean checkStartsWith() {
return checkStartsWith;
}
public static final class Builder {
private int maxFramePayloadLength = 65_536;
private int maxAggregatedMessageSize = 1_048_576;
private Duration idleTimeout = Duration.ofSeconds(60);
private final Set<String> allowedOrigins = new LinkedHashSet<>();
private boolean allowAnyOrigin = false;
private final Set<String> subprotocols = new LinkedHashSet<>();
private boolean compression = true;
private boolean checkStartsWith = false;
public Builder maxFramePayloadLength(int bytes) {
if (bytes <= 0) throw new IllegalArgumentException("maxFramePayloadLength must be > 0");
this.maxFramePayloadLength = bytes;
return this;
}
public Builder maxAggregatedMessageSize(int bytes) {
if (bytes <= 0) throw new IllegalArgumentException("maxAggregatedMessageSize must be > 0");
this.maxAggregatedMessageSize = bytes;
return this;
}
public Builder idleTimeout(Duration timeout) {
this.idleTimeout = timeout;
return this;
}
public Builder noIdleTimeout() {
this.idleTimeout = null;
return this;
}
public Builder allowedOrigins(String... origins) {
Collections.addAll(this.allowedOrigins, origins);
return this;
}
public Builder anyOrigin() {
this.allowAnyOrigin = true;
return this;
}
public Builder subprotocols(String... protocols) {
Collections.addAll(this.subprotocols, protocols);
return this;
}
public Builder compression(boolean enabled) {
this.compression = enabled;
return this;
}
public Builder checkStartsWith(boolean v) {
this.checkStartsWith = v;
return this;
}
public WebSocketConfig build() {
return new WebSocketConfig(this);
}
}
}
@@ -0,0 +1,117 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.CharsetUtil;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
private static final Executor VT_EXECUTOR =
Executors.newVirtualThreadPerTaskExecutor();
private final WebSocketHandler handler;
private final String path;
private final Map<String, String> pathParams;
WebSocketFrameHandler(WebSocketHandler handler, String path, Map<String, String> pathParams) {
this.handler = handler;
this.path = path;
this.pathParams = pathParams;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
WebSocketSession session = new WebSocketSession(ctx.channel(), path, pathParams);
ctx.channel().attr(WebSocketSession.SESSION_KEY).set(session);
VT_EXECUTOR.execute(() -> {
try {
handler.onOpen(session);
} catch (Throwable t) {
safeError(session, t);
}
});
return;
}
if (evt instanceof IdleStateEvent) {
ctx.close();
return;
}
super.userEventTriggered(ctx, evt);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
if (session == null) return;
if (frame instanceof TextWebSocketFrame text) {
String content = text.text();
VT_EXECUTOR.execute(() -> {
try {
handler.onMessage(session, content);
} catch (Throwable t) {
safeError(session, t);
}
});
} else if (frame instanceof BinaryWebSocketFrame bin) {
int readable = bin.content().readableBytes();
byte[] data = new byte[readable];
bin.content().getBytes(bin.content().readerIndex(), data);
VT_EXECUTOR.execute(() -> {
try {
handler.onBinary(session, data);
} catch (Throwable t) {
safeError(session, t);
}
});
} else if (frame instanceof CloseWebSocketFrame close) {
int code = close.statusCode();
String reason = close.reasonText() == null ? "" : close.reasonText();
VT_EXECUTOR.execute(() -> {
try {
handler.onClose(session, code, reason);
} catch (Throwable t) {
safeError(session, t);
}
});
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).getAndSet(null);
if (session == null) return;
VT_EXECUTOR.execute(() -> {
try {
handler.onClose(session, 1006, "Connection closed");
} catch (Throwable t) {
safeError(session, t);
}
});
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
if (session != null) safeError(session, cause);
ctx.close();
}
private void safeError(WebSocketSession session, Throwable cause) {
try {
handler.onError(session, cause);
} catch (Throwable ignored) {
}
}
}
@@ -0,0 +1,16 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.ChannelHandler;
import java.util.Map;
public final class WebSocketFrameHandlerFactory {
private WebSocketFrameHandlerFactory() {
}
public static ChannelHandler create(WebSocketHandler handler, String path,
Map<String, String> pathParams) {
return new WebSocketFrameHandler(handler, path, pathParams);
}
}
@@ -0,0 +1,83 @@
package dev.coph.nextusweb.server.websocket;
import dev.coph.nextusweb.server.json.JsonMapper;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import tools.jackson.core.JacksonException;
public final class WebSocketGroup {
private final ChannelGroup channels;
private final String name;
public WebSocketGroup() {
this("anonymous");
}
public WebSocketGroup(String name) {
this.name = name;
this.channels = new DefaultChannelGroup(name, GlobalEventExecutor.INSTANCE);
}
public String name() {
return name;
}
public WebSocketGroup add(WebSocketSession session) {
channels.add(session.channel());
return this;
}
public WebSocketGroup remove(WebSocketSession session) {
channels.remove(session.channel());
return this;
}
public int size() {
return channels.size();
}
public WebSocketGroup broadcast(String text) {
channels.writeAndFlush(new TextWebSocketFrame(text));
return this;
}
public WebSocketGroup broadcastJson(Object value) {
try {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
String text = new String(bytes, java.nio.charset.StandardCharsets.UTF_8);
channels.writeAndFlush(new TextWebSocketFrame(text));
} catch (JacksonException e) {
throw new RuntimeException("JSON serialization failed", e);
}
return this;
}
public WebSocketGroup broadcastBinary(byte[] data) {
for (var ch : channels) {
if (ch.isActive()) {
var buf = ch.alloc().buffer(data.length).writeBytes(data);
ch.writeAndFlush(new BinaryWebSocketFrame(buf));
}
}
return this;
}
public WebSocketGroup broadcastExcept(WebSocketSession exclude, String text) {
var excludeCh = exclude == null ? null : exclude.channel();
for (var ch : channels) {
if (ch.isActive() && ch != excludeCh) {
ch.writeAndFlush(new TextWebSocketFrame(text));
}
}
return this;
}
public WebSocketGroup closeAll() {
channels.close();
return this;
}
}
@@ -0,0 +1,19 @@
package dev.coph.nextusweb.server.websocket;
public interface WebSocketHandler {
default void onOpen(WebSocketSession session) throws Exception {
}
default void onMessage(WebSocketSession session, String message) throws Exception {
}
default void onBinary(WebSocketSession session, byte[] data) throws Exception {
}
default void onClose(WebSocketSession session, int code, String reason) throws Exception {
}
default void onError(WebSocketSession session, Throwable cause) throws Exception {
}
}
@@ -0,0 +1,70 @@
package dev.coph.nextusweb.server.websocket;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public final class WebSocketRouter {
private final Node root = new Node();
public WebSocketRouter on(String path, WebSocketHandler handler) {
Node node = root;
for (String segment : split(path)) {
if (segment.startsWith("{") && segment.endsWith("}")) {
if (node.paramChild == null) {
node.paramChild = new Node();
node.paramName = segment.substring(1, segment.length() - 1);
}
node = node.paramChild;
} else {
node = node.children.computeIfAbsent(segment, k -> new Node());
}
}
node.handler = handler;
return this;
}
public Resolution resolve(String path) {
Map<String, String> params = new HashMap<>(4);
Node node = root;
for (String segment : split(path)) {
Node next = node.children.get(segment);
if (next != null) {
node = next;
} else if (node.paramChild != null) {
params.put(node.paramName, segment);
node = node.paramChild;
} else {
return null;
}
}
if (node.handler == null) return null;
return new Resolution(node.handler, params);
}
private static List<String> split(String path) {
List<String> out = new ArrayList<>();
int start = path.startsWith("/") ? 1 : 0;
for (int i = start; i < path.length(); i++) {
if (path.charAt(i) == '/') {
if (i > start) out.add(path.substring(start, i));
start = i + 1;
}
}
if (start < path.length()) out.add(path.substring(start));
return out;
}
public record Resolution(WebSocketHandler handler, Map<String, String> pathParams) {
}
private static final class Node {
final Map<String, Node> children = new ConcurrentHashMap<>();
Node paramChild;
String paramName;
WebSocketHandler handler;
}
}
@@ -0,0 +1,129 @@
package dev.coph.nextusweb.server.websocket;
import dev.coph.nextusweb.server.json.JsonMapper;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import io.netty.util.CharsetUtil;
import tools.jackson.core.JacksonException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
public final class WebSocketSession {
static final AttributeKey<WebSocketSession> SESSION_KEY =
AttributeKey.valueOf("nexusweb.ws.session");
private final Channel channel;
private final String id;
private final String path;
private final Map<String, String> pathParams;
private final Map<String, Object> attributes = new ConcurrentHashMap<>();
WebSocketSession(Channel channel, String path, Map<String, String> pathParams) {
this.channel = channel;
this.id = UUID.randomUUID().toString();
this.path = path;
this.pathParams = pathParams;
}
public String id() {
return id;
}
public String path() {
return path;
}
public String pathParam(String name) {
return pathParams.get(name);
}
public boolean isOpen() {
return channel.isActive();
}
public String remoteAddress() {
SocketAddress addr = channel.remoteAddress();
if (addr instanceof InetSocketAddress inet) {
return inet.getAddress().getHostAddress();
}
return addr == null ? null : addr.toString();
}
public Channel channel() {
return channel;
}
public WebSocketSession attribute(String name, Object value) {
if (value == null) attributes.remove(name);
else attributes.put(name, value);
return this;
}
@SuppressWarnings("unchecked")
public <T> T attribute(String name) {
return (T) attributes.get(name);
}
public ChannelFuture send(String text) {
if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new TextWebSocketFrame(text));
}
public ChannelFuture sendJson(Object value) {
try {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(bytes.length).writeBytes(bytes);
return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
} catch (JacksonException e) {
throw new RuntimeException("JSON serialization failed", e);
}
}
public ChannelFuture sendBinary(byte[] data) {
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(data);
return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
}
public ChannelFuture ping() {
if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new PingWebSocketFrame());
}
public ChannelFuture close() {
return close(1000, "");
}
public ChannelFuture close(int code, String reason) {
if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new CloseWebSocketFrame(code, reason))
.addListener(ChannelFutureListener.CLOSE);
}
static ChannelFuture sendRaw(Channel channel, String text) {
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer();
buf.writeCharSequence(text, CharsetUtil.UTF_8);
return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
}
static ChannelFuture sendRawBinary(Channel channel, byte[] data) {
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(Unpooled.wrappedBuffer(data));
return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
}
}