Add WebSocket support with routing, origin validation, session management, and broadcasting

This commit is contained in:
CodingPhoenixx
2026-05-28 13:23:23 +02:00
parent b75e1e5c6e
commit 994e7fa80c
11 changed files with 799 additions and 4 deletions
@@ -7,15 +7,24 @@ import dev.coph.nextusweb.server.router.Request;
import dev.coph.nextusweb.server.router.Response;
import dev.coph.nextusweb.server.router.Router;
import dev.coph.nextusweb.server.router.exception.BadRequestException;
import dev.coph.nextusweb.server.websocket.WebSocketConfig;
import dev.coph.nextusweb.server.websocket.WebSocketFrameHandlerFactory;
import dev.coph.nextusweb.server.websocket.WebSocketRouter;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.handler.timeout.IdleStateHandler;
import java.net.InetSocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
@@ -26,16 +35,29 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
private final Router router;
private final CorsHandler cors;
private final RateLimitGate rateLimit;
private final WebSocketRouter wsRouter;
private final WebSocketConfig wsConfig;
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit) {
this(router, cors, rateLimit, null, null);
}
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit,
WebSocketRouter wsRouter, WebSocketConfig wsConfig) {
this.router = router;
this.cors = cors;
this.rateLimit = rateLimit;
this.wsRouter = wsRouter;
this.wsConfig = wsConfig;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
if (wsRouter != null && isWebSocketUpgrade(req)) {
if (handleWebSocketUpgrade(ctx, req)) return;
}
req.retain();
VT_EXECUTOR.execute(() -> {
try {
@@ -46,6 +68,63 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
});
}
private static boolean isWebSocketUpgrade(FullHttpRequest req) {
if (req.method() != HttpMethod.GET) return false;
String upgrade = req.headers().get(HttpHeaderNames.UPGRADE);
if (upgrade == null || !"websocket".equalsIgnoreCase(upgrade)) return false;
String connection = req.headers().get(HttpHeaderNames.CONNECTION);
if (connection == null) return false;
for (String token : connection.split(",")) {
if ("upgrade".equalsIgnoreCase(token.trim())) return true;
}
return false;
}
private boolean handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest req) {
String path = new QueryStringDecoder(req.uri()).path();
WebSocketRouter.Resolution resolution = wsRouter.resolve(path);
if (resolution == null) return false;
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
if (!wsConfig.isOriginAllowed(origin)) {
FullHttpResponse forbidden = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN);
forbidden.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0);
ctx.writeAndFlush(forbidden).addListener(ChannelFutureListener.CLOSE);
return true;
}
WebSocketServerProtocolConfig protoCfg = WebSocketServerProtocolConfig.newBuilder()
.websocketPath(path)
.checkStartsWith(false)
.subprotocols(wsConfig.subprotocolsCsv())
.maxFramePayloadLength(wsConfig.maxFramePayloadLength())
.allowExtensions(wsConfig.compression())
.build();
ChannelPipeline pipeline = ctx.pipeline();
String myName = ctx.name();
if (wsConfig.idleTimeout() != null) {
long secs = Math.max(1, wsConfig.idleTimeout().toSeconds());
pipeline.addBefore(myName, "ws-idle",
new IdleStateHandler(0, 0, secs, TimeUnit.SECONDS));
}
if (wsConfig.compression()) {
pipeline.addBefore(myName, "ws-deflate",
new WebSocketServerCompressionHandler());
}
pipeline.addBefore(myName, "ws-proto",
new WebSocketServerProtocolHandler(protoCfg));
pipeline.addBefore(myName, "ws-frames",
WebSocketFrameHandlerFactory.create(resolution.handler(), path, resolution.pathParams()));
ChannelHandlerContext anchor = pipeline.context(HttpObjectAggregator.class);
if (anchor == null) anchor = pipeline.firstContext();
anchor.fireChannelRead(req.retain());
return true;
}
private void handle(ChannelHandlerContext ctx, FullHttpRequest raw) {
String origin = raw.headers().get("Origin");