8 Commits

Author SHA1 Message Date
CodingPhoenixx ac2d1efec7 Consolidate and streamline workflows: merge test.yml and publish.yml into a unified ci.yml for testing, publishing, and release automation.
CI - Test, Publish and Release / run-tests (push) Successful in 19s
CI - Test, Publish and Release / create-release (push) Successful in 12s
CI - Test, Publish and Release / check-and-publish (push) Successful in 13s
2026-05-29 09:14:39 +02:00
CodingPhoenixx d2ce4592d4 Bump project version to 0.0.3 in build.gradle
Auto Publish on Version Change / check-and-publish (push) Successful in 18s
Run Tests on Push and Pull Request / run-tests (push) Successful in 18s
2026-05-29 09:04:55 +02:00
CodingPhoenixx a7b65c031d Add constructors and Javadoc comments to improve clarity and completeness across server components, including WebSocket and routing classes.
Auto Publish on Version Change / check-and-publish (push) Successful in 14s
Run Tests on Push and Pull Request / run-tests (push) Successful in 18s
2026-05-29 09:00:31 +02:00
CodingPhoenixx 5d6e8622bf Add comprehensive Javadoc documentation to server components, including annotations, request/response handling, routing, and WebSocket support. 2026-05-29 08:50:05 +02:00
CodingPhoenixx f00a1098b4 Revert "Streamline test report packaging and uploading in test.yml: archive reports, handle missing directories, and upload to Gitea packages."
Auto Publish on Version Change / check-and-publish (push) Successful in 13s
Run Tests on Push and Pull Request / run-tests (push) Successful in 18s
This reverts commit 0d8ee099a0.
2026-05-28 13:59:41 +02:00
CodingPhoenixx 0d8ee099a0 Streamline test report packaging and uploading in test.yml: archive reports, handle missing directories, and upload to Gitea packages.
Auto Publish on Version Change / check-and-publish (push) Successful in 13s
Run Tests on Push and Pull Request / run-tests (push) Successful in 18s
2026-05-28 13:56:52 +02:00
CodingPhoenixx efd302f625 Expand test coverage for routing and annotation scanning: validate distinct handlers for same path with different methods, ensure correct MethodNotAllowed responses, and handle overwriting/parameterized paths.
Auto Publish on Version Change / check-and-publish (push) Successful in 14s
Run Tests on Push and Pull Request / run-tests (push) Successful in 18s
2026-05-28 13:48:05 +02:00
CodingPhoenixx 78d90855c5 Add test coverage for core server components: annotation scanning, routing, rate limiting, CORS, and JSON handling
Auto Publish on Version Change / check-and-publish (push) Successful in 14s
Run Tests on Push and Pull Request / run-tests (push) Successful in 19s
2026-05-28 13:40:24 +02:00
60 changed files with 3862 additions and 116 deletions
+188
View File
@@ -0,0 +1,188 @@
name: CI - Test, Publish and Release
on:
push:
branches:
- master
pull_request:
jobs:
run-tests:
runs-on: java26
steps:
- name: Checkout Code
run: |
SERVER_DOMAIN=$(echo "${{ github.server_url }}" | sed 's/https:\/\///')
rm -rf "$GITHUB_WORKSPACE"/*
git clone "https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@${SERVER_DOMAIN}/${{ github.repository }}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout ${{ github.sha }}
- name: Make gradlew executable
run: |
cd "$GITHUB_WORKSPACE"
chmod +x ./gradlew
- name: Run JUnit tests
id: run_tests
run: |
cd "$GITHUB_WORKSPACE"
./gradlew test --no-daemon --stacktrace
- name: Upload test reports
if: always()
run: |
cd "$GITHUB_WORKSPACE"
if [ -d "build/reports/tests/test" ]; then
echo "Test reports available in build/reports/tests/test"
ls -la build/reports/tests/test || true
fi
if [ -d "build/test-results/test" ]; then
echo "Test result XMLs:"
ls -la build/test-results/test || true
fi
check-and-publish:
runs-on: java26
# Publish to the Maven repo after the Gitea release has been created
needs: create-release
if: github.event_name == 'push' && github.ref == 'refs/heads/master'
env:
MAVEN_REPO_URL: ${{ secrets.MAVEN_REPO_URL }}
MAVEN_REPO_USER: ${{ secrets.MAVEN_REPO_USER }}
MAVEN_REPO_PASS: ${{ secrets.MAVEN_REPO_PASS }}
steps:
- name: Checkout Code
run: |
SERVER_DOMAIN=$(echo "${{ github.server_url }}" | sed 's/https:\/\///')
rm -rf "$GITHUB_WORKSPACE"/*
git clone "https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@${SERVER_DOMAIN}/${{ github.repository }}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout ${{ github.sha }}
- name: Read version from Gradle
id: get_version
run: |
cd "$GITHUB_WORKSPACE"
chmod +x ./gradlew
VERSION=$(./gradlew properties | grep "^version:" | awk '{print $2}')
echo "Found local project version: $VERSION"
echo "version=$VERSION" >> $GITHUB_OUTPUT
- name: Check if version exists on repository
id: check_repo
run: |
cd "$GITHUB_WORKSPACE"
RAW_GROUP=$(./gradlew properties | grep "^group:" | awk '{print $2}')
GROUP_PATH=$(echo "$RAW_GROUP" | tr '.' '/')
ARTIFACT_ID=$(./gradlew properties | grep "^name:" | awk '{print $2}')
LOCAL_VERSION="${{ steps.get_version.outputs.version }}"
echo "Detected project: $RAW_GROUP:$ARTIFACT_ID:$LOCAL_VERSION"
CHECK_URL="${{ env.MAVEN_REPO_URL }}/${GROUP_PATH}/${ARTIFACT_ID}/${LOCAL_VERSION}/${ARTIFACT_ID}-${LOCAL_VERSION}.pom"
echo "Check url: $CHECK_URL"
STATUS=$(curl -o /dev/null -s -w "%{http_code}" -u "${{ env.MAVEN_REPO_USER }}:${{ env.MAVEN_REPO_PASS }}" "$CHECK_URL")
if [ "$STATUS" = "200" ]; then
echo "Version $LOCAL_VERSION already exists in repository. Skipping publishing."
echo "is_new=false" >> $GITHUB_OUTPUT
else
echo "Version $LOCAL_VERSION not found (Status $STATUS). Start deployment..."
echo "is_new=true" >> $GITHUB_OUTPUT
fi
- name: Push to Maven Repository
if: steps.check_repo.outputs.is_new == 'true'
run: |
cd "$GITHUB_WORKSPACE"
echo "Publishing version ${{ steps.get_version.outputs.version }} zu Repository..."
./gradlew publish
create-release:
runs-on: java26
# Create the Gitea tag/release after tests pass, before publishing
needs: run-tests
if: github.event_name == 'push' && github.ref == 'refs/heads/master'
env:
API_BASE: ${{ github.server_url }}/api/v1/repos/${{ github.repository }}
steps:
- name: Checkout Code
run: |
SERVER_DOMAIN=$(echo "${{ github.server_url }}" | sed 's/https:\/\///')
rm -rf "$GITHUB_WORKSPACE"/*
git clone "https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@${SERVER_DOMAIN}/${{ github.repository }}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout ${{ github.sha }}
- name: Read version from Gradle
id: get_version
run: |
cd "$GITHUB_WORKSPACE"
chmod +x ./gradlew
VERSION=$(./gradlew properties | grep "^version:" | awk '{print $2}')
echo "Found local project version: $VERSION"
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "tag=v$VERSION" >> $GITHUB_OUTPUT
- name: Check if tag/release already exists
id: check_tag
run: |
TAG="${{ steps.get_version.outputs.tag }}"
CHECK_URL="${{ env.API_BASE }}/releases/tags/${TAG}"
echo "Checking for existing release: $CHECK_URL"
STATUS=$(curl -o /dev/null -s -w "%{http_code}" \
-H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
"$CHECK_URL")
if [ "$STATUS" = "200" ]; then
echo "Release for tag $TAG already exists. Skipping."
echo "is_new=false" >> $GITHUB_OUTPUT
else
echo "No release found for tag $TAG (Status $STATUS). Creating release..."
echo "is_new=true" >> $GITHUB_OUTPUT
fi
- name: Create tag and release on Gitea
if: steps.check_tag.outputs.is_new == 'true'
run: |
TAG="${{ steps.get_version.outputs.tag }}"
VERSION="${{ steps.get_version.outputs.version }}"
CREATE_URL="${{ env.API_BASE }}/releases"
echo "Creating release $TAG at $CREATE_URL"
# Gitea creates the tag automatically from target_commitish when it
# does not yet exist, so a separate tag-creation call is not needed.
BODY=$(cat <<EOF
{
"tag_name": "${TAG}",
"target_commitish": "${{ github.sha }}",
"name": "Release ${VERSION}",
"body": "Automated release for version ${VERSION}.",
"draft": false,
"prerelease": false
}
EOF
)
HTTP_STATUS=$(curl -s -o /tmp/release_response.json -w "%{http_code}" \
-X POST \
-H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
-H "Content-Type: application/json" \
-d "$BODY" \
"$CREATE_URL")
echo "Response ($HTTP_STATUS):"
cat /tmp/release_response.json
if [ "$HTTP_STATUS" != "201" ]; then
echo "Failed to create release (HTTP $HTTP_STATUS)"
exit 1
fi
echo "Release $TAG created successfully."
-62
View File
@@ -1,62 +0,0 @@
name: Auto Publish on Version Change
on:
push:
branches:
- master
jobs:
check-and-publish:
runs-on: java26
env:
MAVEN_REPO_URL: ${{ secrets.MAVEN_REPO_URL }}
MAVEN_REPO_USER: ${{ secrets.MAVEN_REPO_USER }}
MAVEN_REPO_PASS: ${{ secrets.MAVEN_REPO_PASS }}
steps:
- name: Checkout Code
run: |
SERVER_DOMAIN=$(echo "${{ github.server_url }}" | sed 's/https:\/\///')
rm -rf "$GITHUB_WORKSPACE"/*
git clone "https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@${SERVER_DOMAIN}/${{ github.repository }}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout ${{ github.sha }}
- name: Read version from Gradle
id: get_version
run: |
cd "$GITHUB_WORKSPACE"
chmod +x ./gradlew
VERSION=$(./gradlew properties | grep "^version:" | awk '{print $2}')
echo "Found local project version: $VERSION"
echo "version=$VERSION" >> $GITHUB_OUTPUT
- name: Check if version exists on repository
id: check_repo
run: |
cd "$GITHUB_WORKSPACE"
RAW_GROUP=$(./gradlew properties | grep "^group:" | awk '{print $2}')
GROUP_PATH=$(echo "$RAW_GROUP" | tr '.' '/')
ARTIFACT_ID=$(./gradlew properties | grep "^name:" | awk '{print $2}')
LOCAL_VERSION="${{ steps.get_version.outputs.version }}"
echo "Detected project: $RAW_GROUP:$ARTIFACT_ID:$LOCAL_VERSION"
CHECK_URL="${{ env.MAVEN_REPO_URL }}/${GROUP_PATH}/${ARTIFACT_ID}/${LOCAL_VERSION}/${ARTIFACT_ID}-${LOCAL_VERSION}.pom"
echo "Check url: $CHECK_URL"
STATUS=$(curl -o /dev/null -s -w "%{http_code}" -u "${{ env.MAVEN_REPO_USER }}:${{ env.MAVEN_REPO_PASS }}" "$CHECK_URL")
if [ "$STATUS" = "200" ]; then
echo "Version $LOCAL_VERSION already exists in repository. Skipping publishing."
echo "is_new=false" >> $GITHUB_OUTPUT
else
echo "Version $LOCAL_VERSION not found (Status $STATUS). Start deployment..."
echo "is_new=true" >> $GITHUB_OUTPUT
fi
- name: Push to Maven Repository
if: steps.check_repo.outputs.is_new == 'true'
run: |
echo "Publishing version ${{ steps.get_version.outputs.version }} zu Repository..."
./gradlew publish
+13 -1
View File
@@ -4,7 +4,7 @@ plugins {
} }
group = 'dev.coph' group = 'dev.coph'
version = '0.0.2' version = '0.0.3'
repositories { repositories {
mavenCentral() mavenCentral()
@@ -13,6 +13,10 @@ repositories {
dependencies { dependencies {
implementation 'io.netty:netty-all:4.2.14.Final' implementation 'io.netty:netty-all:4.2.14.Final'
implementation 'tools.jackson.core:jackson-databind:3.1.3' implementation 'tools.jackson.core:jackson-databind:3.1.3'
testImplementation platform('org.junit:junit-bom:5.11.4')
testImplementation 'org.junit.jupiter:junit-jupiter'
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
} }
java { java {
@@ -23,6 +27,14 @@ java {
targetCompatibility = JavaVersion.VERSION_26 targetCompatibility = JavaVersion.VERSION_26
} }
test {
useJUnitPlatform()
testLogging {
events "passed", "skipped", "failed"
showStandardStreams = false
}
}
publishing { publishing {
publications { publications {
@@ -27,22 +27,56 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/**
* The core inbound channel handler that processes every aggregated HTTP request.
*
* <p>For each request it, in order: detects and performs WebSocket upgrades (when a WebSocket
* router is configured), answers CORS preflight requests, enforces rate limits, resolves the
* route via the {@link Router}, runs middlewares and the matched handler, and finally writes the
* response with CORS and rate-limit headers applied.</p>
*
* <p>Blocking handler logic runs on a virtual-thread executor rather than on the Netty event
* loop, so handlers may perform blocking work without stalling I/O. WebSocket upgrades, by
* contrast, mutate the pipeline and are handled inline on the event loop.</p>
*/
public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> { public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
/** Executor running one virtual thread per task, used to offload blocking handler work. */
private static final Executor VT_EXECUTOR = private static final Executor VT_EXECUTOR =
Executors.newVirtualThreadPerTaskExecutor(); Executors.newVirtualThreadPerTaskExecutor();
/** Router resolving requests to handlers. */
private final Router router; private final Router router;
/** CORS handler, or {@code null} if CORS is disabled. */
private final CorsHandler cors; private final CorsHandler cors;
/** Rate-limit gate, or {@code null} if rate limiting is disabled. */
private final RateLimitGate rateLimit; private final RateLimitGate rateLimit;
/** WebSocket router, or {@code null} if WebSocket support is disabled. */
private final WebSocketRouter wsRouter; private final WebSocketRouter wsRouter;
/** WebSocket configuration; only consulted when {@link #wsRouter} is non-null. */
private final WebSocketConfig wsConfig; private final WebSocketConfig wsConfig;
/**
* Creates a handler without WebSocket support.
*
* @param router the router resolving requests
* @param cors the CORS handler, or {@code null} to disable CORS
* @param rateLimit the rate-limit gate, or {@code null} to disable rate limiting
*/
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit) { public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit) {
this(router, cors, rateLimit, null, null); this(router, cors, rateLimit, null, null);
} }
/**
* Creates a handler, optionally with WebSocket support.
*
* @param router the router resolving requests
* @param cors the CORS handler, or {@code null} to disable CORS
* @param rateLimit the rate-limit gate, or {@code null} to disable rate limiting
* @param wsRouter the WebSocket router, or {@code null} to disable WebSocket support
* @param wsConfig the WebSocket configuration, used only when {@code wsRouter} is non-null
*/
public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit, public HttpRequestHandler(Router router, CorsHandler cors, RateLimitGate rateLimit,
WebSocketRouter wsRouter, WebSocketConfig wsConfig) { WebSocketRouter wsRouter, WebSocketConfig wsConfig) {
this.router = router; this.router = router;
@@ -52,6 +86,14 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
this.wsConfig = wsConfig; this.wsConfig = wsConfig;
} }
/**
* Entry point invoked by Netty for each fully aggregated request. WebSocket upgrade requests
* are handled inline; all other requests are retained and dispatched to a virtual thread for
* processing, with the request released once handling completes.
*
* @param ctx the channel context
* @param req the aggregated HTTP request
*/
@Override @Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
if (wsRouter != null && isWebSocketUpgrade(req)) { if (wsRouter != null && isWebSocketUpgrade(req)) {
@@ -68,6 +110,14 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
}); });
} }
/**
* Determines whether a request is a WebSocket upgrade handshake, i.e. a {@code GET} carrying
* {@code Upgrade: websocket} and a {@code Connection} header that includes the
* {@code upgrade} token.
*
* @param req the request to inspect
* @return {@code true} if the request is a WebSocket upgrade
*/
private static boolean isWebSocketUpgrade(FullHttpRequest req) { private static boolean isWebSocketUpgrade(FullHttpRequest req) {
if (req.method() != HttpMethod.GET) return false; if (req.method() != HttpMethod.GET) return false;
String upgrade = req.headers().get(HttpHeaderNames.UPGRADE); String upgrade = req.headers().get(HttpHeaderNames.UPGRADE);
@@ -80,6 +130,20 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
return false; return false;
} }
/**
* Attempts to upgrade the connection to WebSocket for the request's path.
*
* <p>Resolves the path against the WebSocket router; if no handler matches the upgrade is
* declined. Otherwise the origin is validated, the WebSocket protocol/compression/idle
* handlers and the application frame handler are inserted into the pipeline, and the request
* is re-fired so Netty performs the handshake.</p>
*
* @param ctx the channel context
* @param req the upgrade request
* @return {@code true} if the request was consumed (handshake started or rejected),
* {@code false} if no WebSocket route matched and normal HTTP handling should
* continue
*/
private boolean handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest req) { private boolean handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest req) {
String path = new QueryStringDecoder(req.uri()).path(); String path = new QueryStringDecoder(req.uri()).path();
WebSocketRouter.Resolution resolution = wsRouter.resolve(path); WebSocketRouter.Resolution resolution = wsRouter.resolve(path);
@@ -125,6 +189,18 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
return true; return true;
} }
/**
* Processes a normal (non-WebSocket) HTTP request: applies CORS preflight handling and rate
* limiting, resolves the route, runs middlewares and the handler, and sends the response.
*
* <p>Exceptions from the handler are mapped to responses: a {@link BadRequestException}
* becomes a {@code 400}, any other exception a {@code 500}. Routing misses become
* {@code 404}, and method mismatches a {@code 405} with an {@code Allow} header. CORS and
* rate-limit headers are applied to the final response in all cases.</p>
*
* @param ctx the channel context
* @param raw the aggregated request being handled
*/
private void handle(ChannelHandlerContext ctx, FullHttpRequest raw) { private void handle(ChannelHandlerContext ctx, FullHttpRequest raw) {
String origin = raw.headers().get("Origin"); String origin = raw.headers().get("Origin");
@@ -182,6 +258,13 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
send(ctx, res); send(ctx, res);
} }
/**
* Converts the framework {@link Response} into a Netty {@link FullHttpResponse}, sets the
* {@code Content-Length}, writes it and closes the connection afterwards.
*
* @param ctx the channel context
* @param res the response to send
*/
private void send(ChannelHandlerContext ctx, Response res) { private void send(ChannelHandlerContext ctx, Response res) {
var nettyRes = new DefaultFullHttpResponse( var nettyRes = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpVersion.HTTP_1_1,
@@ -193,8 +276,14 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler<FullHt
ctx.writeAndFlush(nettyRes).addListener(ChannelFutureListener.CLOSE); ctx.writeAndFlush(nettyRes).addListener(ChannelFutureListener.CLOSE);
} }
/**
* Closes the channel on any unhandled pipeline exception.
*
* @param ctx the channel context
* @param cause the exception that propagated up the pipeline
*/
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ctx.close(); ctx.close();
} }
} }
@@ -19,53 +19,140 @@ import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.HttpServerCodec;
/**
* Bootstraps and runs the Netty-based HTTP (and optionally WebSocket) server.
*
* <p>The class doubles as a small fluent builder: {@link #builder(int, Router)} creates an
* instance bound to a port and {@link Router}, and the {@code withXxx} methods attach optional
* features (CORS, rate limiting, WebSockets) before {@link #start()} launches the server.</p>
*
* <p>At start-up it selects the most efficient available transport &mdash; {@code epoll} on
* Linux, {@code kqueue} on macOS/BSD, or the portable NIO transport otherwise &mdash; and wires
* up the Netty channel pipeline (codec, aggregator and the {@link HttpRequestHandler}). The
* {@link #start()} call blocks until the server channel is closed.</p>
*/
public final class HttpServer { public final class HttpServer {
/** TCP port the server binds to. */
private final int port; private final int port;
/** Router resolving requests to handlers. */
private final Router router; private final Router router;
/** Optional CORS handler; {@code null} disables CORS handling. */
private CorsHandler cors; private CorsHandler cors;
/** Optional rate-limit gate; {@code null} disables rate limiting. */
private RateLimitGate gate; private RateLimitGate gate;
/** Optional WebSocket router; {@code null} disables WebSocket support. */
private WebSocketRouter wsRouter; private WebSocketRouter wsRouter;
/** WebSocket configuration; only used when {@link #wsRouter} is set. */
private WebSocketConfig wsConfig; private WebSocketConfig wsConfig;
/**
* Creates a server bound to a port and router. Use {@link #builder(int, Router)} instead of
* calling this directly.
*
* @param port the TCP port to bind
* @param router the router resolving requests
*/
private HttpServer(int port, Router router) { private HttpServer(int port, Router router) {
this.port = port; this.port = port;
this.router = router; this.router = router;
} }
/**
* Starts building a server for the given port and router.
*
* @param port the TCP port to bind
* @param router the router resolving requests
* @return a new, configurable {@code HttpServer} instance
*/
public static HttpServer builder(int port, Router router) { public static HttpServer builder(int port, Router router) {
return new HttpServer(port, router); return new HttpServer(port, router);
} }
/**
* Attaches a CORS handler that decorates responses and answers preflight requests.
*
* @param cors the CORS handler to use
* @return this instance, for fluent chaining
*/
public HttpServer withCorsHandler(CorsHandler cors) { public HttpServer withCorsHandler(CorsHandler cors) {
this.cors = cors; this.cors = cors;
return this; return this;
} }
/**
* Attaches a rate-limit gate that throttles incoming requests.
*
* @param gate the rate-limit gate to use
* @return this instance, for fluent chaining
*/
public HttpServer withRateLimitGate(RateLimitGate gate) { public HttpServer withRateLimitGate(RateLimitGate gate) {
this.gate = gate; this.gate = gate;
return this; return this;
} }
/**
* Enables WebSocket support with default configuration.
*
* @param wsRouter the WebSocket router resolving upgrade paths to handlers
* @return this instance, for fluent chaining
* @see #withWebSockets(WebSocketRouter, WebSocketConfig)
*/
public HttpServer withWebSockets(WebSocketRouter wsRouter) { public HttpServer withWebSockets(WebSocketRouter wsRouter) {
return withWebSockets(wsRouter, WebSocketConfig.defaults()); return withWebSockets(wsRouter, WebSocketConfig.defaults());
} }
/**
* Enables WebSocket support with explicit configuration.
*
* @param wsRouter the WebSocket router resolving upgrade paths to handlers
* @param wsConfig the WebSocket configuration (frame sizes, timeouts, origins, ...)
* @return this instance, for fluent chaining
*/
public HttpServer withWebSockets(WebSocketRouter wsRouter, WebSocketConfig wsConfig) { public HttpServer withWebSockets(WebSocketRouter wsRouter, WebSocketConfig wsConfig) {
this.wsRouter = wsRouter; this.wsRouter = wsRouter;
this.wsConfig = wsConfig; this.wsConfig = wsConfig;
return this; return this;
} }
/**
* Starts the server using the configuration accumulated on this instance and blocks until
* the server channel closes.
*
* @throws InterruptedException if the binding or close-future wait is interrupted
*/
public void start() throws InterruptedException { public void start() throws InterruptedException {
start(port, router, cors, gate, wsRouter, wsConfig); start(port, router, cors, gate, wsRouter, wsConfig);
} }
/**
* Starts a server without WebSocket support. Convenience overload of
* {@link #start(int, Router, CorsHandler, RateLimitGate, WebSocketRouter, WebSocketConfig)}.
*
* @param port the TCP port to bind
* @param router the router resolving requests
* @param cors the CORS handler, or {@code null} to disable CORS
* @param gate the rate-limit gate, or {@code null} to disable rate limiting
* @throws InterruptedException if the binding or close-future wait is interrupted
*/
public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate) public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate)
throws InterruptedException { throws InterruptedException {
start(port, router, cors, gate, null, null); start(port, router, cors, gate, null, null);
} }
/**
* Starts the server, selecting the best transport for the platform, configuring the Netty
* channel pipeline and binding the port. The call blocks until the server channel is closed,
* after which the event-loop groups are shut down gracefully.
*
* @param port the TCP port to bind
* @param router the router resolving requests
* @param cors the CORS handler, or {@code null} to disable CORS
* @param gate the rate-limit gate, or {@code null} to disable rate limiting
* @param wsRouter the WebSocket router, or {@code null} to disable WebSocket support
* @param wsConfig the WebSocket configuration, used only when {@code wsRouter} is non-null
* @throws InterruptedException if the binding or close-future wait is interrupted
*/
public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate, public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate,
WebSocketRouter wsRouter, WebSocketConfig wsConfig) WebSocketRouter wsRouter, WebSocketConfig wsConfig)
throws InterruptedException { throws InterruptedException {
@@ -9,15 +9,63 @@ import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.lang.reflect.Method; import java.lang.reflect.Method;
/**
* Reflective registrar that wires the routing annotations on a controller object into a
* {@link Router}.
*
* <p>Given a controller instance, the scanner reads the optional {@link Controller} annotation
* to determine a path prefix, then walks every declared method looking for one of the
* supported route annotations ({@link Route}, {@link GET}, {@link POST}, {@link PUT},
* {@link DELETE}, {@link PATCH} or {@link CUSTOM}). For each matching method it:</p>
* <ol>
* <li>validates that the method has the required {@code (Request, Response)} signature and
* a {@code void} return type;</li>
* <li>creates a {@link MethodHandle} bound to the controller instance for fast,
* reflection-free invocation;</li>
* <li>registers a {@link Router.Handler} that delegates to that handle under the resolved
* HTTP method and full path.</li>
* </ol>
*
* <p>This class is a stateless utility and cannot be instantiated.</p>
*
* @see Controller
* @see Router
*/
public final class AnnotationScanner { public final class AnnotationScanner {
/**
* Shared lookup used to unreflect controller methods into {@link MethodHandle}s. A single
* lookup is sufficient because the scanner forces accessibility on each method before
* unreflecting it.
*/
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
/**
* The exact method type every handler must conform to: {@code void (Request, Response)}.
* Used as documentation of the contract enforced by {@link #validateSignature(Method)}.
*/
private static final MethodType HANDLER_TYPE = private static final MethodType HANDLER_TYPE =
MethodType.methodType(void.class, Request.class, Response.class); MethodType.methodType(void.class, Request.class, Response.class);
/**
* Private constructor preventing instantiation of this stateless utility class.
*/
private AnnotationScanner() { private AnnotationScanner() {
} }
/**
* Scans the given controller for route annotations and registers every discovered handler
* with the supplied router.
*
* <p>If the controller class is annotated with {@link Controller}, its value is used as a
* path prefix for all routes. Methods without a recognised route annotation are ignored.
* A line describing each registered route is printed to standard output.</p>
*
* @param router the router to register the discovered handlers with
* @param controller the controller instance whose annotated methods should be registered
* @throws IllegalArgumentException if an annotated method has an invalid signature
* @throws RuntimeException if a method cannot be made accessible or unreflected
*/
public static void register(Router router, Object controller) { public static void register(Router router, Object controller) {
Class<?> clazz = controller.getClass(); Class<?> clazz = controller.getClass();
Controller ctrlAnno = clazz.getAnnotation(Controller.class); Controller ctrlAnno = clazz.getAnnotation(Controller.class);
@@ -55,11 +103,27 @@ public final class AnnotationScanner {
} }
} }
/**
* Normalizes a controller-level path prefix by ensuring it starts with a single leading
* slash.
*
* @param p the raw prefix from the {@link Controller} annotation, may be {@code null} or empty
* @return the normalized prefix, or an empty string if {@code p} is {@code null} or empty
*/
private static String normalizePrefix(String p) { private static String normalizePrefix(String p) {
if (p == null || p.isEmpty()) return ""; if (p == null || p.isEmpty()) return "";
return p.startsWith("/") ? p : "/" + p; return p.startsWith("/") ? p : "/" + p;
} }
/**
* Extracts route metadata (HTTP method and path) from a method by inspecting the supported
* route annotations in priority order. {@link Route} is checked first, followed by the
* verb-specific annotations and finally {@link CUSTOM}.
*
* @param m the method to inspect
* @return a {@link RouteInfo} describing the route, or {@code null} if the method carries
* no recognised route annotation
*/
private static RouteInfo extractRoute(Method m) { private static RouteInfo extractRoute(Method m) {
Route r = m.getAnnotation(Route.class); Route r = m.getAnnotation(Route.class);
if (r != null) return new RouteInfo(r.method(), r.path()); if (r != null) return new RouteInfo(r.method(), r.path());
@@ -75,16 +139,24 @@ public final class AnnotationScanner {
DELETE del = m.getAnnotation(DELETE.class); DELETE del = m.getAnnotation(DELETE.class);
if (del != null) return new RouteInfo("DELETE", del.value()); if (del != null) return new RouteInfo("DELETE", del.value());
PATCH patch = m.getAnnotation(PATCH.class); PATCH patch = m.getAnnotation(PATCH.class);
if (patch != null) return new RouteInfo("PATCH", patch.value()); if (patch != null) return new RouteInfo("PATCH", patch.value());
CUSTOM custom = m.getAnnotation(CUSTOM.class); CUSTOM custom = m.getAnnotation(CUSTOM.class);
if (custom != null) return new RouteInfo(custom.method(), custom.value()); if (custom != null) return new RouteInfo(custom.method(), custom.value());
return null; return null;
} }
/**
* Validates that a handler method conforms to the required {@code void (Request, Response)}
* contract.
*
* @param m the method to validate
* @throws IllegalArgumentException if the method does not take exactly a {@link Request}
* and a {@link Response}, or does not return {@code void}
*/
private static void validateSignature(Method m) { private static void validateSignature(Method m) {
Class<?>[] params = m.getParameterTypes(); Class<?>[] params = m.getParameterTypes();
if (params.length != 2 || params[0] != Request.class || params[1] != Response.class) { if (params.length != 2 || params[0] != Request.class || params[1] != Response.class) {
@@ -95,11 +167,23 @@ public final class AnnotationScanner {
} }
} }
/**
* Normalizes a route-level path by ensuring it starts with a single leading slash.
*
* @param p the raw path from a route annotation, may be {@code null} or empty
* @return the normalized path, or an empty string if {@code p} is {@code null} or empty
*/
private static String normalizePath(String p) { private static String normalizePath(String p) {
if (p == null || p.isEmpty()) return ""; if (p == null || p.isEmpty()) return "";
return p.startsWith("/") ? p : "/" + p; return p.startsWith("/") ? p : "/" + p;
} }
/**
* Immutable carrier for the HTTP method and path extracted from a route annotation.
*
* @param method the HTTP method name (e.g. {@code "GET"})
* @param path the route path relative to the controller prefix
*/
private record RouteInfo(String method, String path) { private record RouteInfo(String method, String path) {
} }
} }
@@ -5,9 +5,39 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Binds a controller method to a route using a custom or non-standard HTTP method. Whereas
* {@link GET}, {@link POST} and the other verb annotations hard-code the verb, {@code @CUSTOM}
* lets the caller name the verb explicitly via {@link #method()} (for example {@code "HEAD"},
* {@code "OPTIONS"} or a WebDAV-style verb).
*
* <p>The annotated method must have the signature {@code void handler(Request, Response)},
* which the {@link AnnotationScanner} verifies during registration. The route path given by
* {@link #value()} is combined with any {@link Controller#value() controller prefix}.</p>
*
* <p>Retained at {@link RetentionPolicy#RUNTIME runtime} for reflective scanning and only
* applicable to {@link ElementType#METHOD methods}.</p>
*
* @see Route
* @see AnnotationScanner
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface CUSTOM { public @interface CUSTOM {
/**
* The HTTP method name this route responds to. Must be a value accepted by
* {@link io.netty.handler.codec.http.HttpMethod#valueOf(String)}.
*
* @return the HTTP method name
*/
String method(); String method();
/**
* The path this route is mounted at, relative to any controller prefix. Supports
* {@code {param}} path parameters and {@code *} wildcards.
*
* @return the route path
*/
String value(); String value();
} }
@@ -5,8 +5,36 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Marks a class as a <em>controller</em>: a container for HTTP handler methods that are
* discovered and wired into the {@link dev.coph.nextusweb.server.router.Router Router} at
* runtime by the {@link AnnotationScanner}.
*
* <p>The optional {@link #value()} acts as a common path prefix that is prepended to every
* route declared inside the annotated class. For example, a controller annotated with
* {@code @Controller("/api")} whose method is annotated with {@code @GET("/users")} will be
* registered under {@code /api/users}.</p>
*
* <p>This annotation is retained at {@link RetentionPolicy#RUNTIME runtime} because the
* scanner inspects it reflectively while the application is running, and it may only be
* placed on {@link ElementType#TYPE types} (classes).</p>
*
* @see AnnotationScanner
* @see Route
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE) @Target(ElementType.TYPE)
public @interface Controller { public @interface Controller {
/**
* The base path prefix that is prepended to every route declared in the annotated
* controller class.
*
* <p>A leading slash is optional; the scanner normalizes the value so that
* {@code "api"} and {@code "/api"} behave identically. The default empty string means
* the controller contributes no prefix and its routes are registered as-is.</p>
*
* @return the path prefix, or an empty string for no prefix
*/
String value() default ""; String value() default "";
} }
@@ -5,8 +5,29 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Binds a controller method to an HTTP {@code DELETE} route. This is a convenience shorthand
* for {@link Route @Route(method = "DELETE", path = ...)}.
*
* <p>The annotated method must have the signature {@code void handler(Request, Response)},
* which the {@link AnnotationScanner} verifies during registration. The route path given by
* {@link #value()} is combined with any {@link Controller#value() controller prefix}.</p>
*
* <p>Retained at {@link RetentionPolicy#RUNTIME runtime} for reflective scanning and only
* applicable to {@link ElementType#METHOD methods}.</p>
*
* @see Route
* @see AnnotationScanner
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface DELETE { public @interface DELETE {
/**
* The path this {@code DELETE} route is mounted at, relative to any controller prefix.
* Supports {@code {param}} path parameters and {@code *} wildcards.
*
* @return the route path
*/
String value(); String value();
} }
@@ -5,8 +5,29 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Binds a controller method to an HTTP {@code GET} route. This is a convenience shorthand for
* {@link Route @Route(method = "GET", path = ...)}.
*
* <p>The annotated method must have the signature {@code void handler(Request, Response)},
* which the {@link AnnotationScanner} verifies during registration. The route path given by
* {@link #value()} is combined with any {@link Controller#value() controller prefix}.</p>
*
* <p>Retained at {@link RetentionPolicy#RUNTIME runtime} for reflective scanning and only
* applicable to {@link ElementType#METHOD methods}.</p>
*
* @see Route
* @see AnnotationScanner
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface GET { public @interface GET {
/**
* The path this {@code GET} route is mounted at, relative to any controller prefix.
* Supports {@code {param}} path parameters and {@code *} wildcards.
*
* @return the route path
*/
String value(); String value();
} }
@@ -5,8 +5,29 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Binds a controller method to an HTTP {@code PATCH} route. This is a convenience shorthand
* for {@link Route @Route(method = "PATCH", path = ...)}.
*
* <p>The annotated method must have the signature {@code void handler(Request, Response)},
* which the {@link AnnotationScanner} verifies during registration. The route path given by
* {@link #value()} is combined with any {@link Controller#value() controller prefix}.</p>
*
* <p>Retained at {@link RetentionPolicy#RUNTIME runtime} for reflective scanning and only
* applicable to {@link ElementType#METHOD methods}.</p>
*
* @see Route
* @see AnnotationScanner
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface PATCH { public @interface PATCH {
/**
* The path this {@code PATCH} route is mounted at, relative to any controller prefix.
* Supports {@code {param}} path parameters and {@code *} wildcards.
*
* @return the route path
*/
String value(); String value();
} }
@@ -5,8 +5,29 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Binds a controller method to an HTTP {@code POST} route. This is a convenience shorthand for
* {@link Route @Route(method = "POST", path = ...)}.
*
* <p>The annotated method must have the signature {@code void handler(Request, Response)},
* which the {@link AnnotationScanner} verifies during registration. The route path given by
* {@link #value()} is combined with any {@link Controller#value() controller prefix}.</p>
*
* <p>Retained at {@link RetentionPolicy#RUNTIME runtime} for reflective scanning and only
* applicable to {@link ElementType#METHOD methods}.</p>
*
* @see Route
* @see AnnotationScanner
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface POST { public @interface POST {
/**
* The path this {@code POST} route is mounted at, relative to any controller prefix.
* Supports {@code {param}} path parameters and {@code *} wildcards.
*
* @return the route path
*/
String value(); String value();
} }
@@ -5,8 +5,29 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Binds a controller method to an HTTP {@code PUT} route. This is a convenience shorthand for
* {@link Route @Route(method = "PUT", path = ...)}.
*
* <p>The annotated method must have the signature {@code void handler(Request, Response)},
* which the {@link AnnotationScanner} verifies during registration. The route path given by
* {@link #value()} is combined with any {@link Controller#value() controller prefix}.</p>
*
* <p>Retained at {@link RetentionPolicy#RUNTIME runtime} for reflective scanning and only
* applicable to {@link ElementType#METHOD methods}.</p>
*
* @see Route
* @see AnnotationScanner
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface PUT { public @interface PUT {
/**
* The path this {@code PUT} route is mounted at, relative to any controller prefix.
* Supports {@code {param}} path parameters and {@code *} wildcards.
*
* @return the route path
*/
String value(); String value();
} }
@@ -5,10 +5,41 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Generic route declaration that binds a controller method to an arbitrary HTTP method and
* path. This is the most flexible of the routing annotations: where {@link GET}, {@link POST}
* and friends hard-code the HTTP verb, {@code @Route} lets the verb be specified explicitly
* via {@link #method()}.
*
* <p>Handler methods carrying this annotation must follow the signature
* {@code void handler(Request, Response)}; this is enforced by the {@link AnnotationScanner}
* when the route is registered.</p>
*
* <p>The annotation is retained at {@link RetentionPolicy#RUNTIME runtime} so the scanner can
* read it reflectively, and may only be placed on {@link ElementType#METHOD methods}.</p>
*
* @see AnnotationScanner
* @see Controller
*/
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
public @interface Route { public @interface Route {
/**
* The HTTP method (verb) this route responds to, for example {@code "GET"} or
* {@code "POST"}. The value must match a name accepted by
* {@link io.netty.handler.codec.http.HttpMethod#valueOf(String)}.
*
* @return the HTTP method name
*/
String method(); String method();
/**
* The path this route is mounted at, relative to any {@link Controller#value() controller
* prefix}. Path segments wrapped in braces (e.g. {@code /users/{id}}) denote path
* parameters, and a {@code *} segment denotes a wildcard.
*
* @return the route path
*/
String path(); String path();
} }
@@ -6,16 +6,42 @@ import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
/**
* Immutable configuration describing the Cross-Origin Resource Sharing (CORS) policy the
* server enforces. Instances are created through the nested {@link Builder} and consumed by
* {@link CorsHandler} to decide which origins, methods and headers are permitted.
*
* <p>As a safety measure the configuration forbids combining a wildcard origin
* ({@link #allowAnyOrigin()}) with {@link #allowCredentials() credentialed requests}, which
* the CORS specification disallows.</p>
*
* @see CorsHandler
*/
public final class CorsConfig { public final class CorsConfig {
/** Explicit set of allowed origins; ignored when {@link #allowAnyOrigin} is {@code true}. */
private final Set<String> allowedOrigins; private final Set<String> allowedOrigins;
/** HTTP methods advertised as allowed in preflight responses. */
private final Set<HttpMethod> allowedMethods; private final Set<HttpMethod> allowedMethods;
/** Request headers advertised as allowed in preflight responses. */
private final Set<String> allowedHeaders; private final Set<String> allowedHeaders;
/** Response headers exposed to the browser via {@code Access-Control-Expose-Headers}. */
private final Set<String> exposedHeaders; private final Set<String> exposedHeaders;
/** Whether credentialed (cookie/authorization) requests are permitted. */
private final boolean allowCredentials; private final boolean allowCredentials;
/** How long (in seconds) a preflight response may be cached by the browser. */
private final long maxAgeSeconds; private final long maxAgeSeconds;
/** Whether any origin is allowed (the {@code *} wildcard). */
private final boolean allowAnyOrigin; private final boolean allowAnyOrigin;
/**
* Builds an immutable configuration from a {@link Builder}, defensively copying its
* collections.
*
* @param b the builder carrying the configured values
* @throws IllegalStateException if a wildcard origin is combined with
* {@code allowCredentials = true}
*/
private CorsConfig(Builder b) { private CorsConfig(Builder b) {
this.allowedOrigins = Set.copyOf(b.allowedOrigins); this.allowedOrigins = Set.copyOf(b.allowedOrigins);
this.allowedMethods = Set.copyOf(b.allowedMethods); this.allowedMethods = Set.copyOf(b.allowedMethods);
@@ -32,6 +58,14 @@ public final class CorsConfig {
} }
} }
/**
* Creates a permissive, development-friendly configuration that allows any origin, the
* common HTTP methods, a handful of common headers and a one-hour preflight cache.
*
* <p>Because it allows any origin it intentionally does not enable credentials.</p>
*
* @return a ready-to-use permissive configuration
*/
public static CorsConfig permissive() { public static CorsConfig permissive() {
return builder() return builder()
.anyOrigin() .anyOrigin()
@@ -42,86 +76,194 @@ public final class CorsConfig {
.build(); .build();
} }
/**
* Creates a new, empty {@link Builder}.
*
* @return a fresh builder
*/
public static Builder builder() { public static Builder builder() {
return new Builder(); return new Builder();
} }
/**
* Tests whether a given request origin is permitted by this policy.
*
* @param origin the {@code Origin} header value, may be {@code null}
* @return {@code true} if the origin is allowed; {@code false} for a {@code null} origin or
* one not in the allow-list (unless any origin is permitted)
*/
public boolean isOriginAllowed(String origin) { public boolean isOriginAllowed(String origin) {
if (origin == null) return false; if (origin == null) return false;
if (allowAnyOrigin) return true; if (allowAnyOrigin) return true;
return allowedOrigins.contains(origin); return allowedOrigins.contains(origin);
} }
/**
* Returns the HTTP methods advertised as allowed in preflight responses.
*
* @return the immutable set of allowed HTTP methods
*/
public Set<HttpMethod> allowedMethods() { public Set<HttpMethod> allowedMethods() {
return allowedMethods; return allowedMethods;
} }
/**
* Returns the request headers advertised as allowed in preflight responses.
*
* @return the immutable set of allowed request headers
*/
public Set<String> allowedHeaders() { public Set<String> allowedHeaders() {
return allowedHeaders; return allowedHeaders;
} }
/**
* Returns the response headers that browsers are permitted to read.
*
* @return the immutable set of response headers exposed to the browser
*/
public Set<String> exposedHeaders() { public Set<String> exposedHeaders() {
return exposedHeaders; return exposedHeaders;
} }
/**
* Indicates whether credentialed (cookie/authorization) requests are permitted.
*
* @return {@code true} if credentialed requests are permitted
*/
public boolean allowCredentials() { public boolean allowCredentials() {
return allowCredentials; return allowCredentials;
} }
/**
* Returns how long a preflight response may be cached by the browser.
*
* @return the preflight cache lifetime in seconds ({@code 0} disables the header)
*/
public long maxAgeSeconds() { public long maxAgeSeconds() {
return maxAgeSeconds; return maxAgeSeconds;
} }
/**
* Indicates whether requests from any origin are permitted.
*
* @return {@code true} if any origin is permitted
*/
public boolean allowAnyOrigin() { public boolean allowAnyOrigin() {
return allowAnyOrigin; return allowAnyOrigin;
} }
/**
* Fluent builder for {@link CorsConfig}. All collection setters are additive, so they may
* be called multiple times to accumulate values.
*/
public static final class Builder { public static final class Builder {
/** Accumulated explicit origins. */
private final Set<String> allowedOrigins = new HashSet<>(); private final Set<String> allowedOrigins = new HashSet<>();
/** Accumulated allowed methods. */
private final Set<HttpMethod> allowedMethods = new HashSet<>(); private final Set<HttpMethod> allowedMethods = new HashSet<>();
/** Accumulated allowed request headers. */
private final Set<String> allowedHeaders = new HashSet<>(); private final Set<String> allowedHeaders = new HashSet<>();
/** Accumulated exposed response headers. */
private final Set<String> exposedHeaders = new HashSet<>(); private final Set<String> exposedHeaders = new HashSet<>();
/** Whether credentialed requests are permitted; defaults to {@code false}. */
private boolean allowCredentials = false; private boolean allowCredentials = false;
/** Preflight cache lifetime in seconds; defaults to {@code 0} (disabled). */
private long maxAgeSeconds = 0; private long maxAgeSeconds = 0;
/** Whether any origin is permitted; defaults to {@code false}. */
private boolean allowAnyOrigin = false; private boolean allowAnyOrigin = false;
/**
* Creates a builder with no origins, methods or headers configured and all flags at
* their defaults. Obtain instances via {@link CorsConfig#builder()}.
*/
public Builder() {
}
/**
* Adds one or more explicit origins to the allow-list.
*
* @param origins the origins to allow
* @return this builder, for fluent chaining
*/
public Builder allowedOrigins(String... origins) { public Builder allowedOrigins(String... origins) {
Collections.addAll(allowedOrigins, origins); Collections.addAll(allowedOrigins, origins);
return this; return this;
} }
/**
* Allows requests from any origin (the {@code *} wildcard). Cannot be combined with
* {@link #allowCredentials(boolean) credentials}.
*
* @return this builder, for fluent chaining
*/
public Builder anyOrigin() { public Builder anyOrigin() {
this.allowAnyOrigin = true; this.allowAnyOrigin = true;
return this; return this;
} }
/**
* Adds one or more allowed HTTP methods.
*
* @param ms the methods to allow
* @return this builder, for fluent chaining
*/
public Builder allowedMethods(HttpMethod... ms) { public Builder allowedMethods(HttpMethod... ms) {
Collections.addAll(allowedMethods, ms); Collections.addAll(allowedMethods, ms);
return this; return this;
} }
/**
* Adds one or more allowed request headers.
*
* @param hs the request headers to allow
* @return this builder, for fluent chaining
*/
public Builder allowedHeaders(String... hs) { public Builder allowedHeaders(String... hs) {
Collections.addAll(allowedHeaders, hs); Collections.addAll(allowedHeaders, hs);
return this; return this;
} }
/**
* Adds one or more response headers to expose to the browser.
*
* @param hs the response headers to expose
* @return this builder, for fluent chaining
*/
public Builder exposedHeaders(String... hs) { public Builder exposedHeaders(String... hs) {
Collections.addAll(exposedHeaders, hs); Collections.addAll(exposedHeaders, hs);
return this; return this;
} }
/**
* Sets whether credentialed requests are permitted.
*
* @param v {@code true} to allow credentials
* @return this builder, for fluent chaining
*/
public Builder allowCredentials(boolean v) { public Builder allowCredentials(boolean v) {
this.allowCredentials = v; this.allowCredentials = v;
return this; return this;
} }
/**
* Sets the preflight cache lifetime in seconds.
*
* @param s the max-age in seconds ({@code 0} disables the header)
* @return this builder, for fluent chaining
*/
public Builder maxAgeSeconds(long s) { public Builder maxAgeSeconds(long s) {
this.maxAgeSeconds = s; this.maxAgeSeconds = s;
return this; return this;
} }
/**
* Builds the immutable {@link CorsConfig}.
*
* @return the configured instance
* @throws IllegalStateException if any origin is combined with credentials
*/
public CorsConfig build() { public CorsConfig build() {
return new CorsConfig(this); return new CorsConfig(this);
} }
} }
} }
@@ -5,13 +5,38 @@ import io.netty.handler.codec.http.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/**
* Applies a {@link CorsConfig} to outgoing responses and handles CORS preflight requests.
*
* <p>The handler pre-computes the comma-separated header strings derived from the
* configuration (allowed methods, allowed headers, exposed headers) once at construction time
* so they need not be rebuilt for every request. It then offers two entry points:</p>
* <ul>
* <li>{@link #applyHeaders(String, Response)} decorates a normal response with the
* appropriate {@code Access-Control-*} headers;</li>
* <li>{@link #handlePreflight(String, HttpHeaders)} produces a complete response for an
* {@code OPTIONS} preflight request.</li>
* </ul>
*
* @see CorsConfig
*/
public final class CorsHandler { public final class CorsHandler {
/** The policy this handler enforces. */
private final CorsConfig config; private final CorsConfig config;
/** Pre-joined {@code Access-Control-Allow-Methods} value. */
private final String allowedMethodsHeader; private final String allowedMethodsHeader;
/** Pre-joined {@code Access-Control-Allow-Headers} value. */
private final String allowedHeadersHeader; private final String allowedHeadersHeader;
/** Pre-joined {@code Access-Control-Expose-Headers} value. */
private final String exposedHeadersHeader; private final String exposedHeadersHeader;
/**
* Creates a handler for the given configuration, pre-computing the header strings it will
* emit.
*
* @param config the CORS policy to enforce
*/
public CorsHandler(CorsConfig config) { public CorsHandler(CorsConfig config) {
this.config = config; this.config = config;
this.allowedMethodsHeader = config.allowedMethods().stream().map(HttpMethod::name).collect(Collectors.joining(", ")); this.allowedMethodsHeader = config.allowedMethods().stream().map(HttpMethod::name).collect(Collectors.joining(", "));
@@ -19,9 +44,21 @@ public final class CorsHandler {
this.exposedHeadersHeader = String.join(", ", config.exposedHeaders()); this.exposedHeadersHeader = String.join(", ", config.exposedHeaders());
} }
/**
* Adds the {@code Access-Control-Allow-Origin} (and related) headers to a response, if and
* only if the request carried an allowed {@code Origin}.
*
* <p>For wildcard, credential-less policies a literal {@code *} is emitted; otherwise the
* concrete origin is echoed back together with a {@code Vary: Origin} header so caches key
* on the origin. Requests without an origin or with a disallowed origin are left
* untouched.</p>
*
* @param origin the request's {@code Origin} header, may be {@code null}
* @param res the response to decorate
*/
public void applyHeaders(String origin, Response res) { public void applyHeaders(String origin, Response res) {
if (origin == null) return; if (origin == null) return;
if (!config.isOriginAllowed(origin)) return; if (!config.isOriginAllowed(origin)) return;
@@ -40,12 +77,33 @@ public final class CorsHandler {
res.header("Access-Control-Expose-Headers", exposedHeadersHeader); res.header("Access-Control-Expose-Headers", exposedHeadersHeader);
} }
} }
/**
* Determines whether a request is a CORS preflight request, i.e. an {@code OPTIONS}
* request carrying an {@code Access-Control-Request-Method} header.
*
* @param method the request method
* @param headers the request headers
* @return {@code true} if the request is a preflight request
*/
public boolean isPreflight(HttpMethod method, HttpHeaders headers) { public boolean isPreflight(HttpMethod method, HttpHeaders headers) {
return method.equals(HttpMethod.OPTIONS) return method.equals(HttpMethod.OPTIONS)
&& headers.contains("Access-Control-Request-Method"); && headers.contains("Access-Control-Request-Method");
} }
/**
* Builds the response to a CORS preflight request.
*
* <p>If the origin is missing or disallowed the response is a {@code 403 Forbidden};
* otherwise it is a {@code 204 No Content} carrying the allowed methods and headers, the
* requested headers echoed back when no explicit allow-list is configured, and the
* {@code Access-Control-Max-Age} cache hint when configured.</p>
*
* @param origin the request's {@code Origin} header, may be {@code null}
* @param requestHeaders the request's headers (used to read
* {@code Access-Control-Request-Headers})
* @return the fully populated preflight response
*/
public Response handlePreflight(String origin, HttpHeaders requestHeaders) { public Response handlePreflight(String origin, HttpHeaders requestHeaders) {
Response res = new Response().status(204); Response res = new Response().status(204);
@@ -69,4 +127,4 @@ public final class CorsHandler {
return res; return res;
} }
} }
@@ -3,10 +3,28 @@ package dev.coph.nextusweb.server.json;
import tools.jackson.databind.ObjectMapper; import tools.jackson.databind.ObjectMapper;
/**
* Holder for the application-wide Jackson {@link ObjectMapper}.
*
* <p>A single, pre-configured mapper instance is shared across the whole server because
* {@code ObjectMapper} is thread-safe once configured and is relatively expensive to build.
* Centralizing it here ensures every component (request parsing, response serialization,
* WebSocket payloads) uses identical serialization settings.</p>
*
* <p>This class is a static holder and cannot be instantiated.</p>
*/
public final class JsonMapper { public final class JsonMapper {
/**
* The shared, thread-safe Jackson mapper used throughout the server for all JSON reading
* and writing.
*/
public static final ObjectMapper MAPPER = tools.jackson.databind.json.JsonMapper.builder() public static final ObjectMapper MAPPER = tools.jackson.databind.json.JsonMapper.builder()
// .addModule(new JavaTimeModule()) // .addModule(new JavaTimeModule())
.build(); .build();
/**
* Private constructor preventing instantiation of this static holder class.
*/
private JsonMapper() {} private JsonMapper() {}
} }
@@ -3,37 +3,89 @@ package dev.coph.nextusweb.server.ratelimit;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
/**
* A {@link RateLimiter} implementing the <em>fixed window</em> counter algorithm.
*
* <p>Time is divided into consecutive windows of {@code windowMillis} length. Each key may make
* up to {@code limit} requests within a window; the counter resets to zero when a new window
* begins. This is the simplest counting strategy but can permit up to twice the limit across a
* window boundary (the "burst at the edge" problem) &mdash; see {@link SlidingWindowLimiter}
* for a smoother variant.</p>
*
* <p>Window state is held in {@link AtomicLong}s, making the limiter safe for concurrent
* use.</p>
*/
public final class FixedWindowLimiter implements RateLimiter { public final class FixedWindowLimiter implements RateLimiter {
/** Maximum number of requests permitted per window. */
private final long limit; private final long limit;
/** Window length in nanoseconds. */
private final long windowNanos; private final long windowNanos;
/** Per-key windows, created on demand. */
private final ConcurrentHashMap<String, Window> windows = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, Window> windows = new ConcurrentHashMap<>();
/**
* Creates a fixed-window limiter.
*
* @param limit the maximum number of requests per window
* @param windowMillis the window length in milliseconds
*/
public FixedWindowLimiter(long limit, long windowMillis) { public FixedWindowLimiter(long limit, long windowMillis) {
this.limit = limit; this.limit = limit;
this.windowNanos = windowMillis * 1_000_000L; this.windowNanos = windowMillis * 1_000_000L;
} }
/**
* {@inheritDoc}
*
* <p>Lazily creates the window for {@code key} and counts this request against it.</p>
*/
@Override @Override
public Result tryAcquire(String key, long nowNanos) { public Result tryAcquire(String key, long nowNanos) {
Window w = windows.computeIfAbsent(key, k -> new Window(nowNanos)); Window w = windows.computeIfAbsent(key, k -> new Window(nowNanos));
return w.tryAcquire(nowNanos, limit, windowNanos); return w.tryAcquire(nowNanos, limit, windowNanos);
} }
/**
* Evicts windows whose start time is older than the given age.
*
* @param olderThanNanos maximum age in nanoseconds before a window is removed
*/
public void cleanup(long olderThanNanos) { public void cleanup(long olderThanNanos) {
long now = System.nanoTime(); long now = System.nanoTime();
windows.entrySet().removeIf(e -> now - e.getValue().windowStart.get() > olderThanNanos); windows.entrySet().removeIf(e -> now - e.getValue().windowStart.get() > olderThanNanos);
} }
/**
* A single client's fixed window, tracking the window start time and the request count
* within it.
*/
private static final class Window { private static final class Window {
/** Start timestamp of the current window, in nanoseconds. */
final AtomicLong windowStart; final AtomicLong windowStart;
/** Number of requests counted in the current window. */
final AtomicLong count; final AtomicLong count;
/**
* Creates a window starting at the given time with a zero count.
*
* @param now the window start timestamp in nanoseconds
*/
Window(long now) { Window(long now) {
this.windowStart = new AtomicLong(now); this.windowStart = new AtomicLong(now);
this.count = new AtomicLong(0); this.count = new AtomicLong(0);
} }
/**
* Rolls the window over if it has expired, then counts this request and decides whether
* it stays within the limit.
*
* @param now the current time in nanoseconds
* @param limit the per-window request limit
* @param windowNanos the window length in nanoseconds
* @return an allow result with the remaining quota, or a deny result with the time until
* the window resets
*/
Result tryAcquire(long now, long limit, long windowNanos) { Result tryAcquire(long now, long limit, long windowNanos) {
long start = windowStart.get(); long start = windowStart.get();
if (now - start >= windowNanos) { if (now - start >= windowNanos) {
@@ -50,4 +102,4 @@ public final class FixedWindowLimiter implements RateLimiter {
return Result.allow(limit - current, limit); return Result.allow(limit - current, limit);
} }
} }
} }
@@ -2,10 +2,33 @@ package dev.coph.nextusweb.server.ratelimit;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
/**
* Strategy for deriving the logical key under which a request is rate limited. The key
* determines which bucket a request counts against — for example one bucket per client IP, or
* one per authenticated user.
*
* <p>Two ready-made resolvers are provided as factory methods: {@link #clientIp()} and
* {@link #userOrIp()}.</p>
*/
@FunctionalInterface @FunctionalInterface
public interface KeyResolver { public interface KeyResolver {
/**
* Resolves the rate-limit key for a request.
*
* @param req the incoming HTTP request, used to inspect headers
* @param remoteAddress the transport-level remote address, used as a fallback
* @return the key the request should be counted against
*/
String resolve(HttpRequest req, String remoteAddress); String resolve(HttpRequest req, String remoteAddress);
/**
* Returns a resolver that keys on the client IP address. It prefers the first entry of the
* {@code X-Forwarded-For} header (so it works behind a reverse proxy) and falls back to the
* transport-level remote address when that header is absent.
*
* @return a client-IP key resolver
*/
static KeyResolver clientIp() { static KeyResolver clientIp() {
return (req, remote) -> { return (req, remote) -> {
String forwarded = req.headers().get("X-Forwarded-For"); String forwarded = req.headers().get("X-Forwarded-For");
@@ -17,6 +40,14 @@ public interface KeyResolver {
}; };
} }
/**
* Returns a resolver that keys on the authenticated user when possible, falling back to the
* client IP otherwise. A {@code Bearer} token from the {@code Authorization} header yields a
* {@code "u:<token>"} key; otherwise the {@code "ip:<address>"} key from {@link #clientIp()}
* is used.
*
* @return a user-or-IP key resolver
*/
static KeyResolver userOrIp() { static KeyResolver userOrIp() {
return (req, remote) -> { return (req, remote) -> {
String auth = req.headers().get("Authorization"); String auth = req.headers().get("Authorization");
@@ -26,4 +57,4 @@ public interface KeyResolver {
return "ip:" + clientIp().resolve(req, remote); return "ip:" + clientIp().resolve(req, remote);
}; };
} }
} }
@@ -3,37 +3,88 @@ package dev.coph.nextusweb.server.ratelimit;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
/**
* A {@link RateLimiter} implementing the <em>leaky bucket</em> algorithm.
*
* <p>Each key owns a bucket whose "water level" rises by one with every request and "leaks"
* back down at a fixed rate of {@code requestsPerSecond} units per second. A request is allowed
* while the (post-leak) level is below {@code capacity}; once full, requests are denied until
* enough has leaked away. Compared to the token bucket this smooths bursts into a steady
* outflow rather than allowing them through up front.</p>
*
* <p>State is held in {@link AtomicLong}s and updated with a lock-free compare-and-set loop, so
* the limiter is safe for concurrent use.</p>
*/
public final class LeakyBucketLimiter implements RateLimiter { public final class LeakyBucketLimiter implements RateLimiter {
/** Maximum water level (number of queued units) the bucket tolerates. */
private final long capacity; private final long capacity;
private final long leakIntervalNanos; /** Nanoseconds it takes for exactly one unit to leak out. */
private final long leakIntervalNanos;
/** Per-key buckets, created on demand. */
private final ConcurrentHashMap<String, LeakyBucket> buckets = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, LeakyBucket> buckets = new ConcurrentHashMap<>();
/**
* Creates a leaky-bucket limiter.
*
* @param requestsPerSecond the steady leak (drain) rate in units per second
* @param capacity the bucket capacity, i.e. the maximum tolerated backlog
*/
public LeakyBucketLimiter(long requestsPerSecond, long capacity) { public LeakyBucketLimiter(long requestsPerSecond, long capacity) {
this.capacity = capacity; this.capacity = capacity;
this.leakIntervalNanos = 1_000_000_000L / Math.max(1, requestsPerSecond); this.leakIntervalNanos = 1_000_000_000L / Math.max(1, requestsPerSecond);
} }
/**
* {@inheritDoc}
*
* <p>Lazily creates the bucket for {@code key} and attempts to add one unit of water.</p>
*/
@Override @Override
public Result tryAcquire(String key, long nowNanos) { public Result tryAcquire(String key, long nowNanos) {
LeakyBucket b = buckets.computeIfAbsent(key, k -> new LeakyBucket(nowNanos)); LeakyBucket b = buckets.computeIfAbsent(key, k -> new LeakyBucket(nowNanos));
return b.tryAcquire(nowNanos, capacity, leakIntervalNanos); return b.tryAcquire(nowNanos, capacity, leakIntervalNanos);
} }
/**
* Evicts buckets that have not leaked/been accessed within the given age.
*
* @param olderThanNanos maximum idle age in nanoseconds before a bucket is removed
*/
public void cleanup(long olderThanNanos) { public void cleanup(long olderThanNanos) {
long now = System.nanoTime(); long now = System.nanoTime();
buckets.entrySet().removeIf(e -> now - e.getValue().lastLeakNanos.get() > olderThanNanos); buckets.entrySet().removeIf(e -> now - e.getValue().lastLeakNanos.get() > olderThanNanos);
} }
/**
* A single client's leaky bucket, tracking the current water level and the timestamp up to
* which leakage has been accounted for.
*/
private static final class LeakyBucket { private static final class LeakyBucket {
/** Current water level (number of units in the bucket). */
final AtomicLong waterLevel; final AtomicLong waterLevel;
/** Timestamp, in nanoseconds, up to which leakage has been applied. */
final AtomicLong lastLeakNanos; final AtomicLong lastLeakNanos;
/**
* Creates an empty bucket.
*
* @param now the creation timestamp in nanoseconds
*/
LeakyBucket(long now) { LeakyBucket(long now) {
this.waterLevel = new AtomicLong(0); this.waterLevel = new AtomicLong(0);
this.lastLeakNanos = new AtomicLong(now); this.lastLeakNanos = new AtomicLong(now);
} }
/**
* Applies elapsed leakage and, if there is room, adds one unit of water.
*
* @param now the current time in nanoseconds
* @param capacity the bucket capacity
* @param leakIntervalNanos the nanoseconds per leaked unit
* @return an allow result with the remaining headroom, or a deny result with a retry
* hint when the bucket is full
*/
Result tryAcquire(long now, long capacity, long leakIntervalNanos) { Result tryAcquire(long now, long capacity, long leakIntervalNanos) {
while (true) { while (true) {
long lastLeak = lastLeakNanos.get(); long lastLeak = lastLeakNanos.get();
@@ -56,4 +107,4 @@ public final class LeakyBucketLimiter implements RateLimiter {
} }
} }
} }
} }
@@ -5,11 +5,35 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
/**
* Immutable mapping from request paths to the {@link Rule rate-limit rules} that apply to them.
*
* <p>Three kinds of rules can be configured, resolved with the following precedence by
* {@link #rulesFor(String)}:</p>
* <ol>
* <li>an optional <strong>global</strong> rule that applies to every request;</li>
* <li><strong>exact-path</strong> rules matched by exact path equality;</li>
* <li><strong>prefix</strong> rules matched by path prefix, evaluated longest-prefix-first.</li>
* </ol>
*
* <p>A request is subject to the global rule (if any) plus the single most specific path rule
* that matches. Instances are built through the nested {@link Builder}.</p>
*/
public final class RateLimitConfig { public final class RateLimitConfig {
/** Rule applied to every request, or {@code null} if no global rule is configured. */
private final Rule globalRule; private final Rule globalRule;
/** Rules matched by exact path equality, keyed by path. */
private final Map<String, Rule> exactPathRules; private final Map<String, Rule> exactPathRules;
/** Prefix rules, pre-sorted longest-prefix-first so the most specific match wins. */
private final List<PrefixRule> prefixRules; private final List<PrefixRule> prefixRules;
/**
* Builds an immutable configuration from a {@link Builder}, copying the exact-path rules
* and sorting the prefix rules by descending prefix length.
*
* @param b the builder carrying the configured rules
*/
private RateLimitConfig(Builder b) { private RateLimitConfig(Builder b) {
this.globalRule = b.globalRule; this.globalRule = b.globalRule;
this.exactPathRules = Map.copyOf(b.exactPathRules); this.exactPathRules = Map.copyOf(b.exactPathRules);
@@ -18,10 +42,25 @@ public final class RateLimitConfig {
.toList(); .toList();
} }
/**
* Creates a new, empty {@link Builder}.
*
* @return a fresh builder
*/
public static Builder builder() { public static Builder builder() {
return new Builder(); return new Builder();
} }
/**
* Returns the ordered list of rules that apply to the given path.
*
* <p>The list contains the global rule first (if configured) followed by at most one
* path-specific rule: the exact-path rule if one matches, otherwise the longest matching
* prefix rule. The returned list may be empty if no rule applies.</p>
*
* @param path the request path
* @return the applicable rules, in evaluation order
*/
public List<Rule> rulesFor(String path) { public List<Rule> rulesFor(String path) {
List<Rule> rules = new ArrayList<>(2); List<Rule> rules = new ArrayList<>(2);
if (globalRule != null) rules.add(globalRule); if (globalRule != null) rules.add(globalRule);
@@ -40,34 +79,90 @@ public final class RateLimitConfig {
return rules; return rules;
} }
/**
* A single rate-limit rule: a limiter, the key resolver feeding it, and a name used to
* namespace keys and aid diagnostics.
*
* @param limiter the limiter that enforces the quota
* @param keyResolver resolves the per-request key the limiter buckets on
* @param name a human-readable label (e.g. {@code "global"} or a path/prefix)
*/
public record Rule(RateLimiter limiter, KeyResolver keyResolver, String name) { public record Rule(RateLimiter limiter, KeyResolver keyResolver, String name) {
} }
/**
* Internal pairing of a path prefix with the rule that applies to paths starting with it.
*
* @param prefix the path prefix
* @param rule the rule to apply for matching paths
*/
private record PrefixRule(String prefix, Rule rule) { private record PrefixRule(String prefix, Rule rule) {
} }
/**
* Fluent builder for {@link RateLimitConfig}.
*/
public static final class Builder { public static final class Builder {
/** Accumulated exact-path rules, keyed by path. */
private final Map<String, Rule> exactPathRules = new HashMap<>(); private final Map<String, Rule> exactPathRules = new HashMap<>();
/** Accumulated prefix rules. */
private final List<PrefixRule> prefixRules = new ArrayList<>(); private final List<PrefixRule> prefixRules = new ArrayList<>();
/** The global rule, if configured. */
private Rule globalRule; private Rule globalRule;
/**
* Creates a builder with no rules configured. Obtain instances via
* {@link RateLimitConfig#builder()}.
*/
public Builder() {
}
/**
* Sets the global rule applied to every request.
*
* @param limiter the limiter enforcing the global quota
* @param keys the key resolver for the global rule
* @return this builder, for fluent chaining
*/
public Builder global(RateLimiter limiter, KeyResolver keys) { public Builder global(RateLimiter limiter, KeyResolver keys) {
this.globalRule = new Rule(limiter, keys, "global"); this.globalRule = new Rule(limiter, keys, "global");
return this; return this;
} }
/**
* Adds a rule that applies only to requests whose path equals {@code path} exactly.
*
* @param path the exact request path
* @param limiter the limiter enforcing the quota
* @param keys the key resolver for this rule
* @return this builder, for fluent chaining
*/
public Builder forPath(String path, RateLimiter limiter, KeyResolver keys) { public Builder forPath(String path, RateLimiter limiter, KeyResolver keys) {
exactPathRules.put(path, new Rule(limiter, keys, path)); exactPathRules.put(path, new Rule(limiter, keys, path));
return this; return this;
} }
/**
* Adds a rule that applies to requests whose path starts with {@code prefix}. When
* several prefixes match, the longest one wins.
*
* @param prefix the path prefix
* @param limiter the limiter enforcing the quota
* @param keys the key resolver for this rule
* @return this builder, for fluent chaining
*/
public Builder forPrefix(String prefix, RateLimiter limiter, KeyResolver keys) { public Builder forPrefix(String prefix, RateLimiter limiter, KeyResolver keys) {
prefixRules.add(new PrefixRule(prefix, new Rule(limiter, keys, prefix + "*"))); prefixRules.add(new PrefixRule(prefix, new Rule(limiter, keys, prefix + "*")));
return this; return this;
} }
/**
* Builds the immutable {@link RateLimitConfig}.
*
* @return the configured instance
*/
public RateLimitConfig build() { public RateLimitConfig build() {
return new RateLimitConfig(this); return new RateLimitConfig(this);
} }
} }
} }
@@ -8,11 +8,31 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
/**
* Request-pipeline entry point that applies a {@link RateLimitConfig} to incoming requests and
* surfaces the outcome as standard {@code X-RateLimit-*} response headers.
*
* <p>For each request the gate evaluates every {@link RateLimitConfig.Rule rule} that applies
* to the request path. If any rule denies the request, evaluation stops and that denial is
* returned; otherwise the strictest (lowest remaining) allowance is returned so the headers
* reflect the tightest applicable budget.</p>
*
* <p>A daemon background thread periodically triggers cleanup of stale limiter state. The gate
* should be {@link #shutdown() shut down} when the server stops.</p>
*/
public final class RateLimitGate { public final class RateLimitGate {
/** The rule set this gate enforces. */
private final RateLimitConfig config; private final RateLimitConfig config;
/** Single-threaded scheduler driving periodic cleanup of stale buckets. */
private final ScheduledExecutorService cleanup; private final ScheduledExecutorService cleanup;
/**
* Creates a gate for the given configuration and starts a background cleanup task that runs
* every five minutes on a daemon thread.
*
* @param config the rate-limit rules to enforce
*/
public RateLimitGate(RateLimitConfig config) { public RateLimitGate(RateLimitConfig config) {
this.config = config; this.config = config;
this.cleanup = Executors.newSingleThreadScheduledExecutor(r -> { this.cleanup = Executors.newSingleThreadScheduledExecutor(r -> {
@@ -22,8 +42,21 @@ public final class RateLimitGate {
}); });
cleanup.scheduleAtFixedRate(this::doCleanup, 5, 5, TimeUnit.MINUTES); cleanup.scheduleAtFixedRate(this::doCleanup, 5, 5, TimeUnit.MINUTES);
} }
/**
* Evaluates all rules applicable to the given path and decides whether the request may
* proceed.
*
* <p>Each rule's key is namespaced with the rule name to keep buckets from different rules
* independent. The first denial short-circuits and is returned immediately; if every rule
* allows the request, the result with the least remaining quota is returned.</p>
*
* @param req the incoming request, used by key resolvers
* @param path the request path used to select rules
* @param remoteAddress the client's remote address, used as a key-resolver fallback
* @return the limiting result, or {@code null} if no rule applies to the path
*/
public RateLimiter.Result check(HttpRequest req, String path, String remoteAddress) { public RateLimiter.Result check(HttpRequest req, String path, String remoteAddress) {
List<RateLimitConfig.Rule> rules = config.rulesFor(path); List<RateLimitConfig.Rule> rules = config.rulesFor(path);
if (rules.isEmpty()) return null; if (rules.isEmpty()) return null;
@@ -35,7 +68,7 @@ public final class RateLimitGate {
String key = rule.name() + ":" + rule.keyResolver().resolve(req, remoteAddress); String key = rule.name() + ":" + rule.keyResolver().resolve(req, remoteAddress);
RateLimiter.Result result = rule.limiter().tryAcquire(key, now); RateLimiter.Result result = rule.limiter().tryAcquire(key, now);
if (!result.allowed()) return result; if (!result.allowed()) return result;
if (strictest == null || result.remaining() < strictest.remaining()) { if (strictest == null || result.remaining() < strictest.remaining()) {
strictest = result; strictest = result;
@@ -44,6 +77,16 @@ public final class RateLimitGate {
return strictest; return strictest;
} }
/**
* Writes the standard rate-limit headers ({@code X-RateLimit-Limit},
* {@code X-RateLimit-Remaining}, and {@code Retry-After} when denied) onto a response.
*
* <p>Does nothing when {@code result} is {@code null} (no rule applied). The retry hint is
* rounded up to whole seconds as required by the {@code Retry-After} header.</p>
*
* @param result the limiting result, may be {@code null}
* @param res the response to decorate
*/
public static void applyHeaders(RateLimiter.Result result, Response res) { public static void applyHeaders(RateLimiter.Result result, Response res) {
if (result == null) return; if (result == null) return;
res.header("X-RateLimit-Limit", String.valueOf(result.limit())); res.header("X-RateLimit-Limit", String.valueOf(result.limit()));
@@ -53,9 +96,16 @@ public final class RateLimitGate {
} }
} }
/**
* Periodic cleanup hook invoked by the background scheduler to evict limiter state that has
* not been touched recently (older than roughly ten minutes).
*/
private void doCleanup() { private void doCleanup() {
long threshold = 10L * 60 * 1_000_000_000L; long threshold = 10L * 60 * 1_000_000_000L;
} }
/**
* Stops the background cleanup scheduler. Should be called when the server shuts down.
*/
public void shutdown() { cleanup.shutdown(); } public void shutdown() { cleanup.shutdown(); }
} }
@@ -1,21 +1,61 @@
package dev.coph.nextusweb.server.ratelimit; package dev.coph.nextusweb.server.ratelimit;
/**
* Strategy interface for rate limiting. An implementation decides, per logical key, whether a
* single request may proceed right now.
*
* <p>Concrete strategies in this package include {@link TokenBucketLimiter},
* {@link LeakyBucketLimiter}, {@link FixedWindowLimiter} and {@link SlidingWindowLimiter}.
* Implementations are expected to be thread-safe, since the same limiter is shared across all
* request-handling threads.</p>
*/
public interface RateLimiter { public interface RateLimiter {
/**
* Attempts to consume one unit of quota for the given key at the given timestamp.
*
* @param key the logical bucket key (for example a client IP or user identifier)
* @param nowNanos the current time in nanoseconds, typically {@link System#nanoTime()}
* @return a {@link Result} describing whether the request was allowed and the remaining
* quota
*/
Result tryAcquire(String key, long nowNanos); Result tryAcquire(String key, long nowNanos);
/**
* Immutable outcome of a {@link #tryAcquire(String, long)} call.
*
* @param allowed whether the request may proceed
* @param remaining the remaining quota in the current window/bucket
* @param limit the configured limit, surfaced as {@code X-RateLimit-Limit}
* @param retryAfterMillis when denied, how long the caller should wait before retrying, in
* milliseconds (0 when allowed)
*/
record Result( record Result(
boolean allowed, boolean allowed,
long remaining, long remaining,
long limit, long limit,
long retryAfterMillis long retryAfterMillis
) { ) {
/**
* Creates a result representing an allowed request.
*
* @param remaining the remaining quota after this request
* @param limit the configured limit
* @return an "allowed" result with no retry delay
*/
public static Result allow(long remaining, long limit) { public static Result allow(long remaining, long limit) {
return new Result(true, remaining, limit, 0); return new Result(true, remaining, limit, 0);
} }
/**
* Creates a result representing a denied (rate-limited) request.
*
* @param limit the configured limit
* @param retryAfterMillis how long to wait before retrying, in milliseconds
* @return a "denied" result with zero remaining quota
*/
public static Result deny(long limit, long retryAfterMillis) { public static Result deny(long limit, long retryAfterMillis) {
return new Result(false, 0, limit, retryAfterMillis); return new Result(false, 0, limit, retryAfterMillis);
} }
} }
} }
@@ -3,39 +3,99 @@ package dev.coph.nextusweb.server.ratelimit;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
/**
* A {@link RateLimiter} implementing the <em>sliding window counter</em> algorithm.
*
* <p>This refines {@link FixedWindowLimiter} by smoothing the boundary between adjacent
* windows. It keeps the count for the current window and the previous window, and estimates the
* effective rate by weighting the previous window's count by how much of the current window has
* not yet elapsed. This avoids the burst-doubling that a plain fixed window allows at window
* boundaries, at the cost of a little extra state.</p>
*
* <p>Because the weighted calculation must read and update several fields atomically together,
* the per-key update is guarded by {@code synchronized}; the per-key state objects are stored
* in a {@link ConcurrentHashMap}.</p>
*/
public final class SlidingWindowLimiter implements RateLimiter { public final class SlidingWindowLimiter implements RateLimiter {
/** Maximum effective (weighted) number of requests per window. */
private final long limit; private final long limit;
/** Window length in nanoseconds. */
private final long windowNanos; private final long windowNanos;
/** Per-key sliding windows, created on demand. */
private final ConcurrentHashMap<String, SlidingWindow> windows = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, SlidingWindow> windows = new ConcurrentHashMap<>();
/**
* Creates a sliding-window limiter.
*
* @param limit the maximum effective number of requests per window
* @param windowMillis the window length in milliseconds
*/
public SlidingWindowLimiter(long limit, long windowMillis) { public SlidingWindowLimiter(long limit, long windowMillis) {
this.limit = limit; this.limit = limit;
this.windowNanos = windowMillis * 1_000_000L; this.windowNanos = windowMillis * 1_000_000L;
} }
/**
* {@inheritDoc}
*
* <p>Lazily creates the sliding window for {@code key} and counts this request against
* it.</p>
*/
@Override @Override
public Result tryAcquire(String key, long nowNanos) { public Result tryAcquire(String key, long nowNanos) {
SlidingWindow w = windows.computeIfAbsent(key, k -> new SlidingWindow(nowNanos)); SlidingWindow w = windows.computeIfAbsent(key, k -> new SlidingWindow(nowNanos));
return w.tryAcquire(nowNanos, limit, windowNanos); return w.tryAcquire(nowNanos, limit, windowNanos);
} }
/**
* Evicts windows whose start time is older than the given age.
*
* @param olderThanNanos maximum age in nanoseconds before a window is removed
*/
public void cleanup(long olderThanNanos) { public void cleanup(long olderThanNanos) {
long now = System.nanoTime(); long now = System.nanoTime();
windows.entrySet().removeIf(e -> now - e.getValue().windowStart.get() > olderThanNanos); windows.entrySet().removeIf(e -> now - e.getValue().windowStart.get() > olderThanNanos);
} }
/**
* A single client's sliding window, tracking the current window start plus the current and
* previous window counts.
*/
private static final class SlidingWindow { private static final class SlidingWindow {
/** Start timestamp of the current window, in nanoseconds. */
final AtomicLong windowStart; final AtomicLong windowStart;
/** Request count accumulated in the current window. */
final AtomicLong currentCount; final AtomicLong currentCount;
/** Request count carried over from the immediately preceding window. */
final AtomicLong previousCount; final AtomicLong previousCount;
/**
* Creates a sliding window starting at the given time with zero counts.
*
* @param now the window start timestamp in nanoseconds
*/
SlidingWindow(long now) { SlidingWindow(long now) {
this.windowStart = new AtomicLong(now); this.windowStart = new AtomicLong(now);
this.currentCount = new AtomicLong(0); this.currentCount = new AtomicLong(0);
this.previousCount = new AtomicLong(0); this.previousCount = new AtomicLong(0);
} }
/**
* Advances the window(s) as time has passed, computes the weighted request count and
* decides whether this request stays within the limit.
*
* <p>If two or more full windows have elapsed the counters are reset; if exactly one has
* elapsed the current count becomes the previous count and a fresh window starts. The
* weighted count blends the previous window's count (scaled by the fraction of the
* current window still remaining) with the current count.</p>
*
* @param now the current time in nanoseconds
* @param limit the per-window effective limit
* @param windowNanos the window length in nanoseconds
* @return an allow result with the remaining quota, or a deny result with the time until
* the window slides far enough to admit the request
*/
synchronized Result tryAcquire(long now, long limit, long windowNanos) { synchronized Result tryAcquire(long now, long limit, long windowNanos) {
long start = windowStart.get(); long start = windowStart.get();
long elapsed = now - start; long elapsed = now - start;
@@ -64,4 +124,4 @@ public final class SlidingWindowLimiter implements RateLimiter {
return Result.allow(limit - weightedCount - 1, limit); return Result.allow(limit - weightedCount - 1, limit);
} }
} }
} }
@@ -3,59 +3,122 @@ package dev.coph.nextusweb.server.ratelimit;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
/**
* A {@link RateLimiter} implementing the <em>token bucket</em> algorithm.
*
* <p>Each key owns a bucket that holds up to {@code burstCapacity} tokens and refills
* continuously at {@code requestsPerSecond} tokens per second. Every request consumes one
* token; if at least one token is available the request is allowed, otherwise it is denied
* with a retry hint computed from the refill rate. This permits short bursts (up to the bucket
* capacity) while bounding the sustained rate.</p>
*
* <p>Token counts are stored in fixed-point form (scaled by 1e9) inside {@link AtomicLong}s and
* updated with a lock-free compare-and-set loop, so the limiter is safe for concurrent use.</p>
*/
public final class TokenBucketLimiter implements RateLimiter { public final class TokenBucketLimiter implements RateLimiter {
private final long capacity; /** Maximum number of tokens a bucket can hold (the burst allowance). */
private final double tokensPerNano; private final long capacity;
/** Refill rate expressed as tokens added per nanosecond. */
private final double tokensPerNano;
/** Approximate nanoseconds between single-token refills, used for retry hints. */
private final long refillIntervalNs; private final long refillIntervalNs;
/** Per-key buckets, created on demand. */
private final ConcurrentHashMap<String, Bucket> buckets = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, Bucket> buckets = new ConcurrentHashMap<>();
/**
* Creates a token-bucket limiter.
*
* @param requestsPerSecond the sustained refill rate in tokens (requests) per second
* @param burstCapacity the maximum burst size, i.e. the bucket capacity in tokens
*/
public TokenBucketLimiter(long requestsPerSecond, long burstCapacity) { public TokenBucketLimiter(long requestsPerSecond, long burstCapacity) {
this.capacity = burstCapacity; this.capacity = burstCapacity;
this.tokensPerNano = (double) requestsPerSecond / 1_000_000_000.0; this.tokensPerNano = (double) requestsPerSecond / 1_000_000_000.0;
this.refillIntervalNs = 1_000_000_000L / Math.max(1, requestsPerSecond); this.refillIntervalNs = 1_000_000_000L / Math.max(1, requestsPerSecond);
} }
/**
* {@inheritDoc}
*
* <p>Lazily creates the bucket for {@code key} (initially full) and attempts to consume one
* token from it.</p>
*/
@Override @Override
public Result tryAcquire(String key, long nowNanos) { public Result tryAcquire(String key, long nowNanos) {
Bucket b = buckets.computeIfAbsent(key, k -> new Bucket(capacity, nowNanos)); Bucket b = buckets.computeIfAbsent(key, k -> new Bucket(capacity, nowNanos));
return b.tryAcquire(nowNanos, capacity, tokensPerNano, refillIntervalNs); return b.tryAcquire(nowNanos, capacity, tokensPerNano, refillIntervalNs);
} }
/**
* Evicts buckets that have not been accessed within the given age, bounding memory use.
*
* @param olderThanNanos maximum idle age in nanoseconds before a bucket is removed
*/
public void cleanup(long olderThanNanos) { public void cleanup(long olderThanNanos) {
long now = System.nanoTime(); long now = System.nanoTime();
buckets.entrySet().removeIf(e -> now - e.getValue().lastAccess() > olderThanNanos); buckets.entrySet().removeIf(e -> now - e.getValue().lastAccess() > olderThanNanos);
} }
/**
* A single client's token bucket. Tokens are stored in fixed-point form (multiplied by
* 1e9) to retain sub-token precision while using integer atomics.
*
* @param tokensFixed current token count in fixed-point (tokens &times; 1e9)
* @param lastRefillNanos timestamp of the last refill/consume, in nanoseconds
*/
private record Bucket(AtomicLong tokensFixed, AtomicLong lastRefillNanos) { private record Bucket(AtomicLong tokensFixed, AtomicLong lastRefillNanos) {
/**
* Creates a full bucket.
*
* @param tokensFixed initial token count (in whole tokens, scaled internally)
* @param lastRefillNanos the creation timestamp in nanoseconds
*/
private Bucket(long tokensFixed, long lastRefillNanos) { private Bucket(long tokensFixed, long lastRefillNanos) {
this(new AtomicLong(tokensFixed * 1_000_000_000L), new AtomicLong(lastRefillNanos)); this(new AtomicLong(tokensFixed * 1_000_000_000L), new AtomicLong(lastRefillNanos));
} }
/**
* Returns the timestamp of the last access, used by {@link #cleanup(long)}.
*
* @return the last-refill timestamp in nanoseconds
*/
long lastAccess() { long lastAccess() {
return lastRefillNanos.get(); return lastRefillNanos.get();
} }
/**
* Refills the bucket according to elapsed time and attempts to consume one token,
* retrying via compare-and-set on contention.
*
* @param now the current time in nanoseconds
* @param capacity the bucket capacity in whole tokens
* @param tokensPerNano the refill rate in tokens per nanosecond
* @param refillIntervalNs the nominal nanoseconds per token (unused in the hot path
* but kept for symmetry/retry computation)
* @return an allow result with the remaining tokens, or a deny result with a retry
* hint when fewer than one token is available
*/
Result tryAcquire(long now, long capacity, double tokensPerNano, long refillIntervalNs) { Result tryAcquire(long now, long capacity, double tokensPerNano, long refillIntervalNs) {
while (true) { while (true) {
long lastRefill = lastRefillNanos.get(); long lastRefill = lastRefillNanos.get();
long currentTokens = tokensFixed.get(); long currentTokens = tokensFixed.get();
long elapsed = now - lastRefill; long elapsed = now - lastRefill;
long refilled = currentTokens; long refilled = currentTokens;
if (elapsed > 0) { if (elapsed > 0) {
long addedFixed = (long) (elapsed * tokensPerNano * 1_000_000_000.0); long addedFixed = (long) (elapsed * tokensPerNano * 1_000_000_000.0);
refilled = Math.min(currentTokens + addedFixed, capacity * 1_000_000_000L); refilled = Math.min(currentTokens + addedFixed, capacity * 1_000_000_000L);
} }
long oneTokenFixed = 1_000_000_000L; long oneTokenFixed = 1_000_000_000L;
if (refilled < oneTokenFixed) { if (refilled < oneTokenFixed) {
long deficitFixed = oneTokenFixed - refilled; long deficitFixed = oneTokenFixed - refilled;
long retryNs = (long) (deficitFixed / (tokensPerNano * 1_000_000_000.0)); long retryNs = (long) (deficitFixed / (tokensPerNano * 1_000_000_000.0));
return Result.deny(capacity, Math.max(1, retryNs / 1_000_000)); return Result.deny(capacity, Math.max(1, retryNs / 1_000_000));
} }
long newTokens = refilled - oneTokenFixed; long newTokens = refilled - oneTokenFixed;
if (tokensFixed.compareAndSet(currentTokens, newTokens)) { if (tokensFixed.compareAndSet(currentTokens, newTokens)) {
lastRefillNanos.set(now); lastRefillNanos.set(now);
@@ -64,4 +127,4 @@ public final class TokenBucketLimiter implements RateLimiter {
} }
} }
} }
} }
@@ -9,22 +9,57 @@ import tools.jackson.databind.JsonNode;
import java.util.*; import java.util.*;
/**
* A convenience wrapper around a Netty {@link FullHttpRequest} that exposes the parts of an
* HTTP request handlers typically need: path parameters, query parameters, headers and the
* request body (raw, as a parsed JSON tree, or deserialized into a type).
*
* <p>Query parameters and the parsed JSON body are computed lazily and cached, so repeated
* accessors do not re-parse the request. A single {@code Request} instance is not intended to
* be shared across threads.</p>
*/
public final class Request { public final class Request {
private final FullHttpRequest raw;
private final Map<String, String> pathParams;
private Map<String, List<String>> queryParams;
private JsonNode jsonCache;
/** The underlying Netty request this wrapper delegates to. */
private final FullHttpRequest raw;
/** Path parameters captured by the router while matching, keyed by name. */
private final Map<String, String> pathParams;
/** Lazily decoded query-string parameters; {@code null} until first accessed. */
private Map<String, List<String>> queryParams;
/** Lazily parsed JSON body; {@code null} until {@link #json()} is first called. */
private JsonNode jsonCache;
/**
* Creates a request wrapper.
*
* @param raw the underlying Netty request
* @param pathParams the path parameters captured during routing, keyed by name
*/
public Request(FullHttpRequest raw, Map<String, String> pathParams) { public Request(FullHttpRequest raw, Map<String, String> pathParams) {
this.raw = raw; this.raw = raw;
this.pathParams = pathParams; this.pathParams = pathParams;
} }
/**
* Returns the value of a path parameter captured during routing.
*
* @param name the parameter name as declared in the route (without braces)
* @return the captured value, or {@code null} if no such parameter was matched
*/
public String pathParam(String name) { public String pathParam(String name) {
return pathParams.get(name); return pathParams.get(name);
} }
/**
* Returns the first value of a query-string parameter, decoding the query string on first
* access.
*
* @param name the query parameter name
* @return the first value, or {@code null} if the parameter is absent or has no value
*/
public String queryParam(String name) { public String queryParam(String name) {
if (queryParams == null) { if (queryParams == null) {
queryParams = new QueryStringDecoder(raw.uri()).parameters(); queryParams = new QueryStringDecoder(raw.uri()).parameters();
@@ -33,6 +68,13 @@ public final class Request {
return values == null || values.isEmpty() ? null : values.getFirst(); return values == null || values.isEmpty() ? null : values.getFirst();
} }
/**
* Returns all values of a query-string parameter, decoding the query string on first
* access.
*
* @param name the query parameter name
* @return the (possibly empty) list of values for the parameter; never {@code null}
*/
public List<String> queryParams(String name) { public List<String> queryParams(String name) {
if (queryParams == null) { if (queryParams == null) {
queryParams = new QueryStringDecoder(raw.uri()).parameters(); queryParams = new QueryStringDecoder(raw.uri()).parameters();
@@ -40,14 +82,32 @@ public final class Request {
return queryParams.getOrDefault(name, List.of()); return queryParams.getOrDefault(name, List.of());
} }
/**
* Returns the value of a request header.
*
* @param name the (case-insensitive) header name
* @return the header value, or {@code null} if not present
*/
public String header(String name) { public String header(String name) {
return raw.headers().get(name); return raw.headers().get(name);
} }
/**
* Returns the request body decoded as a UTF-8 string.
*
* @return the body as text (empty if there is no body)
*/
public String body() { public String body() {
return raw.content().toString(CharsetUtil.UTF_8); return raw.content().toString(CharsetUtil.UTF_8);
} }
/**
* Parses the request body as a JSON tree, caching the result for subsequent calls. An
* empty body resolves to a JSON {@code null} node rather than an error.
*
* @return the parsed JSON tree
* @throws BadRequestException if the body is not valid JSON
*/
public JsonNode json() { public JsonNode json() {
if (jsonCache == null) { if (jsonCache == null) {
try { try {
@@ -65,6 +125,17 @@ public final class Request {
return jsonCache; return jsonCache;
} }
/**
* Deserializes the request body directly into an instance of the given type.
*
* <p>Unlike {@link #json()}, the result is not cached and the body is read fresh on each
* call.</p>
*
* @param type the target type to deserialize into
* @param <T> the target type
* @return the deserialized value
* @throws BadRequestException if the body cannot be deserialized into {@code type}
*/
public <T> T jsonAs(Class<T> type) { public <T> T jsonAs(Class<T> type) {
try { try {
byte[] bytes = new byte[raw.content().readableBytes()]; byte[] bytes = new byte[raw.content().readableBytes()];
@@ -76,11 +147,21 @@ public final class Request {
} }
} }
/**
* Returns the request's HTTP method.
*
* @return the HTTP method
*/
public HttpMethod method() { public HttpMethod method() {
return raw.method(); return raw.method();
} }
/**
* Returns the request's path, with any query string stripped off.
*
* @return the decoded request path
*/
public String path() { public String path() {
return new QueryStringDecoder(raw.uri()).path(); return new QueryStringDecoder(raw.uri()).path();
} }
} }
@@ -5,31 +5,87 @@ import io.netty.handler.codec.http.*;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import tools.jackson.core.JacksonException; import tools.jackson.core.JacksonException;
/**
* A mutable builder for the HTTP response a handler produces. Handlers set the status code,
* headers and body fluently; the request pipeline later reads these back via the accessor
* methods to construct the actual Netty response on the wire.
*
* <p>The status defaults to {@code 200 OK} and the body to an empty byte array. The body
* setters ({@link #text(String)}, {@link #json(String)}, {@link #json(Object)}) also set an
* appropriate {@code Content-Type} header.</p>
*/
public final class Response { public final class Response {
/** HTTP status code; defaults to {@code 200}. */
private int status = 200; private int status = 200;
/** Response headers accumulated by the handler. */
private final HttpHeaders headers = new DefaultHttpHeaders(); private final HttpHeaders headers = new DefaultHttpHeaders();
/** Response body bytes; defaults to an empty array. */
private byte[] body = new byte[0]; private byte[] body = new byte[0];
/**
* Creates an empty response with status {@code 200}, no headers and an empty body, ready to
* be populated fluently by a handler.
*/
public Response() {
}
/**
* Sets the HTTP status code.
*
* @param s the status code
* @return this response, for fluent chaining
*/
public Response status(int s) { this.status = s; return this; } public Response status(int s) { this.status = s; return this; }
/**
* Sets a response header, replacing any existing value for the same name.
*
* @param name the header name
* @param value the header value
* @return this response, for fluent chaining
*/
public Response header(String name, String value) { public Response header(String name, String value) {
headers.set(name, value); headers.set(name, value);
return this; return this;
} }
/**
* Sets the body to the given text encoded as UTF-8 and sets the {@code Content-Type} to
* {@code text/plain; charset=utf-8}.
*
* @param s the text body
* @return this response, for fluent chaining
*/
public Response text(String s) { public Response text(String s) {
this.body = s.getBytes(CharsetUtil.UTF_8); this.body = s.getBytes(CharsetUtil.UTF_8);
headers.set(HttpHeaderNames.CONTENT_TYPE, "text/plain; charset=utf-8"); headers.set(HttpHeaderNames.CONTENT_TYPE, "text/plain; charset=utf-8");
return this; return this;
} }
/**
* Sets the body to an already-serialized JSON string and sets the {@code Content-Type} to
* {@code application/json; charset=utf-8}.
*
* @param json the raw JSON string
* @return this response, for fluent chaining
*/
public Response json(String json) { public Response json(String json) {
this.body = json.getBytes(CharsetUtil.UTF_8); this.body = json.getBytes(CharsetUtil.UTF_8);
headers.set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=utf-8"); headers.set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=utf-8");
return this; return this;
} }
/**
* Serializes the given value to JSON, sets it as the body and sets the {@code Content-Type}
* to {@code application/json; charset=utf-8}.
*
* @param value the object to serialize
* @return this response, for fluent chaining
* @throws RuntimeException if JSON serialization fails
*/
public Response json(Object value) { public Response json(Object value) {
try { try {
this.body = JsonMapper.MAPPER.writeValueAsBytes(value); this.body = JsonMapper.MAPPER.writeValueAsBytes(value);
@@ -40,7 +96,24 @@ public final class Response {
return this; return this;
} }
/**
* Returns the configured HTTP status code.
*
* @return the status code
*/
public int status() { return status; } public int status() { return status; }
/**
* Returns the accumulated response headers.
*
* @return the headers
*/
public HttpHeaders headers() { return headers; } public HttpHeaders headers() { return headers; }
/**
* Returns the response body bytes.
*
* @return the body bytes
*/
public byte[] body() { return body; } public byte[] body() { return body; }
} }
@@ -6,20 +6,71 @@ import java.util.*;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
/**
* A trie-based HTTP router that maps {@code (method, path)} pairs to {@link Handler handlers}.
*
* <p>Routes are stored in a prefix tree (radix-style {@link Node} tree) keyed by path segment.
* Three kinds of segments are supported:</p>
* <ul>
* <li><strong>static</strong> segments such as {@code users}, matched literally;</li>
* <li><strong>path parameters</strong> written as {@code {name}}, which match any single
* segment and capture its value under {@code name};</li>
* <li><strong>wildcards</strong> written as {@code *}, which match any single segment
* without capturing it.</li>
* </ul>
*
* <p>The router also holds an ordered list of {@link BiConsumer middlewares} that the request
* pipeline runs against every matched request before the handler executes.</p>
*
* <p>Registration mutates the shared trie and is intended to happen during start-up;
* {@link #resolve(HttpMethod, String)} is safe to call concurrently afterwards because the
* per-node maps are {@link ConcurrentHashMap}s.</p>
*/
public final class Router { public final class Router {
/** Root of the routing trie; every registered path descends from here. */
private final Node root = new Node(); private final Node root = new Node();
/** Middlewares executed in insertion order for every matched request. */
private final List<BiConsumer<Request, Response>> middlewares = new ArrayList<>(); private final List<BiConsumer<Request, Response>> middlewares = new ArrayList<>();
/**
* Creates an empty router with no registered routes and no middlewares.
*/
public Router() {
}
/**
* Registers a middleware that runs against every matched request before its handler.
*
* @param middleware a callback receiving the request and the response being built
* @return this router, for fluent chaining
*/
public Router use(BiConsumer<Request, Response> middleware) { public Router use(BiConsumer<Request, Response> middleware) {
middlewares.add(middleware); middlewares.add(middleware);
return this; return this;
} }
/**
* Registers a handler for the {@code GET} method at the given path.
*
* @param path the route path (supports {@code {param}} and {@code *} segments)
* @param h the handler to invoke
* @return this router, for fluent chaining
*/
public Router get(String path, Handler h) { public Router get(String path, Handler h) {
return register(HttpMethod.GET, path, h); return register(HttpMethod.GET, path, h);
} }
/**
* Registers a handler for an arbitrary HTTP method at the given path, creating any missing
* trie nodes along the way.
*
* @param method the HTTP method to bind the handler to
* @param path the route path (supports {@code {param}} and {@code *} segments)
* @param h the handler to invoke
* @return this router, for fluent chaining
*/
public Router register(HttpMethod method, String path, Handler h) { public Router register(HttpMethod method, String path, Handler h) {
Node node = root; Node node = root;
for (String segment : split(path)) { for (String segment : split(path)) {
@@ -40,6 +91,13 @@ public final class Router {
return this; return this;
} }
/**
* Splits a path into its non-empty segments, ignoring leading and collapsing internal
* slashes. For example {@code "/a/b/"} yields {@code ["a", "b"]}.
*
* @param path the raw path
* @return the ordered list of path segments
*/
private static List<String> split(String path) { private static List<String> split(String path) {
List<String> out = new ArrayList<>(); List<String> out = new ArrayList<>();
int start = path.startsWith("/") ? 1 : 0; int start = path.startsWith("/") ? 1 : 0;
@@ -53,18 +111,53 @@ public final class Router {
return out; return out;
} }
/**
* Registers a handler for the {@code POST} method at the given path.
*
* @param path the route path
* @param h the handler to invoke
* @return this router, for fluent chaining
*/
public Router post(String path, Handler h) { public Router post(String path, Handler h) {
return register(HttpMethod.POST, path, h); return register(HttpMethod.POST, path, h);
} }
/**
* Registers a handler for the {@code PUT} method at the given path.
*
* @param path the route path
* @param h the handler to invoke
* @return this router, for fluent chaining
*/
public Router put(String path, Handler h) { public Router put(String path, Handler h) {
return register(HttpMethod.PUT, path, h); return register(HttpMethod.PUT, path, h);
} }
/**
* Registers a handler for the {@code DELETE} method at the given path.
*
* @param path the route path
* @param h the handler to invoke
* @return this router, for fluent chaining
*/
public Router delete(String path, Handler h) { public Router delete(String path, Handler h) {
return register(HttpMethod.DELETE, path, h); return register(HttpMethod.DELETE, path, h);
} }
/**
* Resolves an incoming request against the routing trie.
*
* <p>Static segments are matched first, falling back to a path-parameter child (capturing
* the segment value) and then a wildcard child. If the path cannot be matched a
* {@link Resolution.NotFound} is returned. If the path matches but no handler exists for
* the requested method, a {@link Resolution.MethodNotAllowed} carrying the set of allowed
* methods is returned. Otherwise a {@link Resolution.Match} with the handler and captured
* path parameters is returned.</p>
*
* @param method the request's HTTP method
* @param path the request's path
* @return the resolution outcome, never {@code null}
*/
public Resolution resolve(HttpMethod method, String path) { public Resolution resolve(HttpMethod method, String path) {
Map<String, String> params = new HashMap<>(4); Map<String, String> params = new HashMap<>(4);
Node node = root; Node node = root;
@@ -95,31 +188,74 @@ public final class Router {
return new Resolution.NotFound(); return new Resolution.NotFound();
} }
/**
* Returns the live, ordered list of registered middlewares.
*
* @return the middleware list (modifications affect this router)
*/
public List<BiConsumer<Request, Response>> middlewares() { public List<BiConsumer<Request, Response>> middlewares() {
return middlewares; return middlewares;
} }
/**
* Sealed result type describing the three possible outcomes of {@link #resolve}.
*/
public sealed interface Resolution { public sealed interface Resolution {
/**
* A successful match.
*
* @param handler the handler to invoke for the request
* @param pathParams the path parameters captured while matching, keyed by name
*/
record Match(Handler handler, Map<String, String> pathParams) implements Resolution { record Match(Handler handler, Map<String, String> pathParams) implements Resolution {
} }
/**
* The path matched but no handler is registered for the requested method.
*
* @param allowedMethods the methods that <em>are</em> registered for this path
*/
record MethodNotAllowed(Set<HttpMethod> allowedMethods) implements Resolution { record MethodNotAllowed(Set<HttpMethod> allowedMethods) implements Resolution {
} }
/**
* No route matches the requested path.
*/
record NotFound() implements Resolution { record NotFound() implements Resolution {
} }
} }
/**
* Functional contract for a request handler: consumes the incoming {@link Request} and
* mutates the outgoing {@link Response}.
*/
@FunctionalInterface @FunctionalInterface
public interface Handler { public interface Handler {
/**
* Handles a matched request.
*
* @param req the incoming request
* @param res the response to populate
* @throws Exception if handling fails; the request pipeline translates this into an
* appropriate error response
*/
void handle(Request req, Response res) throws Exception; void handle(Request req, Response res) throws Exception;
} }
/**
* A single node in the routing trie. Holds static children keyed by segment, the handlers
* registered at this node, and optional parameter/wildcard children.
*/
private static final class Node { private static final class Node {
/** Static child nodes keyed by their literal path segment. */
final Map<String, Node> children = new ConcurrentHashMap<>(); final Map<String, Node> children = new ConcurrentHashMap<>();
/** Handlers registered directly at this node, keyed by HTTP method. */
final Map<HttpMethod, Handler> handlers = new ConcurrentHashMap<>(); final Map<HttpMethod, Handler> handlers = new ConcurrentHashMap<>();
Node paramChild; /** Child matching any single segment as a path parameter, or {@code null} if none. */
Node paramChild;
/** Name under which {@link #paramChild} captures the matched segment. */
String paramName; String paramName;
Node wildcardChild; /** Child matching any single segment as a wildcard, or {@code null} if none. */
Node wildcardChild;
} }
} }
@@ -1,5 +1,20 @@
package dev.coph.nextusweb.server.router.exception; package dev.coph.nextusweb.server.router.exception;
/**
* Unchecked exception signalling that an incoming request is malformed and should be answered
* with an HTTP {@code 400 Bad Request}.
*
* <p>It is thrown, for example, when a request body cannot be parsed as JSON or deserialized
* into the expected type. The request pipeline catches it and translates the
* {@linkplain #getMessage() message} into a {@code 400} response, distinguishing it from
* unexpected errors which produce a {@code 500}.</p>
*/
public final class BadRequestException extends RuntimeException { public final class BadRequestException extends RuntimeException {
/**
* Creates a bad-request exception with a human-readable explanation.
*
* @param message the detail message describing why the request is invalid
*/
public BadRequestException(String message) { super(message); } public BadRequestException(String message) { super(message); }
} }
@@ -5,17 +5,38 @@ import java.util.Collections;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
/**
* Immutable configuration for the WebSocket subsystem: frame and message size limits, idle
* timeout, allowed origins, negotiated subprotocols, and compression. Instances are created
* through the nested {@link Builder}.
*
* <p>The values configured here govern how {@code HttpRequestHandler} sets up the WebSocket
* portion of the Netty pipeline during the upgrade handshake.</p>
*/
public final class WebSocketConfig { public final class WebSocketConfig {
/** Maximum size, in bytes, of a single WebSocket frame payload. */
private final int maxFramePayloadLength; private final int maxFramePayloadLength;
/** Maximum size, in bytes, of an aggregated (multi-frame) message. */
private final int maxAggregatedMessageSize; private final int maxAggregatedMessageSize;
/** Idle timeout after which an inactive connection is closed; {@code null} disables it. */
private final Duration idleTimeout; private final Duration idleTimeout;
/** Explicit set of allowed origins; ignored when {@link #allowAnyOrigin} is {@code true}. */
private final Set<String> allowedOrigins; private final Set<String> allowedOrigins;
/** Whether connections from any origin are accepted. */
private final boolean allowAnyOrigin; private final boolean allowAnyOrigin;
/** Subprotocols offered during negotiation. */
private final Set<String> subprotocols; private final Set<String> subprotocols;
/** Whether per-message deflate compression is enabled. */
private final boolean compression; private final boolean compression;
/** Whether the protocol handler matches the path by prefix rather than exact equality. */
private final boolean checkStartsWith; private final boolean checkStartsWith;
/**
* Builds an immutable configuration from a {@link Builder}, defensively copying its sets.
*
* @param b the builder carrying the configured values
*/
private WebSocketConfig(Builder b) { private WebSocketConfig(Builder b) {
this.maxFramePayloadLength = b.maxFramePayloadLength; this.maxFramePayloadLength = b.maxFramePayloadLength;
this.maxAggregatedMessageSize = b.maxAggregatedMessageSize; this.maxAggregatedMessageSize = b.maxAggregatedMessageSize;
@@ -27,110 +48,247 @@ public final class WebSocketConfig {
this.checkStartsWith = b.checkStartsWith; this.checkStartsWith = b.checkStartsWith;
} }
/**
* Creates a configuration with all default values.
*
* @return a default configuration
*/
public static WebSocketConfig defaults() { public static WebSocketConfig defaults() {
return builder().build(); return builder().build();
} }
/**
* Creates a new, empty {@link Builder}.
*
* @return a fresh builder
*/
public static Builder builder() { public static Builder builder() {
return new Builder(); return new Builder();
} }
/**
* Tests whether a WebSocket upgrade from the given origin is permitted.
*
* @param origin the request's {@code Origin} header, may be {@code null}
* @return {@code true} if any origin is allowed, or if the origin is in the allow-list;
* {@code false} for a {@code null} or disallowed origin
*/
public boolean isOriginAllowed(String origin) { public boolean isOriginAllowed(String origin) {
if (allowAnyOrigin) return true; if (allowAnyOrigin) return true;
if (origin == null) return false; if (origin == null) return false;
return allowedOrigins.contains(origin); return allowedOrigins.contains(origin);
} }
/**
* Returns the maximum size of a single WebSocket frame payload.
*
* @return the maximum single-frame payload size in bytes
*/
public int maxFramePayloadLength() { public int maxFramePayloadLength() {
return maxFramePayloadLength; return maxFramePayloadLength;
} }
/**
* Returns the maximum size of an aggregated (multi-frame) message.
*
* @return the maximum aggregated message size in bytes
*/
public int maxAggregatedMessageSize() { public int maxAggregatedMessageSize() {
return maxAggregatedMessageSize; return maxAggregatedMessageSize;
} }
/**
* Returns the idle timeout after which inactive connections are closed.
*
* @return the idle timeout, or {@code null} if idle connections are never closed
*/
public Duration idleTimeout() { public Duration idleTimeout() {
return idleTimeout; return idleTimeout;
} }
/**
* Indicates whether connections from any origin are accepted.
*
* @return {@code true} if connections from any origin are accepted
*/
public boolean allowAnyOrigin() { public boolean allowAnyOrigin() {
return allowAnyOrigin; return allowAnyOrigin;
} }
/**
* Returns the explicitly allowed origins.
*
* @return the immutable set of explicitly allowed origins
*/
public Set<String> allowedOrigins() { public Set<String> allowedOrigins() {
return allowedOrigins; return allowedOrigins;
} }
/**
* Returns the configured subprotocols as a comma-separated string suitable for Netty's
* protocol config.
*
* @return the comma-separated subprotocol list, or {@code null} if none are configured
*/
public String subprotocolsCsv() { public String subprotocolsCsv() {
if (subprotocols.isEmpty()) return null; if (subprotocols.isEmpty()) return null;
return String.join(",", subprotocols); return String.join(",", subprotocols);
} }
/**
* Indicates whether per-message deflate compression is enabled.
*
* @return {@code true} if per-message compression is enabled
*/
public boolean compression() { public boolean compression() {
return compression; return compression;
} }
/**
* Indicates whether the WebSocket path is matched by prefix rather than exact equality.
*
* @return {@code true} if the WebSocket path is matched by prefix rather than exactly
*/
public boolean checkStartsWith() { public boolean checkStartsWith() {
return checkStartsWith; return checkStartsWith;
} }
/**
* Fluent builder for {@link WebSocketConfig}, pre-populated with sensible defaults: 64&nbsp;KiB
* frames, 1&nbsp;MiB aggregated messages, a 60-second idle timeout, no origin restriction
* list, compression enabled, and exact path matching.
*/
public static final class Builder { public static final class Builder {
/** Maximum single-frame payload size in bytes; defaults to 64&nbsp;KiB. */
private int maxFramePayloadLength = 65_536; private int maxFramePayloadLength = 65_536;
/** Maximum aggregated message size in bytes; defaults to 1&nbsp;MiB. */
private int maxAggregatedMessageSize = 1_048_576; private int maxAggregatedMessageSize = 1_048_576;
/** Idle timeout; defaults to 60 seconds. */
private Duration idleTimeout = Duration.ofSeconds(60); private Duration idleTimeout = Duration.ofSeconds(60);
/** Accumulated allowed origins (insertion-ordered). */
private final Set<String> allowedOrigins = new LinkedHashSet<>(); private final Set<String> allowedOrigins = new LinkedHashSet<>();
/** Whether any origin is allowed; defaults to {@code false}. */
private boolean allowAnyOrigin = false; private boolean allowAnyOrigin = false;
/** Accumulated subprotocols (insertion-ordered). */
private final Set<String> subprotocols = new LinkedHashSet<>(); private final Set<String> subprotocols = new LinkedHashSet<>();
/** Whether compression is enabled; defaults to {@code true}. */
private boolean compression = true; private boolean compression = true;
/** Whether path matching uses a prefix check; defaults to {@code false}. */
private boolean checkStartsWith = false; private boolean checkStartsWith = false;
/**
* Creates a builder pre-populated with the default configuration values described
* above. Obtain instances via {@link WebSocketConfig#builder()}.
*/
public Builder() {
}
/**
* Sets the maximum single-frame payload size.
*
* @param bytes the limit in bytes; must be positive
* @return this builder, for fluent chaining
* @throws IllegalArgumentException if {@code bytes <= 0}
*/
public Builder maxFramePayloadLength(int bytes) { public Builder maxFramePayloadLength(int bytes) {
if (bytes <= 0) throw new IllegalArgumentException("maxFramePayloadLength must be > 0"); if (bytes <= 0) throw new IllegalArgumentException("maxFramePayloadLength must be > 0");
this.maxFramePayloadLength = bytes; this.maxFramePayloadLength = bytes;
return this; return this;
} }
/**
* Sets the maximum aggregated message size.
*
* @param bytes the limit in bytes; must be positive
* @return this builder, for fluent chaining
* @throws IllegalArgumentException if {@code bytes <= 0}
*/
public Builder maxAggregatedMessageSize(int bytes) { public Builder maxAggregatedMessageSize(int bytes) {
if (bytes <= 0) throw new IllegalArgumentException("maxAggregatedMessageSize must be > 0"); if (bytes <= 0) throw new IllegalArgumentException("maxAggregatedMessageSize must be > 0");
this.maxAggregatedMessageSize = bytes; this.maxAggregatedMessageSize = bytes;
return this; return this;
} }
/**
* Sets the idle timeout after which inactive connections are closed.
*
* @param timeout the idle timeout
* @return this builder, for fluent chaining
*/
public Builder idleTimeout(Duration timeout) { public Builder idleTimeout(Duration timeout) {
this.idleTimeout = timeout; this.idleTimeout = timeout;
return this; return this;
} }
/**
* Disables the idle timeout, so connections are never closed for inactivity.
*
* @return this builder, for fluent chaining
*/
public Builder noIdleTimeout() { public Builder noIdleTimeout() {
this.idleTimeout = null; this.idleTimeout = null;
return this; return this;
} }
/**
* Adds one or more origins to the allow-list.
*
* @param origins the origins to allow
* @return this builder, for fluent chaining
*/
public Builder allowedOrigins(String... origins) { public Builder allowedOrigins(String... origins) {
Collections.addAll(this.allowedOrigins, origins); Collections.addAll(this.allowedOrigins, origins);
return this; return this;
} }
/**
* Allows WebSocket connections from any origin.
*
* @return this builder, for fluent chaining
*/
public Builder anyOrigin() { public Builder anyOrigin() {
this.allowAnyOrigin = true; this.allowAnyOrigin = true;
return this; return this;
} }
/**
* Adds one or more subprotocols to offer during negotiation.
*
* @param protocols the subprotocol names
* @return this builder, for fluent chaining
*/
public Builder subprotocols(String... protocols) { public Builder subprotocols(String... protocols) {
Collections.addAll(this.subprotocols, protocols); Collections.addAll(this.subprotocols, protocols);
return this; return this;
} }
/**
* Enables or disables per-message compression.
*
* @param enabled {@code true} to enable compression
* @return this builder, for fluent chaining
*/
public Builder compression(boolean enabled) { public Builder compression(boolean enabled) {
this.compression = enabled; this.compression = enabled;
return this; return this;
} }
/**
* Sets whether the WebSocket path is matched by prefix rather than exact equality.
*
* @param v {@code true} to match by prefix
* @return this builder, for fluent chaining
*/
public Builder checkStartsWith(boolean v) { public Builder checkStartsWith(boolean v) {
this.checkStartsWith = v; this.checkStartsWith = v;
return this; return this;
} }
/**
* Builds the immutable {@link WebSocketConfig}.
*
* @return the configured instance
*/
public WebSocketConfig build() { public WebSocketConfig build() {
return new WebSocketConfig(this); return new WebSocketConfig(this);
} }
@@ -14,21 +14,53 @@ import java.util.Map;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
/**
* Netty channel handler that bridges low-level WebSocket frames to the high-level
* {@link WebSocketHandler} callbacks.
*
* <p>It creates a {@link WebSocketSession} when the handshake completes, then translates each
* incoming frame into the matching callback ({@code onMessage}, {@code onBinary},
* {@code onClose}). Callbacks are dispatched on a virtual-thread executor so application code
* may block without stalling the Netty event loop, and any exception they throw is funneled to
* {@link WebSocketHandler#onError}. Idle-timeout events close the channel.</p>
*
* <p>This class is package-private; instances are created via
* {@link WebSocketFrameHandlerFactory}.</p>
*/
final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> { final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
private static final Executor VT_EXECUTOR = /** Executor running one virtual thread per task, used to dispatch handler callbacks. */
Executors.newVirtualThreadPerTaskExecutor(); private static final Executor VT_EXECUTOR = Executors.newVirtualThreadPerTaskExecutor();
/** The application handler receiving lifecycle callbacks. */
private final WebSocketHandler handler; private final WebSocketHandler handler;
/** The path the connection was established on. */
private final String path; private final String path;
/** Path parameters captured during routing, keyed by name. */
private final Map<String, String> pathParams; private final Map<String, String> pathParams;
/**
* Creates a frame handler bound to an application handler and connection metadata.
*
* @param handler the application handler to dispatch to
* @param path the connection path
* @param pathParams the captured path parameters
*/
WebSocketFrameHandler(WebSocketHandler handler, String path, Map<String, String> pathParams) { WebSocketFrameHandler(WebSocketHandler handler, String path, Map<String, String> pathParams) {
this.handler = handler; this.handler = handler;
this.path = path; this.path = path;
this.pathParams = pathParams; this.pathParams = pathParams;
} }
/**
* Handles pipeline user events. On handshake completion it creates and stores the
* {@link WebSocketSession} and dispatches {@link WebSocketHandler#onOpen}; on an idle-state
* event it closes the channel; other events are passed up the pipeline.
*
* @param ctx the channel context
* @param evt the user event
* @throws Exception if the superclass handling of an unrecognized event fails
*/
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
@@ -50,6 +82,15 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketF
super.userEventTriggered(ctx, evt); super.userEventTriggered(ctx, evt);
} }
/**
* Dispatches an incoming frame to the appropriate handler callback. Text, binary and close
* frames are forwarded to {@code onMessage}, {@code onBinary} and {@code onClose}
* respectively, each on a virtual thread. Frames arriving before the session exists are
* ignored.
*
* @param ctx the channel context
* @param frame the received WebSocket frame
*/
@Override @Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) { protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get(); WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
@@ -88,6 +129,13 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketF
} }
} }
/**
* Invoked when the channel goes inactive (the connection dropped without a clean close
* handshake). Clears the stored session and dispatches {@link WebSocketHandler#onClose} with
* the abnormal-closure code {@code 1006}.
*
* @param ctx the channel context
*/
@Override @Override
public void channelInactive(ChannelHandlerContext ctx) { public void channelInactive(ChannelHandlerContext ctx) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).getAndSet(null); WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).getAndSet(null);
@@ -101,6 +149,13 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketF
}); });
} }
/**
* Routes a pipeline exception to {@link WebSocketHandler#onError} (when a session exists)
* and then closes the channel.
*
* @param ctx the channel context
* @param cause the exception that propagated up the pipeline
*/
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get(); WebSocketSession session = ctx.channel().attr(WebSocketSession.SESSION_KEY).get();
@@ -108,6 +163,13 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketF
ctx.close(); ctx.close();
} }
/**
* Invokes {@link WebSocketHandler#onError} while swallowing any secondary exception the
* error callback itself might throw, so error handling can never cascade.
*
* @param session the affected session
* @param cause the original error to report
*/
private void safeError(WebSocketSession session, Throwable cause) { private void safeError(WebSocketSession session, Throwable cause) {
try { try {
handler.onError(session, cause); handler.onError(session, cause);
@@ -4,11 +4,30 @@ import io.netty.channel.ChannelHandler;
import java.util.Map; import java.util.Map;
/**
* Small factory that creates the package-private {@code WebSocketFrameHandler} channel handler.
*
* <p>It exists so that other packages (notably {@code HttpRequestHandler} during the upgrade
* handshake) can insert a frame handler into the pipeline without the handler class itself
* having to be public. The class is a stateless utility and cannot be instantiated.</p>
*/
public final class WebSocketFrameHandlerFactory { public final class WebSocketFrameHandlerFactory {
/**
* Private constructor preventing instantiation of this stateless utility class.
*/
private WebSocketFrameHandlerFactory() { private WebSocketFrameHandlerFactory() {
} }
/**
* Creates a channel handler that bridges Netty WebSocket frames to the given application
* {@link WebSocketHandler}.
*
* @param handler the application handler to dispatch lifecycle events to
* @param path the path the connection was established on
* @param pathParams the path parameters captured during routing
* @return a new channel handler ready to be inserted into the pipeline
*/
public static ChannelHandler create(WebSocketHandler handler, String path, public static ChannelHandler create(WebSocketHandler handler, String path,
Map<String, String> pathParams) { Map<String, String> pathParams) {
return new WebSocketFrameHandler(handler, path, pathParams); return new WebSocketFrameHandler(handler, path, pathParams);
@@ -8,43 +8,97 @@ import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor; import io.netty.util.concurrent.GlobalEventExecutor;
import tools.jackson.core.JacksonException; import tools.jackson.core.JacksonException;
/**
* A named collection of WebSocket connections that supports broadcasting to all members at
* once &mdash; useful for chat rooms, pub/sub topics, presence channels and similar fan-out
* scenarios.
*
* <p>It is backed by a Netty {@link ChannelGroup}, which automatically removes channels as they
* close, so callers do not need to prune disconnected sessions manually. The group is
* thread-safe.</p>
*/
public final class WebSocketGroup { public final class WebSocketGroup {
/** Underlying Netty channel group holding the member connections. */
private final ChannelGroup channels; private final ChannelGroup channels;
/** Human-readable name of this group. */
private final String name; private final String name;
/**
* Creates an unnamed group (named {@code "anonymous"}).
*/
public WebSocketGroup() { public WebSocketGroup() {
this("anonymous"); this("anonymous");
} }
/**
* Creates a named group.
*
* @param name the group name
*/
public WebSocketGroup(String name) { public WebSocketGroup(String name) {
this.name = name; this.name = name;
this.channels = new DefaultChannelGroup(name, GlobalEventExecutor.INSTANCE); this.channels = new DefaultChannelGroup(name, GlobalEventExecutor.INSTANCE);
} }
/**
* Returns the name of this group.
*
* @return the group name
*/
public String name() { public String name() {
return name; return name;
} }
/**
* Adds a session to the group.
*
* @param session the session to add
* @return this group, for fluent chaining
*/
public WebSocketGroup add(WebSocketSession session) { public WebSocketGroup add(WebSocketSession session) {
channels.add(session.channel()); channels.add(session.channel());
return this; return this;
} }
/**
* Removes a session from the group.
*
* @param session the session to remove
* @return this group, for fluent chaining
*/
public WebSocketGroup remove(WebSocketSession session) { public WebSocketGroup remove(WebSocketSession session) {
channels.remove(session.channel()); channels.remove(session.channel());
return this; return this;
} }
/**
* Returns how many connections are currently in the group.
*
* @return the current number of member connections
*/
public int size() { public int size() {
return channels.size(); return channels.size();
} }
/**
* Broadcasts a text message to every member of the group.
*
* @param text the text to send
* @return this group, for fluent chaining
*/
public WebSocketGroup broadcast(String text) { public WebSocketGroup broadcast(String text) {
channels.writeAndFlush(new TextWebSocketFrame(text)); channels.writeAndFlush(new TextWebSocketFrame(text));
return this; return this;
} }
/**
* Serializes the given value to JSON and broadcasts it as a text message to every member.
*
* @param value the object to serialize and broadcast
* @return this group, for fluent chaining
* @throws RuntimeException if JSON serialization fails
*/
public WebSocketGroup broadcastJson(Object value) { public WebSocketGroup broadcastJson(Object value) {
try { try {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value); byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
@@ -56,6 +110,12 @@ public final class WebSocketGroup {
return this; return this;
} }
/**
* Broadcasts a binary message to every active member, allocating a fresh buffer per channel.
*
* @param data the bytes to broadcast
* @return this group, for fluent chaining
*/
public WebSocketGroup broadcastBinary(byte[] data) { public WebSocketGroup broadcastBinary(byte[] data) {
for (var ch : channels) { for (var ch : channels) {
if (ch.isActive()) { if (ch.isActive()) {
@@ -66,6 +126,14 @@ public final class WebSocketGroup {
return this; return this;
} }
/**
* Broadcasts a text message to every active member except one &mdash; typically the sender,
* so a client does not receive its own message echoed back.
*
* @param exclude the session to skip, or {@code null} to broadcast to everyone
* @param text the text to send
* @return this group, for fluent chaining
*/
public WebSocketGroup broadcastExcept(WebSocketSession exclude, String text) { public WebSocketGroup broadcastExcept(WebSocketSession exclude, String text) {
var excludeCh = exclude == null ? null : exclude.channel(); var excludeCh = exclude == null ? null : exclude.channel();
for (var ch : channels) { for (var ch : channels) {
@@ -76,6 +144,11 @@ public final class WebSocketGroup {
return this; return this;
} }
/**
* Closes every connection in the group.
*
* @return this group, for fluent chaining
*/
public WebSocketGroup closeAll() { public WebSocketGroup closeAll() {
channels.close(); channels.close();
return this; return this;
@@ -1,19 +1,67 @@
package dev.coph.nextusweb.server.websocket; package dev.coph.nextusweb.server.websocket;
/**
* Application-facing callback interface for a WebSocket endpoint. Implementations react to the
* lifecycle events of a single connection: opening, incoming text and binary messages, closing,
* and errors.
*
* <p>Every method has an empty default implementation, so handlers need only override the
* events they care about. Callbacks are dispatched on virtual threads by the framework, so they
* may perform blocking work, and they are allowed to throw &mdash; any thrown exception is
* routed to {@link #onError(WebSocketSession, Throwable)}.</p>
*
* @see WebSocketSession
* @see WebSocketRouter
*/
public interface WebSocketHandler { public interface WebSocketHandler {
/**
* Invoked once the WebSocket handshake has completed and the session is ready for use.
*
* @param session the newly opened session
* @throws Exception if the handler fails; routed to {@link #onError}
*/
default void onOpen(WebSocketSession session) throws Exception { default void onOpen(WebSocketSession session) throws Exception {
} }
/**
* Invoked when a text message is received.
*
* @param session the session the message arrived on
* @param message the decoded text payload
* @throws Exception if the handler fails; routed to {@link #onError}
*/
default void onMessage(WebSocketSession session, String message) throws Exception { default void onMessage(WebSocketSession session, String message) throws Exception {
} }
/**
* Invoked when a binary message is received.
*
* @param session the session the message arrived on
* @param data the raw binary payload
* @throws Exception if the handler fails; routed to {@link #onError}
*/
default void onBinary(WebSocketSession session, byte[] data) throws Exception { default void onBinary(WebSocketSession session, byte[] data) throws Exception {
} }
/**
* Invoked when the connection closes, whether initiated by the peer or the server.
*
* @param session the session being closed
* @param code the WebSocket close status code
* @param reason the close reason text (empty if none was provided)
* @throws Exception if the handler fails; routed to {@link #onError}
*/
default void onClose(WebSocketSession session, int code, String reason) throws Exception { default void onClose(WebSocketSession session, int code, String reason) throws Exception {
} }
/**
* Invoked when an error occurs on the connection or when another callback throws.
*
* @param session the affected session
* @param cause the error that occurred
* @throws Exception if the error handler itself fails (such failures are swallowed)
*/
default void onError(WebSocketSession session, Throwable cause) throws Exception { default void onError(WebSocketSession session, Throwable cause) throws Exception {
} }
} }
@@ -6,10 +6,33 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
/**
* A trie-based router that maps WebSocket upgrade paths to {@link WebSocketHandler}s.
*
* <p>It mirrors the HTTP {@link dev.coph.nextusweb.server.router.Router Router} but is simpler:
* a path resolves to a single handler (there is no HTTP method dimension) and only static and
* {@code {param}} path-parameter segments are supported (no wildcards). Registration mutates the
* shared trie at start-up; {@link #resolve(String)} is safe to call concurrently afterwards.</p>
*/
public final class WebSocketRouter { public final class WebSocketRouter {
/** Root of the routing trie. */
private final Node root = new Node(); private final Node root = new Node();
/**
* Creates an empty WebSocket router with no registered handlers.
*/
public WebSocketRouter() {
}
/**
* Registers a handler at the given path, creating any missing trie nodes. Segments wrapped
* in braces (e.g. {@code /chat/{room}}) are treated as path parameters.
*
* @param path the WebSocket path to mount the handler at
* @param handler the handler to invoke for connections on that path
* @return this router, for fluent chaining
*/
public WebSocketRouter on(String path, WebSocketHandler handler) { public WebSocketRouter on(String path, WebSocketHandler handler) {
Node node = root; Node node = root;
for (String segment : split(path)) { for (String segment : split(path)) {
@@ -27,6 +50,13 @@ public final class WebSocketRouter {
return this; return this;
} }
/**
* Resolves a path to its handler, capturing any path parameters along the way.
*
* @param path the request path
* @return a {@link Resolution} carrying the handler and captured parameters, or {@code null}
* if no handler is registered for the path
*/
public Resolution resolve(String path) { public Resolution resolve(String path) {
Map<String, String> params = new HashMap<>(4); Map<String, String> params = new HashMap<>(4);
Node node = root; Node node = root;
@@ -45,6 +75,13 @@ public final class WebSocketRouter {
return new Resolution(node.handler, params); return new Resolution(node.handler, params);
} }
/**
* Splits a path into its non-empty segments, ignoring leading and collapsing internal
* slashes.
*
* @param path the raw path
* @return the ordered list of path segments
*/
private static List<String> split(String path) { private static List<String> split(String path) {
List<String> out = new ArrayList<>(); List<String> out = new ArrayList<>();
int start = path.startsWith("/") ? 1 : 0; int start = path.startsWith("/") ? 1 : 0;
@@ -58,13 +95,27 @@ public final class WebSocketRouter {
return out; return out;
} }
/**
* A successful path resolution.
*
* @param handler the handler bound to the matched path
* @param pathParams the path parameters captured while matching, keyed by name
*/
public record Resolution(WebSocketHandler handler, Map<String, String> pathParams) { public record Resolution(WebSocketHandler handler, Map<String, String> pathParams) {
} }
/**
* A single node in the WebSocket routing trie. Holds static children keyed by segment, an
* optional path-parameter child, and the handler (if any) registered at this node.
*/
private static final class Node { private static final class Node {
/** Static child nodes keyed by their literal path segment. */
final Map<String, Node> children = new ConcurrentHashMap<>(); final Map<String, Node> children = new ConcurrentHashMap<>();
/** Child matching any single segment as a path parameter, or {@code null} if none. */
Node paramChild; Node paramChild;
/** Name under which {@link #paramChild} captures the matched segment. */
String paramName; String paramName;
/** Handler registered at this node, or {@code null} if the path is only a prefix. */
WebSocketHandler handler; WebSocketHandler handler;
} }
} }
@@ -20,17 +20,44 @@ import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
/**
* Represents a single, live WebSocket connection and is the primary object application handlers
* interact with.
*
* <p>It wraps the underlying Netty {@link Channel} and offers convenient methods to send text,
* JSON and binary payloads, to ping the peer, and to close the connection. It also carries
* read-only connection metadata (a generated id, the path, and captured path parameters) and a
* thread-safe bag of arbitrary {@link #attribute(String, Object) attributes} that handlers can
* use to associate state with the connection.</p>
*
* <p>Each connection's session is stored on its channel under {@link #SESSION_KEY} so the frame
* handler can retrieve it for every incoming frame.</p>
*/
public final class WebSocketSession { public final class WebSocketSession {
/** Channel attribute key under which the session is stored on its Netty channel. */
static final AttributeKey<WebSocketSession> SESSION_KEY = static final AttributeKey<WebSocketSession> SESSION_KEY =
AttributeKey.valueOf("nexusweb.ws.session"); AttributeKey.valueOf("nexusweb.ws.session");
/** The underlying Netty channel for this connection. */
private final Channel channel; private final Channel channel;
/** Unique identifier generated for this session. */
private final String id; private final String id;
/** The path the connection was established on. */
private final String path; private final String path;
/** Path parameters captured during routing, keyed by name. */
private final Map<String, String> pathParams; private final Map<String, String> pathParams;
/** Thread-safe bag of user-defined attributes attached to the session. */
private final Map<String, Object> attributes = new ConcurrentHashMap<>(); private final Map<String, Object> attributes = new ConcurrentHashMap<>();
/**
* Creates a session for a freshly upgraded channel. Package-private; created by the frame
* handler once the handshake completes.
*
* @param channel the underlying Netty channel
* @param path the connection path
* @param pathParams the path parameters captured during routing
*/
WebSocketSession(Channel channel, String path, Map<String, String> pathParams) { WebSocketSession(Channel channel, String path, Map<String, String> pathParams) {
this.channel = channel; this.channel = channel;
this.id = UUID.randomUUID().toString(); this.id = UUID.randomUUID().toString();
@@ -38,22 +65,49 @@ public final class WebSocketSession {
this.pathParams = pathParams; this.pathParams = pathParams;
} }
/**
* Returns the unique identifier generated for this session.
*
* @return the unique session id
*/
public String id() { public String id() {
return id; return id;
} }
/**
* Returns the path the connection was established on.
*
* @return the path the connection was established on
*/
public String path() { public String path() {
return path; return path;
} }
/**
* Returns the value of a path parameter captured during routing.
*
* @param name the parameter name (without braces)
* @return the captured value, or {@code null} if there is no such parameter
*/
public String pathParam(String name) { public String pathParam(String name) {
return pathParams.get(name); return pathParams.get(name);
} }
/**
* Indicates whether the connection is still open.
*
* @return {@code true} if the underlying channel is still active (open)
*/
public boolean isOpen() { public boolean isOpen() {
return channel.isActive(); return channel.isActive();
} }
/**
* Returns the peer's remote IP address.
*
* @return the remote host address, or a string form of the address if it is not an
* {@link InetSocketAddress}; {@code null} if unavailable
*/
public String remoteAddress() { public String remoteAddress() {
SocketAddress addr = channel.remoteAddress(); SocketAddress addr = channel.remoteAddress();
if (addr instanceof InetSocketAddress inet) { if (addr instanceof InetSocketAddress inet) {
@@ -62,26 +116,61 @@ public final class WebSocketSession {
return addr == null ? null : addr.toString(); return addr == null ? null : addr.toString();
} }
/**
* Returns the underlying Netty channel for advanced, low-level use.
*
* @return the underlying Netty channel, for advanced use
*/
public Channel channel() { public Channel channel() {
return channel; return channel;
} }
/**
* Associates a user-defined attribute with this session, or removes it when {@code value} is
* {@code null}.
*
* @param name the attribute name
* @param value the value to store, or {@code null} to remove the attribute
* @return this session, for fluent chaining
*/
public WebSocketSession attribute(String name, Object value) { public WebSocketSession attribute(String name, Object value) {
if (value == null) attributes.remove(name); if (value == null) attributes.remove(name);
else attributes.put(name, value); else attributes.put(name, value);
return this; return this;
} }
/**
* Retrieves a previously stored attribute, cast to the caller's expected type.
*
* @param name the attribute name
* @param <T> the expected attribute type
* @return the stored value, or {@code null} if absent
*/
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T attribute(String name) { public <T> T attribute(String name) {
return (T) attributes.get(name); return (T) attributes.get(name);
} }
/**
* Sends a text message to the peer.
*
* @param text the text to send
* @return a future completing when the write finishes; an already-succeeded future if the
* channel is no longer active
*/
public ChannelFuture send(String text) { public ChannelFuture send(String text) {
if (!channel.isActive()) return channel.newSucceededFuture(); if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new TextWebSocketFrame(text)); return channel.writeAndFlush(new TextWebSocketFrame(text));
} }
/**
* Serializes the given value to JSON and sends it as a text message.
*
* @param value the object to serialize and send
* @return a future completing when the write finishes; an already-succeeded future if the
* channel is no longer active
* @throws RuntimeException if JSON serialization fails
*/
public ChannelFuture sendJson(Object value) { public ChannelFuture sendJson(Object value) {
try { try {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value); byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value);
@@ -93,27 +182,63 @@ public final class WebSocketSession {
} }
} }
/**
* Sends a binary message to the peer.
*
* @param data the bytes to send
* @return a future completing when the write finishes; an already-succeeded future if the
* channel is no longer active
*/
public ChannelFuture sendBinary(byte[] data) { public ChannelFuture sendBinary(byte[] data) {
if (!channel.isActive()) return channel.newSucceededFuture(); if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(data); ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(data);
return channel.writeAndFlush(new BinaryWebSocketFrame(buf)); return channel.writeAndFlush(new BinaryWebSocketFrame(buf));
} }
/**
* Sends a WebSocket ping frame to the peer (e.g. as a keep-alive).
*
* @return a future completing when the write finishes; an already-succeeded future if the
* channel is no longer active
*/
public ChannelFuture ping() { public ChannelFuture ping() {
if (!channel.isActive()) return channel.newSucceededFuture(); if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new PingWebSocketFrame()); return channel.writeAndFlush(new PingWebSocketFrame());
} }
/**
* Closes the connection with the normal-closure status code {@code 1000} and no reason.
*
* @return a future completing when the close frame has been written
*/
public ChannelFuture close() { public ChannelFuture close() {
return close(1000, ""); return close(1000, "");
} }
/**
* Closes the connection with an explicit status code and reason, closing the channel once
* the close frame has been written.
*
* @param code the WebSocket close status code
* @param reason the human-readable close reason
* @return a future completing when the close frame has been written; an already-succeeded
* future if the channel is no longer active
*/
public ChannelFuture close(int code, String reason) { public ChannelFuture close(int code, String reason) {
if (!channel.isActive()) return channel.newSucceededFuture(); if (!channel.isActive()) return channel.newSucceededFuture();
return channel.writeAndFlush(new CloseWebSocketFrame(code, reason)) return channel.writeAndFlush(new CloseWebSocketFrame(code, reason))
.addListener(ChannelFutureListener.CLOSE); .addListener(ChannelFutureListener.CLOSE);
} }
/**
* Low-level helper that writes a text payload directly to a channel, allocating the buffer
* from the channel's allocator. Used by collaborators that hold a channel but not a session.
*
* @param channel the channel to write to
* @param text the text to send
* @return a future completing when the write finishes; an already-succeeded future if the
* channel is no longer active
*/
static ChannelFuture sendRaw(Channel channel, String text) { static ChannelFuture sendRaw(Channel channel, String text) {
if (!channel.isActive()) return channel.newSucceededFuture(); if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(); ByteBuf buf = channel.alloc().buffer();
@@ -121,6 +246,14 @@ public final class WebSocketSession {
return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf)); return channel.writeAndFlush(new TextWebSocketFrame(true, 0, buf));
} }
/**
* Low-level helper that writes a binary payload directly to a channel.
*
* @param channel the channel to write to
* @param data the bytes to send
* @return a future completing when the write finishes; an already-succeeded future if the
* channel is no longer active
*/
static ChannelFuture sendRawBinary(Channel channel, byte[] data) { static ChannelFuture sendRawBinary(Channel channel, byte[] data) {
if (!channel.isActive()) return channel.newSucceededFuture(); if (!channel.isActive()) return channel.newSucceededFuture();
ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(Unpooled.wrappedBuffer(data)); ByteBuf buf = channel.alloc().buffer(data.length).writeBytes(Unpooled.wrappedBuffer(data));
@@ -0,0 +1,31 @@
package dev.coph.nextusweb.server;
import dev.coph.nextusweb.server.cores.CorsConfig;
import dev.coph.nextusweb.server.cores.CorsHandler;
import dev.coph.nextusweb.server.ratelimit.RateLimitConfig;
import dev.coph.nextusweb.server.ratelimit.RateLimitGate;
import dev.coph.nextusweb.server.router.Router;
import dev.coph.nextusweb.server.websocket.WebSocketConfig;
import dev.coph.nextusweb.server.websocket.WebSocketRouter;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class HttpServerTest {
@Test
void builderReturnsConfiguredServer() {
Router router = new Router();
HttpServer server = HttpServer.builder(0, router);
assertNotNull(server);
assertSame(server, server.withCorsHandler(new CorsHandler(CorsConfig.permissive())));
RateLimitGate gate = new RateLimitGate(RateLimitConfig.builder().build());
try {
assertSame(server, server.withRateLimitGate(gate));
assertSame(server, server.withWebSockets(new WebSocketRouter()));
assertSame(server, server.withWebSockets(new WebSocketRouter(), WebSocketConfig.defaults()));
} finally {
gate.shutdown();
}
}
}
@@ -0,0 +1,154 @@
package dev.coph.nextusweb.server.annotation;
import dev.coph.nextusweb.server.router.Request;
import dev.coph.nextusweb.server.router.Response;
import dev.coph.nextusweb.server.router.Router;
import io.netty.handler.codec.http.HttpMethod;
import org.junit.jupiter.api.Test;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.jupiter.api.Assertions.*;
class AnnotationScannerTest {
@Controller("/api")
static class GoodController {
AtomicBoolean called = new AtomicBoolean(false);
@GET("/hello")
public void hello(Request req, Response res) {
called.set(true);
res.text("hi");
}
@POST("post")
public void post(Request req, Response res) {}
@Route(method = "PUT", path = "/put")
public void put(Request req, Response res) {}
@CUSTOM(method = "OPTIONS", value = "/opt")
public void opt(Request req, Response res) {}
public void notAnnotated(Request req, Response res) {}
}
static class NoControllerAnnotation {
@GET("/x")
public void x(Request req, Response res) {}
}
@Controller("nopref")
static class PrefixNoSlash {
@GET("/a")
public void a(Request req, Response res) {}
}
static class BadSignature {
@GET("/bad")
public String bad(String s) { return s; }
}
static class WrongReturnType {
@GET("/bad")
public String bad(Request req, Response res) { return "no"; }
}
@Test
void registersAllAnnotatedRoutesWithPrefix() {
Router router = new Router();
GoodController ctrl = new GoodController();
AnnotationScanner.register(router, ctrl);
assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.GET, "/api/hello"));
assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.POST, "/api/post"));
assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.PUT, "/api/put"));
assertInstanceOf(Router.Resolution.Match.class,
router.resolve(HttpMethod.valueOf("OPTIONS"), "/api/opt"));
}
@Test
void registrationWithoutControllerAnnotationUsesEmptyPrefix() {
Router router = new Router();
AnnotationScanner.register(router, new NoControllerAnnotation());
assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.GET, "/x"));
}
@Test
void prefixWithoutLeadingSlashIsNormalized() {
Router router = new Router();
AnnotationScanner.register(router, new PrefixNoSlash());
assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.GET, "/nopref/a"));
}
@Test
void invokingRegisteredHandlerCallsTheControllerMethod() throws Exception {
Router router = new Router();
GoodController ctrl = new GoodController();
AnnotationScanner.register(router, ctrl);
Router.Resolution res = router.resolve(HttpMethod.GET, "/api/hello");
Router.Resolution.Match m = assertInstanceOf(Router.Resolution.Match.class, res);
Response resp = new Response();
m.handler().handle(null, resp);
assertTrue(ctrl.called.get());
assertEquals(200, resp.status());
}
@Test
void badSignatureThrows() {
Router router = new Router();
assertThrows(IllegalArgumentException.class,
() -> AnnotationScanner.register(router, new BadSignature()));
}
@Test
void wrongReturnTypeThrows() {
Router router = new Router();
assertThrows(IllegalArgumentException.class,
() -> AnnotationScanner.register(router, new WrongReturnType()));
}
@Controller("/users")
static class CrudController {
java.util.concurrent.atomic.AtomicInteger getCalls = new java.util.concurrent.atomic.AtomicInteger();
java.util.concurrent.atomic.AtomicInteger putCalls = new java.util.concurrent.atomic.AtomicInteger();
java.util.concurrent.atomic.AtomicInteger deleteCalls = new java.util.concurrent.atomic.AtomicInteger();
@GET("/")
public void list(Request req, Response res) { getCalls.incrementAndGet(); }
@PUT("/")
public void replace(Request req, Response res) { putCalls.incrementAndGet(); }
@DELETE("/")
public void wipe(Request req, Response res) { deleteCalls.incrementAndGet(); }
}
@Test
void multipleMethodsOnSamePathRouteToDistinctHandlers() throws Exception {
Router router = new Router();
CrudController ctrl = new CrudController();
AnnotationScanner.register(router, ctrl);
var get = assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.GET, "/users/"));
var put = assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.PUT, "/users/"));
var del = assertInstanceOf(Router.Resolution.Match.class, router.resolve(HttpMethod.DELETE, "/users/"));
get.handler().handle(null, new Response());
put.handler().handle(null, new Response());
del.handler().handle(null, new Response());
assertEquals(1, ctrl.getCalls.get());
assertEquals(1, ctrl.putCalls.get());
assertEquals(1, ctrl.deleteCalls.get());
var post = router.resolve(HttpMethod.POST, "/users/");
var mna = assertInstanceOf(Router.Resolution.MethodNotAllowed.class, post);
assertTrue(mna.allowedMethods().contains(HttpMethod.GET));
assertTrue(mna.allowedMethods().contains(HttpMethod.PUT));
assertTrue(mna.allowedMethods().contains(HttpMethod.DELETE));
}
}
@@ -0,0 +1,81 @@
package dev.coph.nextusweb.server.annotation;
import org.junit.jupiter.api.Test;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import static org.junit.jupiter.api.Assertions.*;
class AnnotationsTest {
@Controller
static class CtrlDefault {}
@Controller("/api")
static class CtrlWithValue {}
static class Routes {
@GET("/g") public void g() {}
@POST("/p") public void p() {}
@PUT("/u") public void u() {}
@PATCH("/pa") public void pa() {}
@DELETE("/d") public void d() {}
@CUSTOM(method = "OPTIONS", value = "/o") public void o() {}
@Route(method = "TRACE", path = "/t") public void t() {}
}
@Test
void controllerDefaultValueIsEmpty() {
Controller c = CtrlDefault.class.getAnnotation(Controller.class);
assertNotNull(c);
assertEquals("", c.value());
}
@Test
void controllerValueIsCarried() {
Controller c = CtrlWithValue.class.getAnnotation(Controller.class);
assertNotNull(c);
assertEquals("/api", c.value());
}
@Test
void controllerHasRuntimeRetentionAndTypeTarget() {
Retention r = Controller.class.getAnnotation(Retention.class);
Target t = Controller.class.getAnnotation(Target.class);
assertEquals(RetentionPolicy.RUNTIME, r.value());
assertArrayEquals(new ElementType[]{ElementType.TYPE}, t.value());
}
@Test
void routeMethodAnnotationsCarryValues() throws Exception {
assertEquals("/g", Routes.class.getDeclaredMethod("g").getAnnotation(GET.class).value());
assertEquals("/p", Routes.class.getDeclaredMethod("p").getAnnotation(POST.class).value());
assertEquals("/u", Routes.class.getDeclaredMethod("u").getAnnotation(PUT.class).value());
assertEquals("/pa", Routes.class.getDeclaredMethod("pa").getAnnotation(PATCH.class).value());
assertEquals("/d", Routes.class.getDeclaredMethod("d").getAnnotation(DELETE.class).value());
CUSTOM custom = Routes.class.getDeclaredMethod("o").getAnnotation(CUSTOM.class);
assertEquals("OPTIONS", custom.method());
assertEquals("/o", custom.value());
Route route = Routes.class.getDeclaredMethod("t").getAnnotation(Route.class);
assertEquals("TRACE", route.method());
assertEquals("/t", route.path());
}
@Test
void allMethodAnnotationsTargetMethods() {
Class<?>[] anns = {GET.class, POST.class, PUT.class, PATCH.class, DELETE.class, CUSTOM.class, Route.class};
for (Class<?> a : anns) {
Target t = a.getAnnotation(Target.class);
Retention r = a.getAnnotation(Retention.class);
assertNotNull(t, a.getSimpleName() + " missing @Target");
assertNotNull(r, a.getSimpleName() + " missing @Retention");
assertEquals(RetentionPolicy.RUNTIME, r.value());
assertArrayEquals(new ElementType[]{ElementType.METHOD}, t.value());
}
}
}
@@ -0,0 +1,61 @@
package dev.coph.nextusweb.server.cores;
import io.netty.handler.codec.http.HttpMethod;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class CorsConfigTest {
@Test
void permissiveBuildsWithExpectedDefaults() {
CorsConfig c = CorsConfig.permissive();
assertTrue(c.allowAnyOrigin());
assertFalse(c.allowCredentials());
assertEquals(3600, c.maxAgeSeconds());
assertTrue(c.allowedMethods().contains(HttpMethod.GET));
assertTrue(c.allowedMethods().contains(HttpMethod.OPTIONS));
assertTrue(c.allowedHeaders().contains("Authorization"));
}
@Test
void isOriginAllowedHandlesNullAndWildcard() {
assertFalse(CorsConfig.permissive().isOriginAllowed(null));
assertTrue(CorsConfig.permissive().isOriginAllowed("https://anything"));
}
@Test
void exactOriginMatchOnly() {
CorsConfig c = CorsConfig.builder()
.allowedOrigins("https://a.com")
.build();
assertTrue(c.isOriginAllowed("https://a.com"));
assertFalse(c.isOriginAllowed("https://b.com"));
assertFalse(c.isOriginAllowed(null));
}
@Test
void wildcardWithCredentialsIsRejected() {
assertThrows(IllegalStateException.class, () -> CorsConfig.builder()
.anyOrigin()
.allowCredentials(true)
.build());
}
@Test
void exposedAndAllowedHeadersAreCopied() {
CorsConfig c = CorsConfig.builder()
.allowedHeaders("A", "B")
.exposedHeaders("X")
.allowedMethods(HttpMethod.GET)
.allowCredentials(true)
.maxAgeSeconds(60)
.allowedOrigins("http://a")
.build();
assertEquals(2, c.allowedHeaders().size());
assertTrue(c.exposedHeaders().contains("X"));
assertTrue(c.allowCredentials());
assertEquals(60, c.maxAgeSeconds());
assertFalse(c.allowAnyOrigin());
}
}
@@ -0,0 +1,95 @@
package dev.coph.nextusweb.server.cores;
import dev.coph.nextusweb.server.router.Response;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class CorsHandlerTest {
private CorsHandler permissiveHandler() {
return new CorsHandler(CorsConfig.permissive());
}
@Test
void applyHeadersDoesNothingForNullOrigin() {
Response res = new Response();
permissiveHandler().applyHeaders(null, res);
assertNull(res.headers().get("Access-Control-Allow-Origin"));
}
@Test
void applyHeadersDoesNothingForDisallowedOrigin() {
CorsHandler h = new CorsHandler(CorsConfig.builder()
.allowedOrigins("https://allow")
.allowedMethods(HttpMethod.GET).build());
Response res = new Response();
h.applyHeaders("https://other", res);
assertNull(res.headers().get("Access-Control-Allow-Origin"));
}
@Test
void applyHeadersWritesWildcardWhenAnyOriginAndNoCreds() {
Response res = new Response();
permissiveHandler().applyHeaders("https://x.com", res);
assertEquals("*", res.headers().get("Access-Control-Allow-Origin"));
}
@Test
void applyHeadersWritesOriginAndVaryWhenSpecific() {
CorsHandler h = new CorsHandler(CorsConfig.builder()
.allowedOrigins("https://a")
.allowedMethods(HttpMethod.GET)
.allowCredentials(true)
.exposedHeaders("X-Custom").build());
Response res = new Response();
h.applyHeaders("https://a", res);
assertEquals("https://a", res.headers().get("Access-Control-Allow-Origin"));
assertEquals("Origin", res.headers().get("Vary"));
assertEquals("true", res.headers().get("Access-Control-Allow-Credentials"));
assertEquals("X-Custom", res.headers().get("Access-Control-Expose-Headers"));
}
@Test
void isPreflightTrueOnlyForOptionsWithRequestMethod() {
HttpHeaders hs = new DefaultHttpHeaders();
hs.set("Access-Control-Request-Method", "GET");
assertTrue(permissiveHandler().isPreflight(HttpMethod.OPTIONS, hs));
assertFalse(permissiveHandler().isPreflight(HttpMethod.GET, hs));
HttpHeaders empty = new DefaultHttpHeaders();
assertFalse(permissiveHandler().isPreflight(HttpMethod.OPTIONS, empty));
}
@Test
void handlePreflightReturns403ForDisallowedOrigin() {
CorsHandler h = new CorsHandler(CorsConfig.builder()
.allowedOrigins("https://allow")
.allowedMethods(HttpMethod.GET).build());
Response res = h.handlePreflight("https://other", new DefaultHttpHeaders());
assertEquals(403, res.status());
}
@Test
void handlePreflightWritesAllowAndMaxAge() {
Response res = permissiveHandler().handlePreflight("https://x.com", new DefaultHttpHeaders());
assertEquals(204, res.status());
assertNotNull(res.headers().get("Access-Control-Allow-Methods"));
assertEquals("3600", res.headers().get("Access-Control-Max-Age"));
}
@Test
void handlePreflightEchoesRequestedHeadersIfNoneConfigured() {
CorsConfig cfg = CorsConfig.builder()
.anyOrigin()
.allowedMethods(HttpMethod.GET)
.build();
CorsHandler h = new CorsHandler(cfg);
HttpHeaders req = new DefaultHttpHeaders();
req.set("Access-Control-Request-Headers", "x-foo");
Response res = h.handlePreflight("https://x.com", req);
assertEquals("x-foo", res.headers().get("Access-Control-Allow-Headers"));
}
}
@@ -0,0 +1,36 @@
package dev.coph.nextusweb.server.json;
import org.junit.jupiter.api.Test;
import tools.jackson.databind.JsonNode;
import static org.junit.jupiter.api.Assertions.*;
class JsonMapperTest {
@Test
void mapperIsAvailable() {
assertNotNull(JsonMapper.MAPPER);
}
@Test
void mapperRoundTripsSimpleValues() {
var node = JsonMapper.MAPPER.valueToTree(java.util.Map.of("a", 1));
assertTrue(node.isObject());
assertEquals(1, node.get("a").asInt());
}
@Test
void mapperReadsTree() {
JsonNode n = JsonMapper.MAPPER.readTree("{\"k\":\"v\"}");
assertTrue(n.has("k"));
assertNotNull(n.get("k"));
}
@Test
void mapperSerializesToBytes() {
byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(java.util.Map.of("a", "b"));
String s = new String(bytes, java.nio.charset.StandardCharsets.UTF_8);
assertTrue(s.contains("\"a\""));
assertTrue(s.contains("\"b\""));
}
}
@@ -0,0 +1,43 @@
package dev.coph.nextusweb.server.ratelimit;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class FixedWindowLimiterTest {
@Test
void allowsUpToLimitThenDenies() {
FixedWindowLimiter lim = new FixedWindowLimiter(3, 1000);
assertTrue(lim.tryAcquire("k", 0).allowed());
assertTrue(lim.tryAcquire("k", 0).allowed());
assertTrue(lim.tryAcquire("k", 0).allowed());
RateLimiter.Result r = lim.tryAcquire("k", 0);
assertFalse(r.allowed());
assertEquals(3, r.limit());
assertTrue(r.retryAfterMillis() > 0);
}
@Test
void newWindowResetsCount() {
FixedWindowLimiter lim = new FixedWindowLimiter(1, 100);
assertTrue(lim.tryAcquire("k", 0).allowed());
assertFalse(lim.tryAcquire("k", 0).allowed());
long windowNs = 100L * 1_000_000L;
assertTrue(lim.tryAcquire("k", windowNs).allowed());
}
@Test
void differentKeysAreIndependent() {
FixedWindowLimiter lim = new FixedWindowLimiter(1, 1000);
assertTrue(lim.tryAcquire("a", 0).allowed());
assertTrue(lim.tryAcquire("b", 0).allowed());
}
@Test
void cleanupDoesNotThrow() {
FixedWindowLimiter lim = new FixedWindowLimiter(1, 1000);
lim.tryAcquire("k", System.nanoTime());
assertDoesNotThrow(() -> lim.cleanup(0));
}
}
@@ -0,0 +1,53 @@
package dev.coph.nextusweb.server.ratelimit;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class KeyResolverTest {
private HttpRequest req(String header, String value) {
HttpRequest r = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (header != null) r.headers().set(header, value);
return r;
}
@Test
void clientIpUsesRemoteWhenNoForwardedHeader() {
assertEquals("10.0.0.1", KeyResolver.clientIp().resolve(req(null, null), "10.0.0.1"));
}
@Test
void clientIpUsesForwardedHeaderFirstValue() {
HttpRequest r = req("X-Forwarded-For", "1.1.1.1, 2.2.2.2");
assertEquals("1.1.1.1", KeyResolver.clientIp().resolve(r, "10.0.0.1"));
}
@Test
void clientIpHandlesSingleForwardedValue() {
HttpRequest r = req("X-Forwarded-For", "3.3.3.3");
assertEquals("3.3.3.3", KeyResolver.clientIp().resolve(r, "10.0.0.1"));
}
@Test
void userOrIpReturnsBearerToken() {
HttpRequest r = req("Authorization", "Bearer abc123");
assertEquals("u:abc123", KeyResolver.userOrIp().resolve(r, "10.0.0.1"));
}
@Test
void userOrIpFallsBackToClientIp() {
HttpRequest r = req(null, null);
assertEquals("ip:10.0.0.1", KeyResolver.userOrIp().resolve(r, "10.0.0.1"));
}
@Test
void userOrIpIgnoresNonBearerAuth() {
HttpRequest r = req("Authorization", "Basic xyz");
assertEquals("ip:10.0.0.1", KeyResolver.userOrIp().resolve(r, "10.0.0.1"));
}
}
@@ -0,0 +1,43 @@
package dev.coph.nextusweb.server.ratelimit;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class LeakyBucketLimiterTest {
@Test
void fillsUpToCapacityThenDenies() {
LeakyBucketLimiter lim = new LeakyBucketLimiter(1, 2);
long now = 0;
assertTrue(lim.tryAcquire("k", now).allowed());
assertTrue(lim.tryAcquire("k", now).allowed());
RateLimiter.Result r = lim.tryAcquire("k", now);
assertFalse(r.allowed());
assertEquals(2, r.limit());
assertTrue(r.retryAfterMillis() > 0);
}
@Test
void leakReducesLevelAndAllowsAgain() {
LeakyBucketLimiter lim = new LeakyBucketLimiter(10, 1);
assertTrue(lim.tryAcquire("k", 0).allowed());
assertFalse(lim.tryAcquire("k", 0).allowed());
long oneSec = 1_000_000_000L;
assertTrue(lim.tryAcquire("k", oneSec).allowed());
}
@Test
void differentKeysAreIndependent() {
LeakyBucketLimiter lim = new LeakyBucketLimiter(1, 1);
assertTrue(lim.tryAcquire("a", 0).allowed());
assertTrue(lim.tryAcquire("b", 0).allowed());
}
@Test
void cleanupDoesNotThrow() {
LeakyBucketLimiter lim = new LeakyBucketLimiter(1, 1);
lim.tryAcquire("k", System.nanoTime());
assertDoesNotThrow(() -> lim.cleanup(0));
}
}
@@ -0,0 +1,78 @@
package dev.coph.nextusweb.server.ratelimit;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
class RateLimitConfigTest {
private RateLimiter alwaysAllow() {
return (k, now) -> RateLimiter.Result.allow(1, 1);
}
private KeyResolver keyer() {
return (req, remote) -> "x";
}
@Test
void emptyConfigReturnsEmptyList() {
RateLimitConfig cfg = RateLimitConfig.builder().build();
assertTrue(cfg.rulesFor("/anything").isEmpty());
}
@Test
void globalOnlyReturnsOneRule() {
RateLimitConfig cfg = RateLimitConfig.builder()
.global(alwaysAllow(), keyer())
.build();
List<RateLimitConfig.Rule> rules = cfg.rulesFor("/x");
assertEquals(1, rules.size());
assertEquals("global", rules.getFirst().name());
}
@Test
void exactPathTrumpsPrefixRule() {
RateLimitConfig cfg = RateLimitConfig.builder()
.forPath("/a/b", alwaysAllow(), keyer())
.forPrefix("/a/", alwaysAllow(), keyer())
.build();
List<RateLimitConfig.Rule> rules = cfg.rulesFor("/a/b");
assertEquals(1, rules.size());
assertEquals("/a/b", rules.getFirst().name());
}
@Test
void prefixRuleMatchesWhenNoExact() {
RateLimitConfig cfg = RateLimitConfig.builder()
.forPrefix("/api/", alwaysAllow(), keyer())
.build();
List<RateLimitConfig.Rule> rules = cfg.rulesFor("/api/users");
assertEquals(1, rules.size());
assertEquals("/api/*", rules.getFirst().name());
}
@Test
void longerPrefixWinsOverShorter() {
RateLimitConfig cfg = RateLimitConfig.builder()
.forPrefix("/api/", alwaysAllow(), keyer())
.forPrefix("/api/v2/", alwaysAllow(), keyer())
.build();
List<RateLimitConfig.Rule> rules = cfg.rulesFor("/api/v2/users");
assertEquals(1, rules.size());
assertEquals("/api/v2/*", rules.getFirst().name());
}
@Test
void globalIsAlwaysIncludedAlongsideMatchedRule() {
RateLimitConfig cfg = RateLimitConfig.builder()
.global(alwaysAllow(), keyer())
.forPath("/x", alwaysAllow(), keyer())
.build();
List<RateLimitConfig.Rule> rules = cfg.rulesFor("/x");
assertEquals(2, rules.size());
assertEquals("global", rules.get(0).name());
assertEquals("/x", rules.get(1).name());
}
}
@@ -0,0 +1,77 @@
package dev.coph.nextusweb.server.ratelimit;
import dev.coph.nextusweb.server.router.Response;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class RateLimitGateTest {
private HttpRequest req() {
return new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
}
@Test
void checkReturnsNullWhenNoRulesMatch() {
RateLimitGate gate = new RateLimitGate(RateLimitConfig.builder().build());
assertNull(gate.check(req(), "/anything", "1.1.1.1"));
gate.shutdown();
}
@Test
void checkAllowsWhenWithinLimit() {
RateLimitGate gate = new RateLimitGate(RateLimitConfig.builder()
.global(new FixedWindowLimiter(2, 1000), KeyResolver.clientIp())
.build());
RateLimiter.Result r = gate.check(req(), "/x", "1.1.1.1");
assertNotNull(r);
assertTrue(r.allowed());
gate.shutdown();
}
@Test
void checkDeniesWhenAnyRuleDenies() {
RateLimitGate gate = new RateLimitGate(RateLimitConfig.builder()
.global(new FixedWindowLimiter(1, 1000), KeyResolver.clientIp())
.build());
gate.check(req(), "/x", "1.1.1.1");
RateLimiter.Result r = gate.check(req(), "/x", "1.1.1.1");
assertNotNull(r);
assertFalse(r.allowed());
gate.shutdown();
}
@Test
void applyHeadersIsNoOpForNull() {
Response res = new Response();
RateLimitGate.applyHeaders(null, res);
assertNull(res.headers().get("X-RateLimit-Limit"));
}
@Test
void applyHeadersWritesLimitAndRemaining() {
Response res = new Response();
RateLimitGate.applyHeaders(RateLimiter.Result.allow(5, 10), res);
assertEquals("10", res.headers().get("X-RateLimit-Limit"));
assertEquals("5", res.headers().get("X-RateLimit-Remaining"));
assertNull(res.headers().get("Retry-After"));
}
@Test
void applyHeadersWritesRetryAfterWhenDenied() {
Response res = new Response();
RateLimitGate.applyHeaders(RateLimiter.Result.deny(10, 2500), res);
assertEquals("3", res.headers().get("Retry-After"));
}
@Test
void applyHeadersClampsNegativeRemainingToZero() {
Response res = new Response();
RateLimitGate.applyHeaders(new RateLimiter.Result(true, -5, 10, 0), res);
assertEquals("0", res.headers().get("X-RateLimit-Remaining"));
}
}
@@ -0,0 +1,26 @@
package dev.coph.nextusweb.server.ratelimit;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class RateLimiterResultTest {
@Test
void allowFactoryProducesAllowed() {
RateLimiter.Result r = RateLimiter.Result.allow(5, 10);
assertTrue(r.allowed());
assertEquals(5, r.remaining());
assertEquals(10, r.limit());
assertEquals(0, r.retryAfterMillis());
}
@Test
void denyFactoryProducesDenied() {
RateLimiter.Result r = RateLimiter.Result.deny(10, 250);
assertFalse(r.allowed());
assertEquals(0, r.remaining());
assertEquals(10, r.limit());
assertEquals(250, r.retryAfterMillis());
}
}
@@ -0,0 +1,38 @@
package dev.coph.nextusweb.server.ratelimit;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class SlidingWindowLimiterTest {
@Test
void allowsUpToLimitThenDenies() {
SlidingWindowLimiter lim = new SlidingWindowLimiter(2, 1000);
assertTrue(lim.tryAcquire("k", 0).allowed());
assertTrue(lim.tryAcquire("k", 0).allowed());
assertFalse(lim.tryAcquire("k", 0).allowed());
}
@Test
void afterFullWindowAllowsAgain() {
SlidingWindowLimiter lim = new SlidingWindowLimiter(1, 100);
assertTrue(lim.tryAcquire("k", 0).allowed());
long twoWindows = 200L * 1_000_000L;
assertTrue(lim.tryAcquire("k", twoWindows).allowed());
}
@Test
void differentKeysAreIndependent() {
SlidingWindowLimiter lim = new SlidingWindowLimiter(1, 1000);
assertTrue(lim.tryAcquire("a", 0).allowed());
assertTrue(lim.tryAcquire("b", 0).allowed());
}
@Test
void cleanupDoesNotThrow() {
SlidingWindowLimiter lim = new SlidingWindowLimiter(1, 1000);
lim.tryAcquire("k", System.nanoTime());
assertDoesNotThrow(() -> lim.cleanup(0));
}
}
@@ -0,0 +1,50 @@
package dev.coph.nextusweb.server.ratelimit;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class TokenBucketLimiterTest {
@Test
void burstUpToCapacityIsAllowed() {
TokenBucketLimiter lim = new TokenBucketLimiter(1, 3);
long now = 0;
assertTrue(lim.tryAcquire("k", now).allowed());
assertTrue(lim.tryAcquire("k", now).allowed());
assertTrue(lim.tryAcquire("k", now).allowed());
}
@Test
void emptyBucketIsDeniedAndRetryAfterIsPositive() {
TokenBucketLimiter lim = new TokenBucketLimiter(1, 1);
assertTrue(lim.tryAcquire("k", 0).allowed());
RateLimiter.Result r = lim.tryAcquire("k", 0);
assertFalse(r.allowed());
assertTrue(r.retryAfterMillis() > 0);
assertEquals(1, r.limit());
}
@Test
void refillAllowsAcquireAfterTime() {
TokenBucketLimiter lim = new TokenBucketLimiter(10, 1);
assertTrue(lim.tryAcquire("k", 0).allowed());
assertFalse(lim.tryAcquire("k", 0).allowed());
long oneSecLater = 1_000_000_000L;
assertTrue(lim.tryAcquire("k", oneSecLater).allowed());
}
@Test
void differentKeysAreIndependent() {
TokenBucketLimiter lim = new TokenBucketLimiter(1, 1);
assertTrue(lim.tryAcquire("a", 0).allowed());
assertTrue(lim.tryAcquire("b", 0).allowed());
}
@Test
void cleanupDoesNotThrow() {
TokenBucketLimiter lim = new TokenBucketLimiter(1, 1);
lim.tryAcquire("k", System.nanoTime());
assertDoesNotThrow(() -> lim.cleanup(0));
}
}
@@ -0,0 +1,104 @@
package dev.coph.nextusweb.server.router;
import dev.coph.nextusweb.server.router.exception.BadRequestException;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.CharsetUtil;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class RequestTest {
private FullHttpRequest build(HttpMethod method, String uri, String body) {
var content = body == null
? Unpooled.EMPTY_BUFFER
: Unpooled.copiedBuffer(body, CharsetUtil.UTF_8);
return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, method, uri, content);
}
record Payload(String name, int age) {}
@Test
void pathParamReturnsFromMap() {
Request req = new Request(build(HttpMethod.GET, "/u/1", null), Map.of("id", "1"));
assertEquals("1", req.pathParam("id"));
assertNull(req.pathParam("missing"));
}
@Test
void queryParamReturnsFirstValue() {
Request req = new Request(build(HttpMethod.GET, "/?a=1&a=2&b=foo", null), Map.of());
assertEquals("1", req.queryParam("a"));
assertEquals("foo", req.queryParam("b"));
assertNull(req.queryParam("nope"));
}
@Test
void queryParamsReturnsAllValues() {
Request req = new Request(build(HttpMethod.GET, "/?a=1&a=2", null), Map.of());
assertEquals(java.util.List.of("1", "2"), req.queryParams("a"));
assertTrue(req.queryParams("missing").isEmpty());
}
@Test
void headerReturnsValue() {
FullHttpRequest raw = build(HttpMethod.GET, "/", null);
raw.headers().set("X-Foo", "bar");
Request req = new Request(raw, Map.of());
assertEquals("bar", req.header("X-Foo"));
}
@Test
void bodyReturnsContentAsString() {
Request req = new Request(build(HttpMethod.POST, "/", "hello"), Map.of());
assertEquals("hello", req.body());
}
@Test
void methodAndPathExpose() {
Request req = new Request(build(HttpMethod.POST, "/a/b?q=1", null), Map.of());
assertEquals(HttpMethod.POST, req.method());
assertEquals("/a/b", req.path());
}
@Test
void jsonReturnsNullNodeForEmptyBody() {
Request req = new Request(build(HttpMethod.POST, "/", null), Map.of());
var node = req.json();
assertNotNull(node);
assertTrue(node.isNull());
}
@Test
void jsonParsesObject() {
Request req = new Request(build(HttpMethod.POST, "/", "{\"a\":1}"), Map.of());
var node = req.json();
assertTrue(node.has("a"));
}
@Test
void jsonThrowsBadRequestOnInvalidJson() {
Request req = new Request(build(HttpMethod.POST, "/", "not-json"), Map.of());
assertThrows(BadRequestException.class, req::json);
}
@Test
void jsonAsDeserializes() {
Request req = new Request(build(HttpMethod.POST, "/", "{\"name\":\"x\",\"age\":42}"), Map.of());
Payload p = req.jsonAs(Payload.class);
assertEquals("x", p.name());
assertEquals(42, p.age());
}
@Test
void jsonAsThrowsBadRequestOnInvalid() {
Request req = new Request(build(HttpMethod.POST, "/", "not-json"), Map.of());
assertThrows(BadRequestException.class, () -> req.jsonAs(Payload.class));
}
}
@@ -0,0 +1,54 @@
package dev.coph.nextusweb.server.router;
import io.netty.handler.codec.http.HttpHeaderNames;
import org.junit.jupiter.api.Test;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class ResponseTest {
@Test
void defaultStatusIs200AndEmptyBody() {
Response res = new Response();
assertEquals(200, res.status());
assertArrayEquals(new byte[0], res.body());
}
@Test
void statusIsFluent() {
Response res = new Response().status(404);
assertEquals(404, res.status());
}
@Test
void headerSetsValue() {
Response res = new Response().header("X-Foo", "bar");
assertEquals("bar", res.headers().get("X-Foo"));
}
@Test
void textSetsBodyAndContentType() {
Response res = new Response().text("hello");
assertEquals("hello", new String(res.body(), StandardCharsets.UTF_8));
assertTrue(res.headers().get(HttpHeaderNames.CONTENT_TYPE).startsWith("text/plain"));
}
@Test
void jsonStringSetsBodyAndContentType() {
Response res = new Response().json("{\"a\":1}");
assertEquals("{\"a\":1}", new String(res.body(), StandardCharsets.UTF_8));
assertTrue(res.headers().get(HttpHeaderNames.CONTENT_TYPE).startsWith("application/json"));
}
@Test
void jsonObjectSerializesValue() {
Response res = new Response().json(Map.of("k", "v"));
String s = new String(res.body(), StandardCharsets.UTF_8);
assertTrue(s.contains("\"k\""));
assertTrue(s.contains("\"v\""));
assertTrue(res.headers().get(HttpHeaderNames.CONTENT_TYPE).startsWith("application/json"));
}
}
@@ -0,0 +1,163 @@
package dev.coph.nextusweb.server.router;
import io.netty.handler.codec.http.HttpMethod;
import org.junit.jupiter.api.Test;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.*;
class RouterTest {
private final Router.Handler noop = (req, res) -> {};
@Test
void getRegistersAndResolvesExactPath() {
Router r = new Router().get("/hello", noop);
assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.GET, "/hello"));
}
@Test
void postPutDeleteRegister() {
Router r = new Router()
.post("/p", noop)
.put("/u", noop)
.delete("/d", noop);
assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.POST, "/p"));
assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.PUT, "/u"));
assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.DELETE, "/d"));
}
@Test
void notFoundForUnknownPath() {
Router r = new Router().get("/a", noop);
assertInstanceOf(Router.Resolution.NotFound.class, r.resolve(HttpMethod.GET, "/x"));
}
@Test
void methodNotAllowedWhenPathMatchesDifferentMethod() {
Router r = new Router().get("/a", noop);
Router.Resolution res = r.resolve(HttpMethod.POST, "/a");
Router.Resolution.MethodNotAllowed mna = assertInstanceOf(Router.Resolution.MethodNotAllowed.class, res);
assertTrue(mna.allowedMethods().contains(HttpMethod.GET));
}
@Test
void pathParamsAreExtracted() {
Router r = new Router().get("/u/{id}", noop);
Router.Resolution res = r.resolve(HttpMethod.GET, "/u/42");
Router.Resolution.Match m = assertInstanceOf(Router.Resolution.Match.class, res);
assertEquals("42", m.pathParams().get("id"));
}
@Test
void wildcardMatches() {
Router r = new Router().get("/files/*", noop);
assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.GET, "/files/anything"));
}
@Test
void useAddsMiddlewareReturned() {
AtomicInteger count = new AtomicInteger();
Router r = new Router().use((req, res) -> count.incrementAndGet());
assertEquals(1, r.middlewares().size());
r.middlewares().getFirst().accept(null, null);
assertEquals(1, count.get());
}
@Test
void registerWorksWithCustomMethod() {
Router r = new Router().register(HttpMethod.valueOf("OPTIONS"), "/x", noop);
assertInstanceOf(Router.Resolution.Match.class,
r.resolve(HttpMethod.valueOf("OPTIONS"), "/x"));
}
@Test
void handlerInvocationWorks() throws Exception {
AtomicInteger called = new AtomicInteger();
Router r = new Router().get("/x", (req, res) -> called.incrementAndGet());
var match = (Router.Resolution.Match) r.resolve(HttpMethod.GET, "/x");
match.handler().handle(null, null);
assertEquals(1, called.get());
}
@Test
void samePathWithDifferentMethodsResolvesToDistinctHandlers() throws Exception {
AtomicInteger getCalls = new AtomicInteger();
AtomicInteger putCalls = new AtomicInteger();
AtomicInteger deleteCalls = new AtomicInteger();
Router r = new Router()
.get("/user", (req, res) -> getCalls.incrementAndGet())
.put("/user", (req, res) -> putCalls.incrementAndGet())
.delete("/user", (req, res) -> deleteCalls.incrementAndGet());
var get = assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.GET, "/user"));
var put = assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.PUT, "/user"));
var del = assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.DELETE, "/user"));
get.handler().handle(null, null);
put.handler().handle(null, null);
del.handler().handle(null, null);
assertEquals(1, getCalls.get());
assertEquals(1, putCalls.get());
assertEquals(1, deleteCalls.get());
assertNotSame(get.handler(), put.handler());
assertNotSame(put.handler(), del.handler());
}
@Test
void samePathUnregisteredMethodReturnsMethodNotAllowedWithAllAllowed() {
Router r = new Router()
.get("/user", noop)
.put("/user", noop)
.delete("/user", noop);
var res = r.resolve(HttpMethod.POST, "/user");
var mna = assertInstanceOf(Router.Resolution.MethodNotAllowed.class, res);
assertTrue(mna.allowedMethods().contains(HttpMethod.GET));
assertTrue(mna.allowedMethods().contains(HttpMethod.PUT));
assertTrue(mna.allowedMethods().contains(HttpMethod.DELETE));
assertFalse(mna.allowedMethods().contains(HttpMethod.POST));
assertEquals(3, mna.allowedMethods().size());
}
@Test
void registeringSameMethodAndPathTwiceOverwritesHandler() throws Exception {
AtomicInteger first = new AtomicInteger();
AtomicInteger second = new AtomicInteger();
Router r = new Router()
.get("/user", (req, res) -> first.incrementAndGet())
.get("/user", (req, res) -> second.incrementAndGet());
var match = (Router.Resolution.Match) r.resolve(HttpMethod.GET, "/user");
match.handler().handle(null, null);
assertEquals(0, first.get());
assertEquals(1, second.get());
}
@Test
void samePathWithParamAndMultipleMethodsKeepsParamsAndHandlers() throws Exception {
AtomicInteger getCalls = new AtomicInteger();
AtomicInteger putCalls = new AtomicInteger();
Router r = new Router()
.get("/user/{id}", (req, res) -> getCalls.incrementAndGet())
.put("/user/{id}", (req, res) -> putCalls.incrementAndGet());
var get = assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.GET, "/user/42"));
var put = assertInstanceOf(Router.Resolution.Match.class, r.resolve(HttpMethod.PUT, "/user/42"));
assertEquals("42", get.pathParams().get("id"));
assertEquals("42", put.pathParams().get("id"));
get.handler().handle(null, null);
put.handler().handle(null, null);
assertEquals(1, getCalls.get());
assertEquals(1, putCalls.get());
assertNotSame(get.handler(), put.handler());
}
}
@@ -0,0 +1,19 @@
package dev.coph.nextusweb.server.router.exception;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class BadRequestExceptionTest {
@Test
void messageIsCarried() {
BadRequestException e = new BadRequestException("oops");
assertEquals("oops", e.getMessage());
}
@Test
void isRuntimeException() {
assertInstanceOf(RuntimeException.class, new BadRequestException("x"));
}
}
@@ -0,0 +1,79 @@
package dev.coph.nextusweb.server.websocket;
import org.junit.jupiter.api.Test;
import java.time.Duration;
import static org.junit.jupiter.api.Assertions.*;
class WebSocketConfigTest {
@Test
void defaultsHasExpectedValues() {
WebSocketConfig c = WebSocketConfig.defaults();
assertEquals(65_536, c.maxFramePayloadLength());
assertEquals(1_048_576, c.maxAggregatedMessageSize());
assertEquals(Duration.ofSeconds(60), c.idleTimeout());
assertFalse(c.allowAnyOrigin());
assertTrue(c.allowedOrigins().isEmpty());
assertNull(c.subprotocolsCsv());
assertTrue(c.compression());
assertFalse(c.checkStartsWith());
}
@Test
void isOriginAllowedRespectsList() {
WebSocketConfig c = WebSocketConfig.builder()
.allowedOrigins("https://a", "https://b")
.build();
assertTrue(c.isOriginAllowed("https://a"));
assertTrue(c.isOriginAllowed("https://b"));
assertFalse(c.isOriginAllowed("https://c"));
assertFalse(c.isOriginAllowed(null));
}
@Test
void anyOriginAllowsEverythingExceptNullCheck() {
WebSocketConfig c = WebSocketConfig.builder().anyOrigin().build();
assertTrue(c.allowAnyOrigin());
assertTrue(c.isOriginAllowed("https://anything"));
assertTrue(c.isOriginAllowed(null));
}
@Test
void invalidFramePayloadLengthRejected() {
assertThrows(IllegalArgumentException.class,
() -> WebSocketConfig.builder().maxFramePayloadLength(0));
}
@Test
void invalidAggregatedMessageSizeRejected() {
assertThrows(IllegalArgumentException.class,
() -> WebSocketConfig.builder().maxAggregatedMessageSize(0));
}
@Test
void noIdleTimeoutSetsNull() {
WebSocketConfig c = WebSocketConfig.builder().noIdleTimeout().build();
assertNull(c.idleTimeout());
}
@Test
void subprotocolsCsvJoins() {
WebSocketConfig c = WebSocketConfig.builder().subprotocols("a", "b").build();
String csv = c.subprotocolsCsv();
assertNotNull(csv);
assertTrue(csv.contains("a"));
assertTrue(csv.contains("b"));
}
@Test
void compressionAndCheckStartsWithFlags() {
WebSocketConfig c = WebSocketConfig.builder()
.compression(false)
.checkStartsWith(true)
.build();
assertFalse(c.compression());
assertTrue(c.checkStartsWith());
}
}
@@ -0,0 +1,20 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.ChannelHandler;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class WebSocketFrameHandlerFactoryTest {
@Test
void createReturnsChannelHandler() {
ChannelHandler h = WebSocketFrameHandlerFactory.create(
new WebSocketHandler() {},
"/ws",
Map.of("a", "b"));
assertNotNull(h);
}
}
@@ -0,0 +1,134 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.DefaultChannelId;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class WebSocketGroupTest {
private EmbeddedChannel uniqueChannel() {
return new EmbeddedChannel(DefaultChannelId.newInstance());
}
private WebSocketSession session(EmbeddedChannel ch) {
return new WebSocketSession(ch, "/ws", Map.of());
}
@Test
void defaultConstructorHasAnonymousName() {
assertEquals("anonymous", new WebSocketGroup().name());
}
@Test
void namedConstructorRetainsName() {
assertEquals("chat", new WebSocketGroup("chat").name());
}
@Test
void addAndRemoveAdjustSize() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel ch = uniqueChannel();
WebSocketSession s = session(ch);
g.add(s);
assertEquals(1, g.size());
g.remove(s);
assertEquals(0, g.size());
}
@Test
void broadcastSendsTextToAll() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel a = uniqueChannel();
EmbeddedChannel b = uniqueChannel();
g.add(session(a)).add(session(b));
g.broadcast("hi");
a.runPendingTasks();
b.runPendingTasks();
Object fa = a.readOutbound();
Object fb = b.readOutbound();
TextWebSocketFrame ta = assertInstanceOf(TextWebSocketFrame.class, fa);
TextWebSocketFrame tb = assertInstanceOf(TextWebSocketFrame.class, fb);
assertEquals("hi", ta.text());
assertEquals("hi", tb.text());
ta.release();
tb.release();
}
@Test
void broadcastJsonSendsTextFrames() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel a = uniqueChannel();
g.add(session(a));
g.broadcastJson(Map.of("k", "v"));
a.runPendingTasks();
Object out = a.readOutbound();
TextWebSocketFrame frame = assertInstanceOf(TextWebSocketFrame.class, out);
assertTrue(frame.text().contains("\"k\""));
frame.release();
}
@Test
void broadcastBinarySendsBinaryFrames() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel a = uniqueChannel();
g.add(session(a));
g.broadcastBinary(new byte[]{1, 2, 3});
a.runPendingTasks();
Object out = a.readOutbound();
BinaryWebSocketFrame frame = assertInstanceOf(BinaryWebSocketFrame.class, out);
assertEquals(3, frame.content().readableBytes());
frame.release();
}
@Test
void broadcastExceptSkipsExcludedSession() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel a = uniqueChannel();
EmbeddedChannel b = uniqueChannel();
WebSocketSession sa = session(a);
WebSocketSession sb = session(b);
g.add(sa).add(sb);
g.broadcastExcept(sa, "hello");
a.runPendingTasks();
b.runPendingTasks();
assertNull(a.readOutbound());
Object out = b.readOutbound();
TextWebSocketFrame frame = assertInstanceOf(TextWebSocketFrame.class, out);
assertEquals("hello", frame.text());
frame.release();
}
@Test
void closeAllClosesUnderlyingChannels() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel a = uniqueChannel();
g.add(session(a));
g.closeAll();
a.runPendingTasks();
assertFalse(a.isActive());
}
@Test
void fluentMethodsReturnGroup() {
WebSocketGroup g = new WebSocketGroup("g");
EmbeddedChannel a = uniqueChannel();
WebSocketSession s = session(a);
assertSame(g, g.add(s));
assertSame(g, g.broadcast("x"));
assertSame(g, g.broadcastBinary(new byte[]{1}));
assertSame(g, g.broadcastJson(Map.of("a", 1)));
assertSame(g, g.broadcastExcept(null, "y"));
assertSame(g, g.remove(s));
assertSame(g, g.closeAll());
}
}
@@ -0,0 +1,23 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class WebSocketHandlerTest {
@Test
void defaultMethodsDoNotThrow() {
WebSocketHandler handler = new WebSocketHandler() {};
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession session = new WebSocketSession(ch, "/ws", Map.of());
assertDoesNotThrow(() -> handler.onOpen(session));
assertDoesNotThrow(() -> handler.onMessage(session, "msg"));
assertDoesNotThrow(() -> handler.onBinary(session, new byte[]{1}));
assertDoesNotThrow(() -> handler.onClose(session, 1000, "ok"));
assertDoesNotThrow(() -> handler.onError(session, new RuntimeException("e")));
}
}
@@ -0,0 +1,45 @@
package dev.coph.nextusweb.server.websocket;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class WebSocketRouterTest {
private final WebSocketHandler handler = new WebSocketHandler() {};
@Test
void resolvesExactPath() {
WebSocketRouter r = new WebSocketRouter().on("/ws", handler);
WebSocketRouter.Resolution res = r.resolve("/ws");
assertNotNull(res);
assertSame(handler, res.handler());
assertTrue(res.pathParams().isEmpty());
}
@Test
void returnsNullForUnknown() {
WebSocketRouter r = new WebSocketRouter().on("/ws", handler);
assertNull(r.resolve("/missing"));
}
@Test
void extractsPathParameters() {
WebSocketRouter r = new WebSocketRouter().on("/rooms/{id}", handler);
WebSocketRouter.Resolution res = r.resolve("/rooms/abc");
assertNotNull(res);
assertEquals("abc", res.pathParams().get("id"));
}
@Test
void onIsFluent() {
WebSocketRouter r = new WebSocketRouter();
assertSame(r, r.on("/x", handler));
}
@Test
void interiorNodeWithoutHandlerReturnsNull() {
WebSocketRouter r = new WebSocketRouter().on("/a/b", handler);
assertNull(r.resolve("/a"));
}
}
@@ -0,0 +1,138 @@
package dev.coph.nextusweb.server.websocket;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.CharsetUtil;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class WebSocketSessionTest {
private WebSocketSession session(EmbeddedChannel ch) {
return new WebSocketSession(ch, "/ws/{id}", Map.of("id", "42"));
}
@Test
void idIsAssignedAndNonNull() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
assertNotNull(s.id());
assertFalse(s.id().isEmpty());
}
@Test
void pathAndPathParamExpose() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
assertEquals("/ws/{id}", s.path());
assertEquals("42", s.pathParam("id"));
assertNull(s.pathParam("missing"));
}
@Test
void isOpenWhileChannelActive() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
assertTrue(s.isOpen());
ch.close();
assertFalse(s.isOpen());
}
@Test
void channelGetterReturnsChannel() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
assertSame(ch, s.channel());
}
@Test
void attributesSetAndRetrieve() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
s.attribute("k", "v");
assertEquals("v", s.<String>attribute("k"));
s.attribute("k", null);
assertNull(s.<String>attribute("k"));
}
@Test
void sendWritesTextFrame() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
s.send("hi");
Object out = ch.readOutbound();
TextWebSocketFrame frame = assertInstanceOf(TextWebSocketFrame.class, out);
assertEquals("hi", frame.text());
frame.release();
}
@Test
void sendJsonProducesTextFrame() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
s.sendJson(Map.of("a", "b"));
Object out = ch.readOutbound();
TextWebSocketFrame frame = assertInstanceOf(TextWebSocketFrame.class, out);
String payload = frame.content().toString(CharsetUtil.UTF_8);
assertTrue(payload.contains("\"a\""));
frame.release();
}
@Test
void sendBinaryProducesBinaryFrame() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
s.sendBinary(new byte[]{1, 2, 3});
Object out = ch.readOutbound();
BinaryWebSocketFrame frame = assertInstanceOf(BinaryWebSocketFrame.class, out);
assertEquals(3, frame.content().readableBytes());
frame.release();
}
@Test
void pingProducesPingFrame() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
s.ping();
Object out = ch.readOutbound();
assertInstanceOf(PingWebSocketFrame.class, out);
((PingWebSocketFrame) out).release();
}
@Test
void closeProducesCloseFrameAndClosesChannel() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
s.close(1001, "going-away");
Object out = ch.readOutbound();
CloseWebSocketFrame frame = assertInstanceOf(CloseWebSocketFrame.class, out);
assertEquals(1001, frame.statusCode());
assertEquals("going-away", frame.reasonText());
frame.release();
}
@Test
void sendOnInactiveChannelDoesNotThrow() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
ch.close();
assertDoesNotThrow(() -> s.send("ignored"));
assertDoesNotThrow(() -> s.sendBinary(new byte[]{1}));
assertDoesNotThrow(() -> s.ping());
assertDoesNotThrow(() -> s.sendJson(Map.of("a", 1)));
assertDoesNotThrow(() -> s.close());
}
@Test
void remoteAddressReturnsNullForUnconnectedEmbeddedChannel() {
EmbeddedChannel ch = new EmbeddedChannel();
WebSocketSession s = session(ch);
assertDoesNotThrow(s::remoteAddress);
}
}