Add WebSocket support with routing, origin validation, session management, and broadcasting
This commit is contained in:
@@ -7,15 +7,24 @@ import dev.coph.nextusweb.server.router.Request;
|
||||
import dev.coph.nextusweb.server.router.Response;
|
||||
import dev.coph.nextusweb.server.router.Router;
|
||||
import dev.coph.nextusweb.server.router.exception.BadRequestException;
|
||||
import dev.coph.nextusweb.server.websocket.WebSocketConfig;
|
||||
import dev.coph.nextusweb.server.websocket.WebSocketFrameHandlerFactory;
|
||||
import dev.coph.nextusweb.server.websocket.WebSocketRouter;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.handler.codec.http.*;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
|
||||
import io.netty.handler.timeout.IdleStateHandler;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
|
||||
@@ -26,16 +35,29 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
|
||||
private final Router router;
|
||||
private final CorsHandler cors;
|
||||
private final RateLimitGate rateLimit;
|
||||
private final WebSocketRouter wsRouter;
|
||||
private final WebSocketConfig wsConfig;
|
||||
|
||||
|
||||
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit) {
|
||||
this(router, cors, rateLimit, null, null);
|
||||
}
|
||||
|
||||
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit,
|
||||
WebSocketRouter wsRouter, WebSocketConfig wsConfig) {
|
||||
this.router = router;
|
||||
this.cors = cors;
|
||||
this.rateLimit = rateLimit;
|
||||
this.wsRouter = wsRouter;
|
||||
this.wsConfig = wsConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
|
||||
if (wsRouter != null && isWebSocketUpgrade(req)) {
|
||||
if (handleWebSocketUpgrade(ctx, req)) return;
|
||||
}
|
||||
|
||||
req.retain();
|
||||
VT_EXECUTOR.execute(() -> {
|
||||
try {
|
||||
@@ -46,6 +68,63 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
|
||||
});
|
||||
}
|
||||
|
||||
private static boolean isWebSocketUpgrade(FullHttpRequest req) {
|
||||
if (req.method() != HttpMethod.GET) return false;
|
||||
String upgrade = req.headers().get(HttpHeaderNames.UPGRADE);
|
||||
if (upgrade == null || !"websocket".equalsIgnoreCase(upgrade)) return false;
|
||||
String connection = req.headers().get(HttpHeaderNames.CONNECTION);
|
||||
if (connection == null) return false;
|
||||
for (String token : connection.split(",")) {
|
||||
if ("upgrade".equalsIgnoreCase(token.trim())) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private boolean handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest req) {
|
||||
String path = new QueryStringDecoder(req.uri()).path();
|
||||
WebSocketRouter.Resolution resolution = wsRouter.resolve(path);
|
||||
if (resolution == null) return false;
|
||||
|
||||
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
|
||||
if (!wsConfig.isOriginAllowed(origin)) {
|
||||
FullHttpResponse forbidden = new DefaultFullHttpResponse(
|
||||
HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN);
|
||||
forbidden.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0);
|
||||
ctx.writeAndFlush(forbidden).addListener(ChannelFutureListener.CLOSE);
|
||||
return true;
|
||||
}
|
||||
|
||||
WebSocketServerProtocolConfig protoCfg = WebSocketServerProtocolConfig.newBuilder()
|
||||
.websocketPath(path)
|
||||
.checkStartsWith(false)
|
||||
.subprotocols(wsConfig.subprotocolsCsv())
|
||||
.maxFramePayloadLength(wsConfig.maxFramePayloadLength())
|
||||
.allowExtensions(wsConfig.compression())
|
||||
.build();
|
||||
|
||||
ChannelPipeline pipeline = ctx.pipeline();
|
||||
String myName = ctx.name();
|
||||
|
||||
if (wsConfig.idleTimeout() != null) {
|
||||
long secs = Math.max(1, wsConfig.idleTimeout().toSeconds());
|
||||
pipeline.addBefore(myName, "ws-idle",
|
||||
new IdleStateHandler(0, 0, secs, TimeUnit.SECONDS));
|
||||
}
|
||||
if (wsConfig.compression()) {
|
||||
pipeline.addBefore(myName, "ws-deflate",
|
||||
new WebSocketServerCompressionHandler());
|
||||
}
|
||||
pipeline.addBefore(myName, "ws-proto",
|
||||
new WebSocketServerProtocolHandler(protoCfg));
|
||||
pipeline.addBefore(myName, "ws-frames",
|
||||
WebSocketFrameHandlerFactory.create(resolution.handler(), path, resolution.pathParams()));
|
||||
|
||||
ChannelHandlerContext anchor = pipeline.context(HttpObjectAggregator.class);
|
||||
if (anchor == null) anchor = pipeline.firstContext();
|
||||
anchor.fireChannelRead(req.retain());
|
||||
return true;
|
||||
}
|
||||
|
||||
private void handle(ChannelHandlerContext ctx, FullHttpRequest raw) {
|
||||
String origin = raw.headers().get("Origin");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user