diff --git a/.idea/db-forest-config.xml b/.idea/db-forest-config.xml
new file mode 100644
index 0000000..dcd78db
--- /dev/null
+++ b/.idea/db-forest-config.xml
@@ -0,0 +1,11 @@
+
+
+
+ .
+ ----------------------------------------
+ 1:0:e0f49905-9df6-459a-a57c-731edb2c1607
+ 2:0:74720f71-b717-4c46-a783-e93fc40a8785
+ 3:0:c2ae7de6-543e-4eed-8b31-a13cb00693a8
+ .
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 96bf21c..91c6cba 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@ A lightweight, high-performance HTTP server library built on top of Netty. Nexus
- **Middleware chain** — attach cross-cutting logic to all routes
- **CORS support** — configurable origins, methods, headers, credentials, and preflight caching
- **Rate limiting** — four algorithm implementations with per-IP, per-token, or custom key strategies
+- **WebSockets** — path-routed handlers with origin validation, idle timeout, frame size limits and permessage-deflate
- **JSON I/O** — built-in Jackson integration for request parsing and response serialization
---
@@ -265,6 +266,104 @@ Every response automatically includes:
---
+## WebSockets
+
+WebSocket routes are registered on a `WebSocketRouter` and attached to the server alongside the HTTP `Router`. Upgrade requests (`GET` + `Upgrade: websocket`) are intercepted before the HTTP router runs.
+
+### Handler
+
+Implement `WebSocketHandler`. All callbacks are optional.
+
+```java
+public class ChatSocket implements WebSocketHandler {
+
+ private final WebSocketGroup room = new WebSocketGroup("chat");
+
+ @Override
+ public void onOpen(WebSocketSession session) {
+ room.add(session);
+ session.send("{\"type\":\"welcome\",\"id\":\"" + session.id() + "\"}");
+ }
+
+ @Override
+ public void onMessage(WebSocketSession session, String message) {
+ room.broadcastExcept(session, message);
+ }
+
+ @Override
+ public void onClose(WebSocketSession session, int code, String reason) {
+ room.remove(session);
+ }
+}
+```
+
+### Registration
+
+```java
+WebSocketRouter wsRouter = new WebSocketRouter()
+ .on("/ws/chat", new ChatSocket())
+ .on("/ws/rooms/{room}", new RoomSocket());
+
+WebSocketConfig wsConfig = WebSocketConfig.builder()
+ .allowedOrigins("https://app.example.com")
+ .maxFramePayloadLength(64 * 1024) // 64 KiB per frame
+ .maxAggregatedMessageSize(1024 * 1024) // 1 MiB upgrade body cap
+ .idleTimeout(Duration.ofSeconds(60)) // close idle peers
+ .subprotocols("chat.v1")
+ .compression(true) // permessage-deflate
+ .build();
+
+HttpServer.builder(8080, router)
+ .withWebSockets(wsRouter, wsConfig)
+ .start();
+```
+
+Use `WebSocketConfig.defaults()` (or `.anyOrigin()` on the builder) only for local development — production deployments should always allow-list origins explicitly.
+
+### Session API
+
+```java
+session.id(); // stable UUID for this connection
+session.path(); // matched path
+session.pathParam("room"); // path parameter, e.g. from /ws/rooms/{room}
+session.remoteAddress(); // client IP
+session.attribute("userId", id); // attach state to the session
+session.attribute("userId"); // read it back
+
+session.send("text"); // text frame
+session.sendJson(dto); // serialized via Jackson
+session.sendBinary(bytes); // binary frame
+session.ping(); // ping frame
+session.close(); // normal close (1000)
+session.close(1011, "internal"); // close with code + reason
+```
+
+### Broadcasting
+
+`WebSocketGroup` is a thin fluent wrapper around Netty's `ChannelGroup` — joining a session is cheap and removal happens automatically when the channel closes.
+
+```java
+WebSocketGroup group = new WebSocketGroup("lobby")
+ .add(sessionA)
+ .add(sessionB);
+
+group.broadcast("hello everyone");
+group.broadcastJson(eventDto);
+group.broadcastExcept(sessionA, "everyone but A");
+```
+
+### Security & limits
+
+| Concern | How it's handled |
+|---|---|
+| Cross-origin upgrades | `Origin` header validated against `WebSocketConfig.allowedOrigins(...)`; mismatched origins are rejected with `403` |
+| Memory exhaustion | `maxFramePayloadLength` caps a single frame; `maxAggregatedMessageSize` caps the upgrade request body |
+| Idle / zombie connections | `idleTimeout` triggers a server-side close when no read **and** no write happen within the window |
+| User code isolation | All callbacks dispatch onto Java virtual threads, never the Netty event loop |
+| Subprotocol negotiation | Server advertises the configured `subprotocols(...)` list; clients that ask for an unsupported subprotocol fail the handshake |
+
+---
+
## Full Example
```java
@@ -304,9 +403,19 @@ RateLimitConfig rlConfig = RateLimitConfig.builder()
.build();
RateLimitGate gate = new RateLimitGate(rlConfig);
+// WebSockets
+WebSocketRouter wsRouter = new WebSocketRouter()
+ .on("/ws/chat", new ChatSocket());
+
+WebSocketConfig wsConfig = WebSocketConfig.builder()
+ .allowedOrigins("https://app.example.com")
+ .idleTimeout(Duration.ofSeconds(60))
+ .build();
+
HttpServer.builder(8080, router)
.withCorsHandler(cors)
.withRateLimitGate(gate)
+ .withWebSockets(wsRouter, wsConfig)
.start();
```
diff --git a/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java b/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java
index 0ed8b29..f9451bb 100644
--- a/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java
+++ b/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java
@@ -7,15 +7,24 @@ import dev.coph.nextusweb.server.router.Request;
import dev.coph.nextusweb.server.router.Response;
import dev.coph.nextusweb.server.router.Router;
import dev.coph.nextusweb.server.router.exception.BadRequestException;
+import dev.coph.nextusweb.server.websocket.WebSocketConfig;
+import dev.coph.nextusweb.server.websocket.WebSocketFrameHandlerFactory;
+import dev.coph.nextusweb.server.websocket.WebSocketRouter;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
+import io.netty.handler.timeout.IdleStateHandler;
import java.net.InetSocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public final class HttpRequestHandler extends SimpleChannelInboundHandler {
@@ -26,16 +35,29 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler {
try {
@@ -46,6 +68,63 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler channelClass;
@@ -66,6 +86,10 @@ public final class HttpServer {
channelClass = NioServerSocketChannel.class;
}
+ int maxAggregated = wsConfig != null
+ ? Math.max(1024 * 1024, wsConfig.maxAggregatedMessageSize())
+ : 1024 * 1024;
+
try {
new ServerBootstrap()
.group(boss, worker)
@@ -78,8 +102,8 @@ public final class HttpServer {
protected void initChannel(SocketChannel ch) {
ch.pipeline()
.addLast(new HttpServerCodec())
- .addLast(new HttpObjectAggregator(1024 * 1024))
- .addLast(new HttpRequestHandler(router, cors, gate));
+ .addLast(new HttpObjectAggregator(maxAggregated))
+ .addLast(new HttpRequestHandler(router, cors, gate, wsRouter, wsConfig));
}
})
.bind(port).sync().channel().closeFuture().sync();
@@ -88,4 +112,4 @@ public final class HttpServer {
worker.shutdownGracefully();
}
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java
new file mode 100644
index 0000000..16c7710
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java
@@ -0,0 +1,138 @@
+package dev.coph.nextusweb.server.websocket;
+
+import java.time.Duration;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.Set;
+
+public final class WebSocketConfig {
+
+ private final int maxFramePayloadLength;
+ private final int maxAggregatedMessageSize;
+ private final Duration idleTimeout;
+ private final Set allowedOrigins;
+ private final boolean allowAnyOrigin;
+ private final Set subprotocols;
+ private final boolean compression;
+ private final boolean checkStartsWith;
+
+ private WebSocketConfig(Builder b) {
+ this.maxFramePayloadLength = b.maxFramePayloadLength;
+ this.maxAggregatedMessageSize = b.maxAggregatedMessageSize;
+ this.idleTimeout = b.idleTimeout;
+ this.allowedOrigins = Set.copyOf(b.allowedOrigins);
+ this.allowAnyOrigin = b.allowAnyOrigin;
+ this.subprotocols = Set.copyOf(b.subprotocols);
+ this.compression = b.compression;
+ this.checkStartsWith = b.checkStartsWith;
+ }
+
+ public static WebSocketConfig defaults() {
+ return builder().build();
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public boolean isOriginAllowed(String origin) {
+ if (allowAnyOrigin) return true;
+ if (origin == null) return false;
+ return allowedOrigins.contains(origin);
+ }
+
+ public int maxFramePayloadLength() {
+ return maxFramePayloadLength;
+ }
+
+ public int maxAggregatedMessageSize() {
+ return maxAggregatedMessageSize;
+ }
+
+ public Duration idleTimeout() {
+ return idleTimeout;
+ }
+
+ public boolean allowAnyOrigin() {
+ return allowAnyOrigin;
+ }
+
+ public Set allowedOrigins() {
+ return allowedOrigins;
+ }
+
+ public String subprotocolsCsv() {
+ if (subprotocols.isEmpty()) return null;
+ return String.join(",", subprotocols);
+ }
+
+ public boolean compression() {
+ return compression;
+ }
+
+ public boolean checkStartsWith() {
+ return checkStartsWith;
+ }
+
+ public static final class Builder {
+ private int maxFramePayloadLength = 65_536;
+ private int maxAggregatedMessageSize = 1_048_576;
+ private Duration idleTimeout = Duration.ofSeconds(60);
+ private final Set allowedOrigins = new LinkedHashSet<>();
+ private boolean allowAnyOrigin = false;
+ private final Set subprotocols = new LinkedHashSet<>();
+ private boolean compression = true;
+ private boolean checkStartsWith = false;
+
+ public Builder maxFramePayloadLength(int bytes) {
+ if (bytes <= 0) throw new IllegalArgumentException("maxFramePayloadLength must be > 0");
+ this.maxFramePayloadLength = bytes;
+ return this;
+ }
+
+ public Builder maxAggregatedMessageSize(int bytes) {
+ if (bytes <= 0) throw new IllegalArgumentException("maxAggregatedMessageSize must be > 0");
+ this.maxAggregatedMessageSize = bytes;
+ return this;
+ }
+
+ public Builder idleTimeout(Duration timeout) {
+ this.idleTimeout = timeout;
+ return this;
+ }
+
+ public Builder noIdleTimeout() {
+ this.idleTimeout = null;
+ return this;
+ }
+
+ public Builder allowedOrigins(String... origins) {
+ Collections.addAll(this.allowedOrigins, origins);
+ return this;
+ }
+
+ public Builder anyOrigin() {
+ this.allowAnyOrigin = true;
+ return this;
+ }
+
+ public Builder subprotocols(String... protocols) {
+ Collections.addAll(this.subprotocols, protocols);
+ return this;
+ }
+
+ public Builder compression(boolean enabled) {
+ this.compression = enabled;
+ return this;
+ }
+
+ public Builder checkStartsWith(boolean v) {
+ this.checkStartsWith = v;
+ return this;
+ }
+
+ public WebSocketConfig build() {
+ return new WebSocketConfig(this);
+ }
+ }
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java
new file mode 100644
index 0000000..97d40b1
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java
@@ -0,0 +1,117 @@
+package dev.coph.nextusweb.server.websocket;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import io.netty.handler.timeout.IdleStateEvent;
+import io.netty.util.CharsetUtil;
+
+import java.util.Map;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+
+final class WebSocketFrameHandler extends SimpleChannelInboundHandler {
+
+ private static final Executor VT_EXECUTOR =
+ Executors.newVirtualThreadPerTaskExecutor();
+
+ private final WebSocketHandler handler;
+ private final String path;
+ private final Map pathParams;
+
+ WebSocketFrameHandler(WebSocketHandler handler, String path, Map pathParams) {
+ this.handler = handler;
+ this.path = path;
+ this.pathParams = pathParams;
+ }
+
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+ if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
+ WebSocketSession session = new WebSocketSession(ctx.channel(), path, pathParams);
+ ctx.channel().attr(WebSocketSession.SESSION_KEY).set(session);
+ VT_EXECUTOR.execute(() -> {
+ try {
+ handler.onOpen(session);
+ } catch (Throwable t) {
+ safeError(session, t);
+ }
+ });
+ return;
+ }
+ if (evt instanceof IdleStateEvent) {
+ ctx.close();
+ return;
+ }
+ super.userEventTriggered(ctx, evt);
+ }
+
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
+ WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
+ if (session == null) return;
+
+ if (frame instanceof TextWebSocketFrame text) {
+ String content = text.text();
+ VT_EXECUTOR.execute(() -> {
+ try {
+ handler.onMessage(session, content);
+ } catch (Throwable t) {
+ safeError(session, t);
+ }
+ });
+ } else if (frame instanceof BinaryWebSocketFrame bin) {
+ int readable = bin.content().readableBytes();
+ byte[] data = new byte[readable];
+ bin.content().getBytes(bin.content().readerIndex(), data);
+ VT_EXECUTOR.execute(() -> {
+ try {
+ handler.onBinary(session, data);
+ } catch (Throwable t) {
+ safeError(session, t);
+ }
+ });
+ } else if (frame instanceof CloseWebSocketFrame close) {
+ int code = close.statusCode();
+ String reason = close.reasonText() == null ? "" : close.reasonText();
+ VT_EXECUTOR.execute(() -> {
+ try {
+ handler.onClose(session, code, reason);
+ } catch (Throwable t) {
+ safeError(session, t);
+ }
+ });
+ }
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) {
+ WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).getAndSet(null);
+ if (session == null) return;
+ VT_EXECUTOR.execute(() -> {
+ try {
+ handler.onClose(session, 1006, "Connection closed");
+ } catch (Throwable t) {
+ safeError(session, t);
+ }
+ });
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
+ if (session != null) safeError(session, cause);
+ ctx.close();
+ }
+
+ private void safeError(WebSocketSession session, Throwable cause) {
+ try {
+ handler.onError(session, cause);
+ } catch (Throwable ignored) {
+ }
+ }
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java
new file mode 100644
index 0000000..dd19683
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java
@@ -0,0 +1,16 @@
+package dev.coph.nextusweb.server.websocket;
+
+import io.netty.channel.ChannelHandler;
+
+import java.util.Map;
+
+public final class WebSocketFrameHandlerFactory {
+
+ private WebSocketFrameHandlerFactory() {
+ }
+
+ public static ChannelHandler create(WebSocketHandler handler, String path,
+ Map pathParams) {
+ return new WebSocketFrameHandler(handler, path, pathParams);
+ }
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java
new file mode 100644
index 0000000..a8ab512
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java
@@ -0,0 +1,83 @@
+package dev.coph.nextusweb.server.websocket;
+
+import dev.coph.nextusweb.server.json.JsonMapper;
+import io.netty.channel.group.ChannelGroup;
+import io.netty.channel.group.DefaultChannelGroup;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import io.netty.util.concurrent.GlobalEventExecutor;
+import tools.jackson.core.JacksonException;
+
+public final class WebSocketGroup {
+
+ private final ChannelGroup channels;
+ private final String name;
+
+ public WebSocketGroup() {
+ this("anonymous");
+ }
+
+ public WebSocketGroup(String name) {
+ this.name = name;
+ this.channels = new DefaultChannelGroup(name, GlobalEventExecutor.INSTANCE);
+ }
+
+ public String name() {
+ return name;
+ }
+
+ public WebSocketGroup add(WebSocketSession session) {
+ channels.add(session.channel());
+ return this;
+ }
+
+ public WebSocketGroup remove(WebSocketSession session) {
+ channels.remove(session.channel());
+ return this;
+ }
+
+ public int size() {
+ return channels.size();
+ }
+
+ public WebSocketGroup broadcast(String text) {
+ channels.writeAndFlush(new TextWebSocketFrame(text));
+ return this;
+ }
+
+ public WebSocketGroup broadcastJson(Object value) {
+ try {
+ byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
+ String text = new String(bytes, java.nio.charset.StandardCharsets.UTF_8);
+ channels.writeAndFlush(new TextWebSocketFrame(text));
+ } catch (JacksonException e) {
+ throw new RuntimeException("JSON serialization failed", e);
+ }
+ return this;
+ }
+
+ public WebSocketGroup broadcastBinary(byte[] data) {
+ for (var ch : channels) {
+ if (ch.isActive()) {
+ var buf = ch.alloc().buffer(data.length).writeBytes(data);
+ ch.writeAndFlush(new BinaryWebSocketFrame(buf));
+ }
+ }
+ return this;
+ }
+
+ public WebSocketGroup broadcastExcept(WebSocketSession exclude, String text) {
+ var excludeCh = exclude == null ? null : exclude.channel();
+ for (var ch : channels) {
+ if (ch.isActive() && ch != excludeCh) {
+ ch.writeAndFlush(new TextWebSocketFrame(text));
+ }
+ }
+ return this;
+ }
+
+ public WebSocketGroup closeAll() {
+ channels.close();
+ return this;
+ }
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketHandler.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketHandler.java
new file mode 100644
index 0000000..7cc3453
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketHandler.java
@@ -0,0 +1,19 @@
+package dev.coph.nextusweb.server.websocket;
+
+public interface WebSocketHandler {
+
+ default void onOpen(WebSocketSession session) throws Exception {
+ }
+
+ default void onMessage(WebSocketSession session, String message) throws Exception {
+ }
+
+ default void onBinary(WebSocketSession session, byte[] data) throws Exception {
+ }
+
+ default void onClose(WebSocketSession session, int code, String reason) throws Exception {
+ }
+
+ default void onError(WebSocketSession session, Throwable cause) throws Exception {
+ }
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketRouter.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketRouter.java
new file mode 100644
index 0000000..9dbe12e
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketRouter.java
@@ -0,0 +1,70 @@
+package dev.coph.nextusweb.server.websocket;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+public final class WebSocketRouter {
+
+ private final Node root = new Node();
+
+ public WebSocketRouter on(String path, WebSocketHandler handler) {
+ Node node = root;
+ for (String segment : split(path)) {
+ if (segment.startsWith("{") && segment.endsWith("}")) {
+ if (node.paramChild == null) {
+ node.paramChild = new Node();
+ node.paramName = segment.substring(1, segment.length() - 1);
+ }
+ node = node.paramChild;
+ } else {
+ node = node.children.computeIfAbsent(segment, k -> new Node());
+ }
+ }
+ node.handler = handler;
+ return this;
+ }
+
+ public Resolution resolve(String path) {
+ Map params = new HashMap<>(4);
+ Node node = root;
+ for (String segment : split(path)) {
+ Node next = node.children.get(segment);
+ if (next != null) {
+ node = next;
+ } else if (node.paramChild != null) {
+ params.put(node.paramName, segment);
+ node = node.paramChild;
+ } else {
+ return null;
+ }
+ }
+ if (node.handler == null) return null;
+ return new Resolution(node.handler, params);
+ }
+
+ private static List split(String path) {
+ List out = new ArrayList<>();
+ int start = path.startsWith("/") ? 1 : 0;
+ for (int i = start; i < path.length(); i++) {
+ if (path.charAt(i) == '/') {
+ if (i > start) out.add(path.substring(start, i));
+ start = i + 1;
+ }
+ }
+ if (start < path.length()) out.add(path.substring(start));
+ return out;
+ }
+
+ public record Resolution(WebSocketHandler handler, Map pathParams) {
+ }
+
+ private static final class Node {
+ final Map children = new ConcurrentHashMap<>();
+ Node paramChild;
+ String paramName;
+ WebSocketHandler handler;
+ }
+}
diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java
new file mode 100644
index 0000000..fe76639
--- /dev/null
+++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java
@@ -0,0 +1,129 @@
+package dev.coph.nextusweb.server.websocket;
+
+import dev.coph.nextusweb.server.json.JsonMapper;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import io.netty.util.AttributeKey;
+import io.netty.util.CharsetUtil;
+import tools.jackson.core.JacksonException;
+
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+
+public final class WebSocketSession {
+
+ static final AttributeKey SESSION_KEY =
+ AttributeKey.valueOf("nexusweb.ws.session");
+
+ private final Channel channel;
+ private final String id;
+ private final String path;
+ private final Map pathParams;
+ private final Map attributes = new ConcurrentHashMap<>();
+
+ WebSocketSession(Channel channel, String path, Map pathParams) {
+ this.channel = channel;
+ this.id = UUID.randomUUID().toString();
+ this.path = path;
+ this.pathParams = pathParams;
+ }
+
+ public String id() {
+ return id;
+ }
+
+ public String path() {
+ return path;
+ }
+
+ public String pathParam(String name) {
+ return pathParams.get(name);
+ }
+
+ public boolean isOpen() {
+ return channel.isActive();
+ }
+
+ public String remoteAddress() {
+ SocketAddress addr = channel.remoteAddress();
+ if (addr instanceof InetSocketAddress inet) {
+ return inet.getAddress().getHostAddress();
+ }
+ return addr == null ? null : addr.toString();
+ }
+
+ public Channel channel() {
+ return channel;
+ }
+
+ public WebSocketSession attribute(String name, Object value) {
+ if (value == null) attributes.remove(name);
+ else attributes.put(name, value);
+ return this;
+ }
+
+ @SuppressWarnings("unchecked")
+ public T attribute(String name) {
+ return (T) attributes.get(name);
+ }
+
+ public ChannelFuture send(String text) {
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ return channel.writeAndFlush(new TextWebSocketFrame(text));
+ }
+
+ public ChannelFuture sendJson(Object value) {
+ try {
+ byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ ByteBuf buf = channel.alloc().buffer(bytes.length).writeBytes(bytes);
+ return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
+ } catch (JacksonException e) {
+ throw new RuntimeException("JSON serialization failed", e);
+ }
+ }
+
+ public ChannelFuture sendBinary(byte[] data) {
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(data);
+ return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
+ }
+
+ public ChannelFuture ping() {
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ return channel.writeAndFlush(new PingWebSocketFrame());
+ }
+
+ public ChannelFuture close() {
+ return close(1000, "");
+ }
+
+ public ChannelFuture close(int code, String reason) {
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ return channel.writeAndFlush(new CloseWebSocketFrame(code, reason))
+ .addListener(ChannelFutureListener.CLOSE);
+ }
+
+ static ChannelFuture sendRaw(Channel channel, String text) {
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ ByteBuf buf = channel.alloc().buffer();
+ buf.writeCharSequence(text, CharsetUtil.UTF_8);
+ return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
+ }
+
+ static ChannelFuture sendRawBinary(Channel channel, byte[] data) {
+ if (!channel.isActive()) return channel.newSucceededFuture();
+ ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(Unpooled.wrappedBuffer(data));
+ return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
+ }
+}