Add WebSocket support with routing, origin validation, session management, and broadcasting

This commit is contained in:
CodingPhoenixx
2026-05-28 13:23:23 +02:00
parent b75e1e5c6e
commit 994e7fa80c
11 changed files with 799 additions and 4 deletions
+11
View File
@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="db-forest-configuration">
<data version="2">.
----------------------------------------
1:0:e0f49905-9df6-459a-a57c-731edb2c1607
2:0:74720f71-b717-4c46-a783-e93fc40a8785
3:0:c2ae7de6-543e-4eed-8b31-a13cb00693a8
.</data>
</component>
</project>
+109
View File
@@ -11,6 +11,7 @@ A lightweight, high-performance HTTP server library built on top of Netty. Nexus
- **Middleware chain** — attach cross-cutting logic to all routes
- **CORS support** — configurable origins, methods, headers, credentials, and preflight caching
- **Rate limiting** — four algorithm implementations with per-IP, per-token, or custom key strategies
- **WebSockets** — path-routed handlers with origin validation, idle timeout, frame size limits and permessage-deflate
- **JSON I/O** — built-in Jackson integration for request parsing and response serialization
---
@@ -265,6 +266,104 @@ Every response automatically includes:
---
## WebSockets
WebSocket routes are registered on a `WebSocketRouter` and attached to the server alongside the HTTP `Router`. Upgrade requests (`GET` + `Upgrade: websocket`) are intercepted before the HTTP router runs.
### Handler
Implement `WebSocketHandler`. All callbacks are optional.
```java
public class ChatSocket implements WebSocketHandler {
private final WebSocketGroup room = new WebSocketGroup("chat");
@Override
public void onOpen(WebSocketSession session) {
room.add(session);
session.send("{\"type\":\"welcome\",\"id\":\"" + session.id() + "\"}");
}
@Override
public void onMessage(WebSocketSession session, String message) {
room.broadcastExcept(session, message);
}
@Override
public void onClose(WebSocketSession session, int code, String reason) {
room.remove(session);
}
}
```
### Registration
```java
WebSocketRouter wsRouter = new WebSocketRouter()
.on("/ws/chat", new ChatSocket())
.on("/ws/rooms/{room}", new RoomSocket());
WebSocketConfig wsConfig = WebSocketConfig.builder()
.allowedOrigins("https://app.example.com")
.maxFramePayloadLength(64 * 1024) // 64 KiB per frame
.maxAggregatedMessageSize(1024 * 1024) // 1 MiB upgrade body cap
.idleTimeout(Duration.ofSeconds(60)) // close idle peers
.subprotocols("chat.v1")
.compression(true) // permessage-deflate
.build();
HttpServer.builder(8080, router)
.withWebSockets(wsRouter, wsConfig)
.start();
```
Use `WebSocketConfig.defaults()` (or `.anyOrigin()` on the builder) only for local development — production deployments should always allow-list origins explicitly.
### Session API
```java
session.id(); // stable UUID for this connection
session.path(); // matched path
session.pathParam("room"); // path parameter, e.g. from /ws/rooms/{room}
session.remoteAddress(); // client IP
session.attribute("userId", id); // attach state to the session
session.attribute("userId"); // read it back
session.send("text"); // text frame
session.sendJson(dto); // serialized via Jackson
session.sendBinary(bytes); // binary frame
session.ping(); // ping frame
session.close(); // normal close (1000)
session.close(1011, "internal"); // close with code + reason
```
### Broadcasting
`WebSocketGroup` is a thin fluent wrapper around Netty's `ChannelGroup` — joining a session is cheap and removal happens automatically when the channel closes.
```java
WebSocketGroup group = new WebSocketGroup("lobby")
.add(sessionA)
.add(sessionB);
group.broadcast("hello everyone");
group.broadcastJson(eventDto);
group.broadcastExcept(sessionA, "everyone but A");
```
### Security & limits
| Concern | How it's handled |
|---|---|
| Cross-origin upgrades | `Origin` header validated against `WebSocketConfig.allowedOrigins(...)`; mismatched origins are rejected with `403` |
| Memory exhaustion | `maxFramePayloadLength` caps a single frame; `maxAggregatedMessageSize` caps the upgrade request body |
| Idle / zombie connections | `idleTimeout` triggers a server-side close when no read **and** no write happen within the window |
| User code isolation | All callbacks dispatch onto Java virtual threads, never the Netty event loop |
| Subprotocol negotiation | Server advertises the configured `subprotocols(...)` list; clients that ask for an unsupported subprotocol fail the handshake |
---
## Full Example
```java
@@ -304,9 +403,19 @@ RateLimitConfig rlConfig = RateLimitConfig.builder()
.build();
RateLimitGate gate = new RateLimitGate(rlConfig);
// WebSockets
WebSocketRouter wsRouter = new WebSocketRouter()
.on("/ws/chat", new ChatSocket());
WebSocketConfig wsConfig = WebSocketConfig.builder()
.allowedOrigins("https://app.example.com")
.idleTimeout(Duration.ofSeconds(60))
.build();
HttpServer.builder(8080, router)
.withCorsHandler(cors)
.withRateLimitGate(gate)
.withWebSockets(wsRouter, wsConfig)
.start();
```
@@ -7,15 +7,24 @@ import dev.coph.nextusweb.server.router.Request;
import dev.coph.nextusweb.server.router.Response;
import dev.coph.nextusweb.server.router.Router;
import dev.coph.nextusweb.server.router.exception.BadRequestException;
import dev.coph.nextusweb.server.websocket.WebSocketConfig;
import dev.coph.nextusweb.server.websocket.WebSocketFrameHandlerFactory;
import dev.coph.nextusweb.server.websocket.WebSocketRouter;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.handler.timeout.IdleStateHandler;
import java.net.InetSocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
@@ -26,16 +35,29 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
private final Router router;
private final CorsHandler cors;
private final RateLimitGate rateLimit;
private final WebSocketRouter wsRouter;
private final WebSocketConfig wsConfig;
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit) {
this(router, cors, rateLimit, null, null);
}
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit,
WebSocketRouter wsRouter, WebSocketConfig wsConfig) {
this.router = router;
this.cors = cors;
this.rateLimit = rateLimit;
this.wsRouter = wsRouter;
this.wsConfig = wsConfig;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
if (wsRouter != null && isWebSocketUpgrade(req)) {
if (handleWebSocketUpgrade(ctx, req)) return;
}
req.retain();
VT_EXECUTOR.execute(() -> {
try {
@@ -46,6 +68,63 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
});
}
private static boolean isWebSocketUpgrade(FullHttpRequest req) {
if (req.method() != HttpMethod.GET) return false;
String upgrade = req.headers().get(HttpHeaderNames.UPGRADE);
if (upgrade == null || !"websocket".equalsIgnoreCase(upgrade)) return false;
String connection = req.headers().get(HttpHeaderNames.CONNECTION);
if (connection == null) return false;
for (String token : connection.split(",")) {
if ("upgrade".equalsIgnoreCase(token.trim())) return true;
}
return false;
}
private boolean handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest req) {
String path = new QueryStringDecoder(req.uri()).path();
WebSocketRouter.Resolution resolution = wsRouter.resolve(path);
if (resolution == null) return false;
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
if (!wsConfig.isOriginAllowed(origin)) {
FullHttpResponse forbidden = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN);
forbidden.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0);
ctx.writeAndFlush(forbidden).addListener(ChannelFutureListener.CLOSE);
return true;
}
WebSocketServerProtocolConfig protoCfg = WebSocketServerProtocolConfig.newBuilder()
.websocketPath(path)
.checkStartsWith(false)
.subprotocols(wsConfig.subprotocolsCsv())
.maxFramePayloadLength(wsConfig.maxFramePayloadLength())
.allowExtensions(wsConfig.compression())
.build();
ChannelPipeline pipeline = ctx.pipeline();
String myName = ctx.name();
if (wsConfig.idleTimeout() != null) {
long secs = Math.max(1, wsConfig.idleTimeout().toSeconds());
pipeline.addBefore(myName, "ws-idle",
new IdleStateHandler(0, 0, secs, TimeUnit.SECONDS));
}
if (wsConfig.compression()) {
pipeline.addBefore(myName, "ws-deflate",
new WebSocketServerCompressionHandler());
}
pipeline.addBefore(myName, "ws-proto",
new WebSocketServerProtocolHandler(protoCfg));
pipeline.addBefore(myName, "ws-frames",
WebSocketFrameHandlerFactory.create(resolution.handler(), path, resolution.pathParams()));
ChannelHandlerContext anchor = pipeline.context(HttpObjectAggregator.class);
if (anchor == null) anchor = pipeline.firstContext();
anchor.fireChannelRead(req.retain());
return true;
}
private void handle(ChannelHandlerContext ctx, FullHttpRequest raw) {
String origin = raw.headers().get("Origin");
@@ -3,6 +3,8 @@ package dev.coph.nextusweb.server;
import dev.coph.nextusweb.server.cores.CorsHandler;
import dev.coph.nextusweb.server.ratelimit.RateLimitGate;
import dev.coph.nextusweb.server.router.Router;
import dev.coph.nextusweb.server.websocket.WebSocketConfig;
import dev.coph.nextusweb.server.websocket.WebSocketRouter;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.epoll.Epoll;
@@ -23,6 +25,8 @@ public final class HttpServer {
private final Router router;
private CorsHandler cors;
private RateLimitGate gate;
private WebSocketRouter wsRouter;
private WebSocketConfig wsConfig;
private HttpServer(int port, Router router) {
this.port = port;
@@ -43,12 +47,28 @@ public final class HttpServer {
return this;
}
public HttpServer withWebSockets(WebSocketRouter wsRouter) {
return withWebSockets(wsRouter, WebSocketConfig.defaults());
}
public HttpServer withWebSockets(WebSocketRouter wsRouter, WebSocketConfig wsConfig) {
this.wsRouter = wsRouter;
this.wsConfig = wsConfig;
return this;
}
public void start() throws InterruptedException {
start(port, router, cors, gate);
start(port, router, cors, gate, wsRouter, wsConfig);
}
public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate)
throws InterruptedException {
start(port, router, cors, gate, null, null);
}
public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate,
WebSocketRouter wsRouter, WebSocketConfig wsConfig)
throws InterruptedException {
EventLoopGroup boss, worker;
Class<? extends ServerChannel> channelClass;
@@ -66,6 +86,10 @@ public final class HttpServer {
channelClass = NioServerSocketChannel.class;
}
int maxAggregated = wsConfig != null
? Math.max(1024 * 1024, wsConfig.maxAggregatedMessageSize())
: 1024 * 1024;
try {
new ServerBootstrap()
.group(boss, worker)
@@ -78,8 +102,8 @@ public final class HttpServer {
protected void initChannel(SocketChannel ch) {
ch.pipeline()
.addLast(new HttpServerCodec())
.addLast(new HttpObjectAggregator(1024 * 1024))
.addLast(new HttpRequestHandler(router, cors, gate));
.addLast(new HttpObjectAggregator(maxAggregated))
.addLast(new HttpRequestHandler(router, cors, gate, wsRouter, wsConfig));
}
})
.bind(port).sync().channel().closeFuture().sync();
@@ -88,4 +112,4 @@ public final class HttpServer {
worker.shutdownGracefully();
}
}
}
}
@@ -0,0 +1,138 @@
package dev.coph.nextusweb.server.websocket;
import java.time.Duration;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
public final class WebSocketConfig {
private final int maxFramePayloadLength;
private final int maxAggregatedMessageSize;
private final Duration idleTimeout;
private final Set<String> allowedOrigins;
private final boolean allowAnyOrigin;
private final Set<String> subprotocols;
private final boolean compression;
private final boolean checkStartsWith;
private WebSocketConfig(Builder b) {
this.maxFramePayloadLength = b.maxFramePayloadLength;
this.maxAggregatedMessageSize = b.maxAggregatedMessageSize;
this.idleTimeout = b.idleTimeout;
this.allowedOrigins = Set.copyOf(b.allowedOrigins);
this.allowAnyOrigin = b.allowAnyOrigin;
this.subprotocols = Set.copyOf(b.subprotocols);
this.compression = b.compression;
this.checkStartsWith = b.checkStartsWith;
}
public static WebSocketConfig defaults() {
return builder().build();
}
public static Builder builder() {
return new Builder();
}
public boolean isOriginAllowed(String origin) {
if (allowAnyOrigin) return true;
if (origin == null) return false;
return allowedOrigins.contains(origin);
}
public int maxFramePayloadLength() {
return maxFramePayloadLength;
}
public int maxAggregatedMessageSize() {
return maxAggregatedMessageSize;
}
public Duration idleTimeout() {
return idleTimeout;
}
public boolean allowAnyOrigin() {
return allowAnyOrigin;
}
public Set<String> allowedOrigins() {
return allowedOrigins;
}
public String subprotocolsCsv() {
if (subprotocols.isEmpty()) return null;
return String.join(",", subprotocols);
}
public boolean compression() {
return compression;
}
public boolean checkStartsWith() {
return checkStartsWith;
}
public static final class Builder {
private int maxFramePayloadLength = 65_536;
private int maxAggregatedMessageSize = 1_048_576;
private Duration idleTimeout = Duration.ofSeconds(60);
private final Set<String> allowedOrigins = new LinkedHashSet<>();
private boolean allowAnyOrigin = false;
private final Set<String> subprotocols = new LinkedHashSet<>();
private boolean compression = true;
private boolean checkStartsWith = false;
public Builder maxFramePayloadLength(int bytes) {
if (bytes <= 0) throw new IllegalArgumentException("maxFramePayloadLength must be > 0");
this.maxFramePayloadLength = bytes;
return this;
}
public Builder maxAggregatedMessageSize(int bytes) {
if (bytes <= 0) throw new IllegalArgumentException("maxAggregatedMessageSize must be > 0");
this.maxAggregatedMessageSize = bytes;
return this;
}
public Builder idleTimeout(Duration timeout) {
this.idleTimeout = timeout;
return this;
}
public Builder noIdleTimeout() {
this.idleTimeout = null;
return this;
}
public Builder allowedOrigins(String... origins) {
Collections.addAll(this.allowedOrigins, origins);
return this;
}
public Builder anyOrigin() {
this.allowAnyOrigin = true;
return this;
}
public Builder subprotocols(String... protocols) {
Collections.addAll(this.subprotocols, protocols);
return this;
}
public Builder compression(boolean enabled) {
this.compression = enabled;
return this;
}
public Builder checkStartsWith(boolean v) {
this.checkStartsWith = v;
return this;
}
public WebSocketConfig build() {
return new WebSocketConfig(this);
}
}
}
@@ -0,0 +1,117 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.CharsetUtil;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
private static final Executor VT_EXECUTOR =
Executors.newVirtualThreadPerTaskExecutor();
private final WebSocketHandler handler;
private final String path;
private final Map<String, String> pathParams;
WebSocketFrameHandler(WebSocketHandler handler, String path, Map<String, String> pathParams) {
this.handler = handler;
this.path = path;
this.pathParams = pathParams;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
WebSocketSession session = new WebSocketSession(ctx.channel(), path, pathParams);
ctx.channel().attr(WebSocketSession.SESSION_KEY).set(session);
VT_EXECUTOR.execute(() -> {
try {
handler.onOpen(session);
} catch (Throwable t) {
safeError(session, t);
}
});
return;
}
if (evt instanceof IdleStateEvent) {
ctx.close();
return;
}
super.userEventTriggered(ctx, evt);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
if (session == null) return;
if (frame instanceof TextWebSocketFrame text) {
String content = text.text();
VT_EXECUTOR.execute(() -> {
try {
handler.onMessage(session, content);
} catch (Throwable t) {
safeError(session, t);
}
});
} else if (frame instanceof BinaryWebSocketFrame bin) {
int readable = bin.content().readableBytes();
byte[] data = new byte[readable];
bin.content().getBytes(bin.content().readerIndex(), data);
VT_EXECUTOR.execute(() -> {
try {
handler.onBinary(session, data);
} catch (Throwable t) {
safeError(session, t);
}
});
} else if (frame instanceof CloseWebSocketFrame close) {
int code = close.statusCode();
String reason = close.reasonText() == null ? "" : close.reasonText();
VT_EXECUTOR.execute(() -> {
try {
handler.onClose(session, code, reason);
} catch (Throwable t) {
safeError(session, t);
}
});
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).getAndSet(null);
if (session == null) return;
VT_EXECUTOR.execute(() -> {
try {
handler.onClose(session, 1006, "Connection closed");
} catch (Throwable t) {
safeError(session, t);
}
});
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
if (session != null) safeError(session, cause);
ctx.close();
}
private void safeError(WebSocketSession session, Throwable cause) {
try {
handler.onError(session, cause);
} catch (Throwable ignored) {
}
}
}
@@ -0,0 +1,16 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.ChannelHandler;
import java.util.Map;
public final class WebSocketFrameHandlerFactory {
private WebSocketFrameHandlerFactory() {
}
public static ChannelHandler create(WebSocketHandler handler, String path,
Map<String, String> pathParams) {
return new WebSocketFrameHandler(handler, path, pathParams);
}
}
@@ -0,0 +1,83 @@
package dev.coph.nextusweb.server.websocket;
import dev.coph.nextusweb.server.json.JsonMapper;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import tools.jackson.core.JacksonException;
public final class WebSocketGroup {
private final ChannelGroup channels;
private final String name;
public WebSocketGroup() {
this("anonymous");
}
public WebSocketGroup(String name) {
this.name = name;
this.channels = new DefaultChannelGroup(name, GlobalEventExecutor.INSTANCE);
}
public String name() {
return name;
}
public WebSocketGroup add(WebSocketSession session) {
channels.add(session.channel());
return this;
}
public WebSocketGroup remove(WebSocketSession session) {
channels.remove(session.channel());
return this;
}
public int size() {
return channels.size();
}
public WebSocketGroup broadcast(String text) {
channels.writeAndFlush(new TextWebSocketFrame(text));
return this;
}
public WebSocketGroup broadcastJson(Object value) {
try {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
String text = new String(bytes, java.nio.charset.StandardCharsets.UTF_8);
channels.writeAndFlush(new TextWebSocketFrame(text));
} catch (JacksonException e) {
throw new RuntimeException("JSON serialization failed", e);
}
return this;
}
public WebSocketGroup broadcastBinary(byte[] data) {
for (var ch : channels) {
if (ch.isActive()) {
var buf = ch.alloc().buffer(data.length).writeBytes(data);
ch.writeAndFlush(new BinaryWebSocketFrame(buf));
}
}
return this;
}
public WebSocketGroup broadcastExcept(WebSocketSession exclude, String text) {
var excludeCh = exclude == null ? null : exclude.channel();
for (var ch : channels) {
if (ch.isActive() && ch != excludeCh) {
ch.writeAndFlush(new TextWebSocketFrame(text));
}
}
return this;
}
public WebSocketGroup closeAll() {
channels.close();
return this;
}
}
@@ -0,0 +1,19 @@
package dev.coph.nextusweb.server.websocket;
public interface WebSocketHandler {
default void onOpen(WebSocketSession session) throws Exception {
}
default void onMessage(WebSocketSession session, String message) throws Exception {
}
default void onBinary(WebSocketSession session, byte[] data) throws Exception {
}
default void onClose(WebSocketSession session, int code, String reason) throws Exception {
}
default void onError(WebSocketSession session, Throwable cause) throws Exception {
}
}
@@ -0,0 +1,70 @@
package dev.coph.nextusweb.server.websocket;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public final class WebSocketRouter {
private final Node root = new Node();
public WebSocketRouter on(String path, WebSocketHandler handler) {
Node node = root;
for (String segment : split(path)) {
if (segment.startsWith("{") && segment.endsWith("}")) {
if (node.paramChild == null) {
node.paramChild = new Node();
node.paramName = segment.substring(1, segment.length() - 1);
}
node = node.paramChild;
} else {
node = node.children.computeIfAbsent(segment, k -> new Node());
}
}
node.handler = handler;
return this;
}
public Resolution resolve(String path) {
Map<String, String> params = new HashMap<>(4);
Node node = root;
for (String segment : split(path)) {
Node next = node.children.get(segment);
if (next != null) {
node = next;
} else if (node.paramChild != null) {
params.put(node.paramName, segment);
node = node.paramChild;
} else {
return null;
}
}
if (node.handler == null) return null;
return new Resolution(node.handler, params);
}
private static List<String> split(String path) {
List<String> out = new ArrayList<>();
int start = path.startsWith("/") ? 1 : 0;
for (int i = start; i < path.length(); i++) {
if (path.charAt(i) == '/') {
if (i > start) out.add(path.substring(start, i));
start = i + 1;
}
}
if (start < path.length()) out.add(path.substring(start));
return out;
}
public record Resolution(WebSocketHandler handler, Map<String, String> pathParams) {
}
private static final class Node {
final Map<String, Node> children = new ConcurrentHashMap<>();
Node paramChild;
String paramName;
WebSocketHandler handler;
}
}
@@ -0,0 +1,129 @@
package dev.coph.nextusweb.server.websocket;
import dev.coph.nextusweb.server.json.JsonMapper;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import io.netty.util.CharsetUtil;
import tools.jackson.core.JacksonException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
public final class WebSocketSession {
static final AttributeKey<WebSocketSession> SESSION_KEY =
AttributeKey.valueOf("nexusweb.ws.session");
private final Channel channel;
private final String id;
private final String path;
private final Map<String, String> pathParams;
private final Map<String, Object> attributes = new ConcurrentHashMap<>();
WebSocketSession(Channel channel, String path, Map<String, String> pathParams) {
this.channel = channel;
this.id = UUID.randomUUID().toString();
this.path = path;
this.pathParams = pathParams;
}
public String id() {
return id;
}
public String path() {
return path;
}
public String pathParam(String name) {
return pathParams.get(name);
}
public boolean isOpen() {
return channel.isActive();
}
public String remoteAddress() {
SocketAddress addr = channel.remoteAddress();
if (addr instanceof InetSocketAddress inet) {
return inet.getAddress().getHostAddress();
}
return addr == null ? null : addr.toString();
}
public Channel channel() {
return channel;
}
public WebSocketSession attribute(String name, Object value) {
if (value == null) attributes.remove(name);
else attributes.put(name, value);
return this;
}
@SuppressWarnings("unchecked")
public <T> T attribute(String name) {
return (T) attributes.get(name);
}
public ChannelFuture send(String text) {
if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new TextWebSocketFrame(text));
}
public ChannelFuture sendJson(Object value) {
try {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(bytes.length).writeBytes(bytes);
return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
} catch (JacksonException e) {
throw new RuntimeException("JSON serialization failed", e);
}
}
public ChannelFuture sendBinary(byte[] data) {
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(data);
return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
}
public ChannelFuture ping() {
if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new PingWebSocketFrame());
}
public ChannelFuture close() {
return close(1000, "");
}
public ChannelFuture close(int code, String reason) {
if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new CloseWebSocketFrame(code, reason))
.addListener(ChannelFutureListener.CLOSE);
}
static ChannelFuture sendRaw(Channel channel, String text) {
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer();
buf.writeCharSequence(text, CharsetUtil.UTF_8);
return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
}
static ChannelFuture sendRawBinary(Channel channel, byte[] data) {
if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(Unpooled.wrappedBuffer(data));
return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
}
}