From bcf5572aebb60098e68644badd7f1ca923c0db5c Mon Sep 17 00:00:00 2001 From: CodingPhoenixx Date: Fri, 29 May 2026 13:22:31 +0200 Subject: [PATCH] Introduce authentication framework with `AuthConfig`, `AuthGate`, and `Authenticator` classes, alongside comprehensive tests for rules, modes, and schemes. --- README.md | 148 ++++++++- build.gradle | 2 +- .../nextusweb/server/HttpRequestHandler.java | 306 +++++++++++++----- .../dev/coph/nextusweb/server/HttpServer.java | 173 ++++++---- .../nextusweb/server/auth/AuthConfig.java | 230 +++++++++++++ .../coph/nextusweb/server/auth/AuthGate.java | 69 ++++ .../nextusweb/server/auth/Authenticator.java | 128 ++++++++ .../coph/nextusweb/server/auth/Principal.java | 96 ++++++ .../coph/nextusweb/server/net/ClientIp.java | 53 +++ .../nextusweb/server/net/TrustedProxies.java | 162 ++++++++++ .../server/ratelimit/FixedWindowLimiter.java | 18 +- .../server/ratelimit/KeyResolver.java | 90 ++++-- .../server/ratelimit/LeakyBucketLimiter.java | 59 ++-- .../server/ratelimit/RateLimitConfig.java | 24 ++ .../server/ratelimit/RateLimitGate.java | 47 ++- .../server/ratelimit/RateLimiter.java | 22 ++ .../ratelimit/SlidingWindowLimiter.java | 1 + .../server/ratelimit/TokenBucketLimiter.java | 129 ++++---- .../coph/nextusweb/server/router/Request.java | 115 +++++++ .../coph/nextusweb/server/tls/TlsConfig.java | 124 +++++++ .../server/websocket/WebSocketConfig.java | 29 ++ .../websocket/WebSocketFrameHandler.java | 118 ++++++- .../WebSocketFrameHandlerFactory.java | 40 ++- .../server/websocket/WebSocketGroup.java | 6 +- .../server/websocket/WebSocketSession.java | 17 +- .../nextusweb/server/auth/AuthConfigTest.java | 126 ++++++++ .../nextusweb/server/auth/AuthGateTest.java | 147 +++++++++ .../server/auth/AuthenticatorTest.java | 119 +++++++ .../nextusweb/server/auth/PrincipalTest.java | 61 ++++ .../nextusweb/server/net/ClientIpTest.java | 56 ++++ .../server/net/TrustedProxiesTest.java | 64 ++++ .../server/ratelimit/KeyResolverTest.java | 63 ++-- .../server/ratelimit/RateLimitConfigTest.java | 26 ++ .../server/ratelimit/RateLimitGateTest.java | 14 +- .../nextusweb/server/router/RequestTest.java | 45 +++ .../nextusweb/server/tls/TlsConfigTest.java | 22 ++ .../server/websocket/WebSocketGroupTest.java | 2 +- .../websocket/WebSocketHandlerTest.java | 2 +- .../websocket/WebSocketSessionTest.java | 2 +- 39 files changed, 2629 insertions(+), 326 deletions(-) create mode 100644 src/main/java/dev/coph/nextusweb/server/auth/AuthConfig.java create mode 100644 src/main/java/dev/coph/nextusweb/server/auth/AuthGate.java create mode 100644 src/main/java/dev/coph/nextusweb/server/auth/Authenticator.java create mode 100644 src/main/java/dev/coph/nextusweb/server/auth/Principal.java create mode 100644 src/main/java/dev/coph/nextusweb/server/net/ClientIp.java create mode 100644 src/main/java/dev/coph/nextusweb/server/net/TrustedProxies.java create mode 100644 src/main/java/dev/coph/nextusweb/server/tls/TlsConfig.java create mode 100644 src/test/java/dev/coph/nextusweb/server/auth/AuthConfigTest.java create mode 100644 src/test/java/dev/coph/nextusweb/server/auth/AuthGateTest.java create mode 100644 src/test/java/dev/coph/nextusweb/server/auth/AuthenticatorTest.java create mode 100644 src/test/java/dev/coph/nextusweb/server/auth/PrincipalTest.java create mode 100644 src/test/java/dev/coph/nextusweb/server/net/ClientIpTest.java create mode 100644 src/test/java/dev/coph/nextusweb/server/net/TrustedProxiesTest.java create mode 100644 src/test/java/dev/coph/nextusweb/server/tls/TlsConfigTest.java diff --git a/README.md b/README.md index 91c6cba..6f3a835 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,16 @@ A lightweight, high-performance HTTP server library built on top of Netty. Nexus ## Features - **Netty-based** — uses epoll/kqueue/NIO automatically based on the platform -- **Virtual thread dispatch** — each request is handled on a Java 21 virtual thread +- **Virtual thread dispatch** — each request is handled on a Java virtual thread, with per-connection read backpressure and HTTP keep-alive +- **TLS / HTTPS** — enable encryption with a single `withTls(...)` call (PEM files or a custom `SslContext`) +- **Pluggable authentication** — insert an auth layer that protects selected paths; API key, cookie, HTTP Basic, bearer or any custom scheme (not tied to bearer tokens) - **Trie-based router** — supports static paths, path parameters (`{id}`), and wildcards (`*`) - **Annotation-based controllers** — define routes declaratively with `@Controller`, `@GET`, `@POST`, etc. - **Middleware chain** — attach cross-cutting logic to all routes - **CORS support** — configurable origins, methods, headers, credentials, and preflight caching -- **Rate limiting** — four algorithm implementations with per-IP, per-token, or custom key strategies -- **WebSockets** — path-routed handlers with origin validation, idle timeout, frame size limits and permessage-deflate +- **Rate limiting** — four algorithm implementations with per-IP, per-header, per-cookie, per-principal or custom key strategies, with automatic eviction of idle state +- **Spoofing-safe client IP** — `X-Forwarded-For` is honoured only behind configured trusted proxies +- **WebSockets** — path-routed handlers with origin validation, optional authentication, ordered per-connection delivery, backpressure, idle timeout, frame size limits and permessage-deflate - **JSON I/O** — built-in Jackson integration for request parsing and response serialization --- @@ -166,11 +169,17 @@ req.pathParam("id") // path parameter, e.g. from /users/{id} req.queryParam("search") // first value of ?search= req.queryParams("tag") // all values of ?tag= as List req.header("Authorization") // raw header value +req.cookie("sid") // value of a named cookie req.body() // raw body as UTF-8 String req.json() // body parsed as Jackson JsonNode req.jsonAs(MyDto.class) // body deserialized into a POJO req.method() // HttpMethod req.path() // decoded path without query string +req.clientIp() // resolved client IP (honours trusted proxies) +req.principal() // authenticated principal, or null (see Authentication) +req.isAuthenticated() // whether a principal is attached +req.attribute("k", value) // attach per-request state +req.attribute("k") // read it back ``` `json()` and `jsonAs()` throw `BadRequestException` (→ `400`) on malformed JSON. @@ -214,6 +223,115 @@ Use `CorsConfig.permissive()` for a development preset that allows any origin wi --- +## TLS / HTTPS + +Enable encryption by attaching a `TlsConfig`. The TLS handler becomes the first element of every connection's pipeline, so both HTTP and WebSocket traffic are served over TLS (HTTPS / WSS). + +```java +import dev.coph.nextusweb.server.tls.TlsConfig; + +HttpServer.builder(443, router) + .withTls(TlsConfig.fromPem( + new File("fullchain.pem"), // PEM certificate chain + new File("privkey.pem"))) // PKCS#8 private key + .start(); +``` + +| Factory | Use | +|---|---| +| `TlsConfig.fromPem(cert, key)` | PEM certificate chain + unencrypted PKCS#8 key | +| `TlsConfig.fromPem(cert, key, password)` | …with a password-protected key | +| `TlsConfig.fromPem(certStream, keyStream, password)` | Load PEM material from the classpath or another stream | +| `TlsConfig.fromSslContext(ctx)` | Full control — supply a Netty `SslContext` (custom ciphers, mutual TLS, …) | + +Any initialisation failure (missing/invalid material) is reported as an `IllegalStateException`. + +--- + +## Authentication + +The auth layer authenticates **selected paths** before they reach handlers and attaches a `Principal` to the request (visible to rate limiting, middleware and handlers). It is deliberately **not** tied to bearer tokens — choose any credential scheme. + +```java +import dev.coph.nextusweb.server.auth.*; + +// 1. An authenticator turns a credential into a Principal (or null if invalid). +Authenticator auth = Authenticator.apiKey("X-API-Key", key -> + key.equals(System.getenv("API_KEY")) ? Principal.of("service", Set.of("admin")) : null); + +// 2. Decide which paths it protects. +AuthConfig authConfig = AuthConfig.builder(auth) + .protectPrefix("/api/") // required: 401 if missing/invalid + .optional("/feed") // attach principal if present, never reject + .challenge("ApiKey realm=\"api\"") + .build(); + +HttpServer.builder(8080, router) + .withAuth(new AuthGate(authConfig)) + .start(); +``` + +In a handler: + +```java +router.get("/api/me", (req, res) -> { + Principal p = req.principal(); // never null on a protected path + if (!p.hasRole("admin")) { res.status(403); return; } + res.json(Map.of("id", p.id())); +}); +``` + +### Authenticators + +| Factory | Credential | +|---|---| +| `Authenticator.apiKey(header, validator)` | An API key in a request header (e.g. `X-API-Key`) | +| `Authenticator.cookie(name, validator)` | A session (or other) cookie | +| `Authenticator.basic(validator)` | HTTP Basic `username` / `password` | +| `Authenticator.bearer(validator)` | A bearer token (provided for completeness; never required) | +| `Authenticator.anyOf(a, b, …)` | Tries each in order, first match wins | +| Custom | Implement `Authenticator` — e.g. mutual-TLS cert, HMAC-signed request | + +`validator` returns the resolved `Principal`, or `null` for missing/invalid credentials (→ `401` on a `REQUIRED` path). A thrown exception is treated as an internal error (→ generic `500`); details are logged, never sent to the client. Rate limiting runs **before** authentication, so an unauthenticated flood is shed before reaching a (potentially expensive) authenticator. + +WebSocket upgrades on protected paths are authenticated the same way; the resolved principal is available via `session.principal()`. + +--- + +## Trusted proxies & client IP + +`req.clientIp()` and `KeyResolver.clientIp()` return a spoofing-safe client address. By default (`TrustedProxies.none()`) the transport peer address is used and `X-Forwarded-For` is ignored — a directly connected client cannot forge its IP. When running behind a reverse proxy, declare it trusted so the forwarded header is honoured: + +```java +import dev.coph.nextusweb.server.net.TrustedProxies; + +HttpServer.builder(8080, router) + .withTrustedProxies(TrustedProxies.of("10.0.0.0/8", "127.0.0.1", "::1")) + .start(); +``` + +The resolver walks `X-Forwarded-For` from right to left and returns the first hop that is **not** a trusted proxy, so forged left-most entries are ignored. Use `TrustedProxies.all()` only when the server can never be reached except through a trusted proxy. + +--- + +## Hardening & limits + +| Concern | How it's handled | +|---|---| +| Connection reuse | HTTP keep-alive is honoured; connections close only on `Connection: close` or error | +| Slow-client / Slowloris | A per-connection read timeout (`httpReadTimeout`, default 30s) closes stalled/idle connections | +| Request memory | Auto-read is disabled while a request is in flight (one buffered body per connection); `maxHttpContentLength` (default 1 MiB) caps the body, returning `413` | +| Error disclosure | Handler exceptions return a generic `500`; the detail is logged server-side, never sent to the client | + +```java +HttpServer.builder(8080, router) + .maxHttpContentLength(2 * 1024 * 1024) // 2 MiB body cap + .httpReadTimeout(Duration.ofSeconds(20)) // null/zero disables + .start(); +``` + +--- + ## Rate Limiting ### Algorithms @@ -233,8 +351,8 @@ RateLimitConfig config = RateLimitConfig.builder() .global(new TokenBucketLimiter(100, 200), KeyResolver.clientIp()) // Stricter rule for a specific path .forPath("/login", new FixedWindowLimiter(5, 60_000), KeyResolver.clientIp()) - // Rule for an entire path prefix - .forPrefix("/api/", new SlidingWindowLimiter(1000, 1000), KeyResolver.userOrIp()) + // Per-API-key rule for an entire path prefix (no bearer token required) + .forPrefix("/api/", new SlidingWindowLimiter(1000, 1000), KeyResolver.header("X-API-Key")) .build(); RateLimitGate gate = new RateLimitGate(config); @@ -246,6 +364,8 @@ HttpServer.builder(8080, router) When the limit is exceeded the server responds with `429 Too Many Requests` and a `Retry-After` header. +Per-key limiter state is evicted automatically by a background task (every 5 minutes, entries idle for >10 minutes), so a high-cardinality key (many distinct IPs/API keys) cannot grow the limiter maps without bound. Call `gate.shutdown()` when stopping the server. + ### Response headers Every response automatically includes: @@ -260,9 +380,13 @@ Every response automatically includes: | Factory | Behaviour | |---|---| -| `KeyResolver.clientIp()` | Uses `X-Forwarded-For` if present, otherwise the remote IP | -| `KeyResolver.userOrIp()` | Uses the Bearer token if present (`u:`), otherwise the client IP (`ip:`) | -| Custom lambda | `(req, remoteAddr) -> myKey(req)` | +| `KeyResolver.clientIp()` | The resolved client IP (honours trusted proxies — `X-Forwarded-For` is **not** trusted from a direct client) | +| `KeyResolver.header(name)` | Header value (e.g. an API key in `X-API-Key`); falls back to `ip:` when absent | +| `KeyResolver.cookie(name)` | Cookie value (e.g. a session id); falls back to `ip:` when absent | +| `KeyResolver.principal()` | The authenticated principal id (`p:`); falls back to `ip:` when anonymous (requires the auth layer to run for the path) | +| Custom lambda | `(req, clientIp) -> myKey(req)` — `req` is the framework `Request`, `clientIp` the resolved IP | + +> The old `KeyResolver.userOrIp()` (which trusted any client's `X-Forwarded-For` and keyed on a raw bearer token) has been removed. It allowed trivial rate-limit bypass and unbounded key growth; use `header(...)`, `cookie(...)` or `principal()` instead. --- @@ -307,7 +431,8 @@ WebSocketRouter wsRouter = new WebSocketRouter() WebSocketConfig wsConfig = WebSocketConfig.builder() .allowedOrigins("https://app.example.com") .maxFramePayloadLength(64 * 1024) // 64 KiB per frame - .maxAggregatedMessageSize(1024 * 1024) // 1 MiB upgrade body cap + .maxAggregatedMessageSize(1024 * 1024) // 1 MiB cap on a reassembled (fragmented) message + .maxQueuedMessages(1024) // per-connection backlog before backpressure .idleTimeout(Duration.ofSeconds(60)) // close idle peers .subprotocols("chat.v1") .compression(true) // permessage-deflate @@ -327,6 +452,7 @@ session.id(); // stable UUID for this connection session.path(); // matched path session.pathParam("room"); // path parameter, e.g. from /ws/rooms/{room} session.remoteAddress(); // client IP +session.principal(); // authenticated principal, or null session.attribute("userId", id); // attach state to the session session.attribute("userId"); // read it back @@ -357,7 +483,9 @@ group.broadcastExcept(sessionA, "everyone but A"); | Concern | How it's handled | |---|---| | Cross-origin upgrades | `Origin` header validated against `WebSocketConfig.allowedOrigins(...)`; mismatched origins are rejected with `403` | -| Memory exhaustion | `maxFramePayloadLength` caps a single frame; `maxAggregatedMessageSize` caps the upgrade request body | +| Authentication | When an `AuthGate` is configured, protected upgrade paths are authenticated and the principal is exposed via `session.principal()` | +| Memory exhaustion | `maxFramePayloadLength` caps a single frame; `maxQueuedMessages` (default 1024) bounds the per-connection callback backlog and pauses reads (backpressure) when exceeded | +| Message ordering | Callbacks for a single connection run **strictly in arrival order** on a per-connection serial drainer (still on virtual threads, so handlers may block) | | Idle / zombie connections | `idleTimeout` triggers a server-side close when no read **and** no write happen within the window | | User code isolation | All callbacks dispatch onto Java virtual threads, never the Netty event loop | | Subprotocol negotiation | Server advertises the configured `subprotocols(...)` list; clients that ask for an unsupported subprotocol fail the handshake | diff --git a/build.gradle b/build.gradle index 5bf5384..f63ead9 100644 --- a/build.gradle +++ b/build.gradle @@ -4,7 +4,7 @@ plugins { } group = 'dev.coph' -version = '0.0.4' +version = '0.0.5' repositories { mavenCentral() diff --git a/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java b/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java index 889b6aa..ca84182 100644 --- a/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java +++ b/src/main/java/dev/coph/nextusweb/server/HttpRequestHandler.java @@ -1,6 +1,10 @@ package dev.coph.nextusweb.server; +import dev.coph.nextusweb.server.auth.AuthGate; +import dev.coph.nextusweb.server.auth.Principal; import dev.coph.nextusweb.server.cores.CorsHandler; +import dev.coph.nextusweb.server.net.ClientIp; +import dev.coph.nextusweb.server.net.TrustedProxies; import dev.coph.nextusweb.server.ratelimit.RateLimitGate; import dev.coph.nextusweb.server.ratelimit.RateLimiter; import dev.coph.nextusweb.server.router.Request; @@ -16,12 +20,17 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.*; +import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler; import io.netty.handler.timeout.IdleStateHandler; +import java.lang.System.Logger; +import java.lang.System.Logger.Level; import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -31,16 +40,22 @@ import java.util.stream.Collectors; * The core inbound channel handler that processes every aggregated HTTP request. * *

For each request it, in order: detects and performs WebSocket upgrades (when a WebSocket - * router is configured), answers CORS preflight requests, enforces rate limits, resolves the - * route via the {@link Router}, runs middlewares and the matched handler, and finally writes the - * response with CORS and rate-limit headers applied.

+ * router is configured), answers CORS preflight requests, enforces rate limits, runs the + * authentication layer, resolves the route via the {@link Router}, runs middlewares and the + * matched handler, and finally writes the response with CORS and rate-limit headers applied.

* *

Blocking handler logic runs on a virtual-thread executor rather than on the Netty event - * loop, so handlers may perform blocking work without stalling I/O. WebSocket upgrades, by - * contrast, mutate the pipeline and are handled inline on the event loop.

+ * loop, so handlers may perform blocking work without stalling I/O. To keep memory bounded the + * connection's auto-read is disabled while a request is in flight and re-enabled once the response + * has been flushed, so at most one request per connection is buffered at a time. Persistent + * (keep-alive) connections are honoured; the connection is closed only when the client requested + * {@code Connection: close} or an unrecoverable error occurred.

*/ public final class HttpRequestHandler extends SimpleChannelInboundHandler { + /** Logger used for server-side error diagnostics (never leaked to clients). */ + private static final Logger LOG = System.getLogger(HttpRequestHandler.class.getName()); + /** Executor running one virtual thread per task, used to offload blocking handler work. */ private static final Executor VT_EXECUTOR = Executors.newVirtualThreadPerTaskExecutor(); @@ -51,45 +66,57 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler { try { @@ -134,15 +165,15 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandlerResolves the path against the WebSocket router; if no handler matches the upgrade is - * declined. Otherwise the origin is validated, the WebSocket protocol/compression/idle - * handlers and the application frame handler are inserted into the pipeline, and the request - * is re-fired so Netty performs the handshake.

+ * declined. Otherwise the origin is validated, the auth layer (if configured) authenticates + * the upgrade and may reject it, and then the WebSocket protocol/compression/idle handlers and + * the application frame handler are inserted into the pipeline and the request is re-fired so + * Netty performs the handshake. The resolved principal (if any) is handed to the session.

* * @param ctx the channel context * @param req the upgrade request * @return {@code true} if the request was consumed (handshake started or rejected), - * {@code false} if no WebSocket route matched and normal HTTP handling should - * continue + * {@code false} if no WebSocket route matched and normal HTTP handling should continue */ private boolean handleWebSocketUpgrade(ChannelHandlerContext ctx, FullHttpRequest req) { String path = new QueryStringDecoder(req.uri()).path(); @@ -151,13 +182,24 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandlerExceptions from the handler are mapped to responses: a {@link BadRequestException} - * becomes a {@code 400}, any other exception a {@code 500}. Routing misses become - * {@code 404}, and method mismatches a {@code 405} with an {@code Allow} header. CORS and - * rate-limit headers are applied to the final response in all cases.

+ *

The whole method is guarded so that some response is always produced — even an + * unexpected failure in the pre-handler stages yields a generic {@code 500} rather than a + * leaked, hung connection. Handler exceptions are mapped to responses: a + * {@link BadRequestException} becomes a {@code 400}, any other exception a {@code 500} whose + * details are logged server-side but never sent to the client. Routing misses become + * {@code 404}, method mismatches a {@code 405} with an {@code Allow} header.

* * @param ctx the channel context * @param raw the aggregated request being handled */ private void handle(ChannelHandlerContext ctx, FullHttpRequest raw) { - String origin = raw.headers().get("Origin"); + boolean keepAlive = HttpUtil.isKeepAlive(raw); + try { + String origin = raw.headers().get(HttpHeaderNames.ORIGIN); - if (cors != null && cors.isPreflight(raw.method(), raw.headers())) { - send(ctx, cors.handlePreflight(origin, raw.headers())); - return; - } - - String path = new QueryStringDecoder(raw.uri()).path(); - - RateLimiter.Result rlResult = null; - if (rateLimit != null) { - String remote = ((InetSocketAddress) ctx.channel().remoteAddress()).getAddress().getHostAddress(); - rlResult = rateLimit.check(raw, path, remote); - if (rlResult != null && !rlResult.allowed()) { - Response res = new Response().status(429).json("{\"error\":\"Too Many Requests\"}"); - RateLimitGate.applyHeaders(rlResult, res); - if (cors != null) cors.applyHeaders(origin, res); - send(ctx, res); + if (cors != null && cors.isPreflight(raw.method(), raw.headers())) { + send(ctx, cors.handlePreflight(origin, raw.headers()), keepAlive); return; } - } + String path = new QueryStringDecoder(raw.uri()).path(); + String clientIp = resolveClientIp(ctx, raw); - Router.Resolution resolution = router.resolve(raw.method(), path); + Router.Resolution resolution = router.resolve(raw.method(), path); + Map params = resolution instanceof Router.Resolution.Match m + ? m.pathParams() : Map.of(); + Request request = new Request(raw, params); + request.clientIp(clientIp); - Response res = new Response(); - - switch (resolution) { - case Router.Resolution.Match m -> { - Request request = new Request(raw, m.pathParams()); - try { - for (var mw : router.middlewares()) mw.accept(request, res); - m.handler().handle(request, res); - } catch (BadRequestException e) { - res.status(400).json("{\"error\":\"" + e.getMessage() + "\"}"); - } catch (Exception e) { - res.status(500).text("Internal Server Error: " + e.getMessage()); + // Rate limiting runs before authentication so an unauthenticated flood is shed before + // reaching the (potentially expensive) authenticator. + RateLimiter.Result rlResult = null; + if (rateLimit != null) { + rlResult = rateLimit.check(request, path, clientIp); + if (rlResult != null && !rlResult.allowed()) { + Response res = new Response().status(429).json("{\"error\":\"Too Many Requests\"}"); + RateLimitGate.applyHeaders(rlResult, res); + if (cors != null) cors.applyHeaders(origin, res); + send(ctx, res, keepAlive); + return; } } - case Router.Resolution.MethodNotAllowed mna -> { - String allow = mna.allowedMethods().stream() - .map(HttpMethod::name) - .sorted() - .collect(Collectors.joining(", ")); - res.status(405) - .header(HttpHeaderNames.ALLOW.toString(), allow) - .json("{\"error\":\"Method Not Allowed\",\"allowed\":\"" + allow + "\"}"); - } - case Router.Resolution.NotFound nf -> res.status(404).json("{\"error\":\"Not Found\"}"); - } - RateLimitGate.applyHeaders(rlResult, res); - if (cors != null) cors.applyHeaders(origin, res); - send(ctx, res); + // Authentication layer: attaches the principal on success, or short-circuits with a + // rejection response (401/500) for protected paths. + if (authGate != null) { + Response rejection = authGate.authenticate(request, path); + if (rejection != null) { + RateLimitGate.applyHeaders(rlResult, rejection); + if (cors != null) cors.applyHeaders(origin, rejection); + send(ctx, rejection, keepAlive); + return; + } + } + + Response res = new Response(); + switch (resolution) { + case Router.Resolution.Match m -> { + try { + for (var mw : router.middlewares()) mw.accept(request, res); + m.handler().handle(request, res); + } catch (BadRequestException e) { + res.status(400).json(Map.of("error", + e.getMessage() == null ? "Bad Request" : e.getMessage())); + } catch (Exception e) { + LOG.log(Level.ERROR, "Handler failed for " + raw.method() + " " + path, e); + res.status(500).json("{\"error\":\"Internal Server Error\"}"); + } + } + case Router.Resolution.MethodNotAllowed mna -> { + String allow = mna.allowedMethods().stream() + .map(HttpMethod::name) + .sorted() + .collect(Collectors.joining(", ")); + res.status(405) + .header(HttpHeaderNames.ALLOW.toString(), allow) + .json(Map.of("error", "Method Not Allowed", "allowed", allow)); + } + case Router.Resolution.NotFound nf -> res.status(404).json("{\"error\":\"Not Found\"}"); + } + + RateLimitGate.applyHeaders(rlResult, res); + if (cors != null) cors.applyHeaders(origin, res); + send(ctx, res, keepAlive); + } catch (Throwable t) { + // Last-resort guard: anything escaping the stages above must still produce a response, + // otherwise the connection would hang with auto-read disabled. Close it to be safe. + LOG.log(Level.ERROR, "Unexpected failure while handling request", t); + try { + send(ctx, new Response().status(500).json("{\"error\":\"Internal Server Error\"}"), false); + } catch (Throwable ignored) { + ctx.close(); + } + } + } + + /** + * Resolves the effective client IP for a request, honouring the configured trusted proxies. + * + * @param ctx the channel context + * @param raw the request (for the forwarded-for header) + * @return the resolved client IP, or {@code "unknown"} if the peer address is unavailable + */ + private String resolveClientIp(ChannelHandlerContext ctx, FullHttpRequest raw) { + SocketAddress addr = ctx.channel().remoteAddress(); + String socketIp = (addr instanceof InetSocketAddress isa && isa.getAddress() != null) + ? isa.getAddress().getHostAddress() : "unknown"; + String forwarded = raw.headers().get(ClientIp.FORWARDED_FOR_HEADER); + return ClientIp.resolve(socketIp, forwarded, trustedProxies); } /** * Converts the framework {@link Response} into a Netty {@link FullHttpResponse}, sets the - * {@code Content-Length}, writes it and closes the connection afterwards. + * {@code Content-Length} and {@code Connection} headers and writes it. For keep-alive + * connections the connection is kept open and reading resumes for the next request; otherwise + * the connection is closed after the write. * - * @param ctx the channel context - * @param res the response to send + * @param ctx the channel context + * @param res the response to send + * @param keepAlive whether the client requested a persistent connection */ - private void send(ChannelHandlerContext ctx, Response res) { + private void send(ChannelHandlerContext ctx, Response res, boolean keepAlive) { var nettyRes = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(res.status()), @@ -273,11 +376,44 @@ public final class HttpRequestHandler extends SimpleChannelInboundHandler resumeReading(ctx)); + } else { + nettyRes.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + ctx.writeAndFlush(nettyRes).addListener(ChannelFutureListener.CLOSE); + } } /** - * Closes the channel on any unhandled pipeline exception. + * Re-enables auto-read and requests the next message, resuming intake on a persistent + * connection after its in-flight request has been answered. + * + * @param ctx the channel context + */ + private static void resumeReading(ChannelHandlerContext ctx) { + if (ctx.channel().isActive()) { + ctx.channel().config().setAutoRead(true); + ctx.read(); + } + } + + /** + * Writes an empty-bodied response with the given status and closes the connection, used for + * rejected WebSocket upgrades. + * + * @param ctx the channel context + * @param status the HTTP status to send + */ + private static void sendStatusAndClose(ChannelHandlerContext ctx, HttpResponseStatus status) { + FullHttpResponse res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status); + res.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + ctx.writeAndFlush(res).addListener(ChannelFutureListener.CLOSE); + } + + /** + * Closes the channel on any unhandled pipeline exception (including read timeouts). * * @param ctx the channel context * @param cause the exception that propagated up the pipeline diff --git a/src/main/java/dev/coph/nextusweb/server/HttpServer.java b/src/main/java/dev/coph/nextusweb/server/HttpServer.java index 869ec65..1b131e9 100644 --- a/src/main/java/dev/coph/nextusweb/server/HttpServer.java +++ b/src/main/java/dev/coph/nextusweb/server/HttpServer.java @@ -1,8 +1,11 @@ package dev.coph.nextusweb.server; +import dev.coph.nextusweb.server.auth.AuthGate; import dev.coph.nextusweb.server.cores.CorsHandler; +import dev.coph.nextusweb.server.net.TrustedProxies; import dev.coph.nextusweb.server.ratelimit.RateLimitGate; import dev.coph.nextusweb.server.router.Router; +import dev.coph.nextusweb.server.tls.TlsConfig; import dev.coph.nextusweb.server.websocket.WebSocketConfig; import dev.coph.nextusweb.server.websocket.WebSocketRouter; import io.netty.bootstrap.ServerBootstrap; @@ -18,44 +21,59 @@ import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.timeout.ReadTimeoutHandler; + +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.TimeUnit; /** * Bootstraps and runs the Netty-based HTTP (and optionally WebSocket) server. * *

The class doubles as a small fluent builder: {@link #builder(int, Router)} creates an - * instance bound to a port and {@link Router}, and the {@code withXxx} methods attach optional - * features (CORS, rate limiting, WebSockets) before {@link #start()} launches the server.

+ * instance bound to a port and {@link Router}, and the {@code withXxx}/setter methods attach + * optional features (TLS, CORS, rate limiting, authentication, WebSockets, trusted proxies, + * limits and timeouts) before {@link #start()} launches the server.

* *

At start-up it selects the most efficient available transport — {@code epoll} on * Linux, {@code kqueue} on macOS/BSD, or the portable NIO transport otherwise — and wires - * up the Netty channel pipeline (codec, aggregator and the {@link HttpRequestHandler}). The - * {@link #start()} call blocks until the server channel is closed.

+ * up the Netty channel pipeline (optional TLS, a read timeout that defeats slow-client attacks, + * the HTTP codec, aggregator and the {@link HttpRequestHandler}). The {@link #start()} call + * blocks until the server channel is closed.

*/ public final class HttpServer { + /** Default cap on aggregated HTTP request bodies: 1 MiB. */ + private static final int DEFAULT_MAX_HTTP_CONTENT_LENGTH = 1_048_576; + /** Default per-connection read timeout that reaps slow/idle clients. */ + private static final Duration DEFAULT_HTTP_READ_TIMEOUT = Duration.ofSeconds(30); + /** TCP port the server binds to. */ private final int port; /** Router resolving requests to handlers. */ private final Router router; + /** Optional TLS configuration; {@code null} serves plain HTTP. */ + private TlsConfig tls; /** Optional CORS handler; {@code null} disables CORS handling. */ private CorsHandler cors; /** Optional rate-limit gate; {@code null} disables rate limiting. */ private RateLimitGate gate; + /** Optional authentication gate; {@code null} disables the auth layer. */ + private AuthGate authGate; /** Optional WebSocket router; {@code null} disables WebSocket support. */ private WebSocketRouter wsRouter; /** WebSocket configuration; only used when {@link #wsRouter} is set. */ private WebSocketConfig wsConfig; + /** Trusted-proxy policy for resolving the client IP; never {@code null}. */ + private TrustedProxies trustedProxies = TrustedProxies.none(); + /** Maximum aggregated HTTP request body size in bytes. */ + private int maxHttpContentLength = DEFAULT_MAX_HTTP_CONTENT_LENGTH; + /** Per-connection HTTP read timeout; {@code null} or non-positive disables it. */ + private Duration httpReadTimeout = DEFAULT_HTTP_READ_TIMEOUT; - /** - * Creates a server bound to a port and router. Use {@link #builder(int, Router)} instead of - * calling this directly. - * - * @param port the TCP port to bind - * @param router the router resolving requests - */ private HttpServer(int port, Router router) { this.port = port; - this.router = router; + this.router = Objects.requireNonNull(router, "router"); } /** @@ -69,6 +87,18 @@ public final class HttpServer { return new HttpServer(port, router); } + /** + * Enables TLS (HTTPS / WSS). The TLS handler becomes the first element of every connection's + * pipeline. + * + * @param tls the TLS configuration + * @return this instance, for fluent chaining + */ + public HttpServer withTls(TlsConfig tls) { + this.tls = tls; + return this; + } + /** * Attaches a CORS handler that decorates responses and answers preflight requests. * @@ -91,6 +121,57 @@ public final class HttpServer { return this; } + /** + * Attaches an authentication gate that authenticates configured requests before they reach + * handlers and attaches the resolved principal to the request. + * + * @param authGate the auth gate to use + * @return this instance, for fluent chaining + */ + public HttpServer withAuth(AuthGate authGate) { + this.authGate = authGate; + return this; + } + + /** + * Configures which transport peers are trusted reverse proxies, controlling whether + * {@code X-Forwarded-For} is honoured when resolving the client IP. Defaults to + * {@link TrustedProxies#none()} (forwarded headers ignored). + * + * @param trustedProxies the trusted-proxy policy + * @return this instance, for fluent chaining + */ + public HttpServer withTrustedProxies(TrustedProxies trustedProxies) { + this.trustedProxies = Objects.requireNonNull(trustedProxies, "trustedProxies"); + return this; + } + + /** + * Sets the maximum aggregated HTTP request body size in bytes. Requests exceeding it are + * rejected by the aggregator with {@code 413 Request Entity Too Large}. + * + * @param bytes the limit in bytes; must be positive + * @return this instance, for fluent chaining + */ + public HttpServer maxHttpContentLength(int bytes) { + if (bytes <= 0) throw new IllegalArgumentException("maxHttpContentLength must be > 0"); + this.maxHttpContentLength = bytes; + return this; + } + + /** + * Sets the per-connection HTTP read timeout. A connection that sends no data for this long is + * closed, which both reaps idle keep-alive connections and defeats slow-client (Slowloris) + * attacks. Pass {@code null} or a non-positive duration to disable it. + * + * @param timeout the read timeout, or {@code null}/non-positive to disable + * @return this instance, for fluent chaining + */ + public HttpServer httpReadTimeout(Duration timeout) { + this.httpReadTimeout = timeout; + return this; + } + /** * Enables WebSocket support with default configuration. * @@ -116,46 +197,12 @@ public final class HttpServer { } /** - * Starts the server using the configuration accumulated on this instance and blocks until - * the server channel closes. + * Starts the server using the configuration accumulated on this instance and blocks until the + * server channel closes. * * @throws InterruptedException if the binding or close-future wait is interrupted */ public void start() throws InterruptedException { - start(port, router, cors, gate, wsRouter, wsConfig); - } - - /** - * Starts a server without WebSocket support. Convenience overload of - * {@link #start(int, Router, CorsHandler, RateLimitGate, WebSocketRouter, WebSocketConfig)}. - * - * @param port the TCP port to bind - * @param router the router resolving requests - * @param cors the CORS handler, or {@code null} to disable CORS - * @param gate the rate-limit gate, or {@code null} to disable rate limiting - * @throws InterruptedException if the binding or close-future wait is interrupted - */ - public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate) - throws InterruptedException { - start(port, router, cors, gate, null, null); - } - - /** - * Starts the server, selecting the best transport for the platform, configuring the Netty - * channel pipeline and binding the port. The call blocks until the server channel is closed, - * after which the event-loop groups are shut down gracefully. - * - * @param port the TCP port to bind - * @param router the router resolving requests - * @param cors the CORS handler, or {@code null} to disable CORS - * @param gate the rate-limit gate, or {@code null} to disable rate limiting - * @param wsRouter the WebSocket router, or {@code null} to disable WebSocket support - * @param wsConfig the WebSocket configuration, used only when {@code wsRouter} is non-null - * @throws InterruptedException if the binding or close-future wait is interrupted - */ - public static void start(int port, Router router, CorsHandler cors, RateLimitGate gate, - WebSocketRouter wsRouter, WebSocketConfig wsConfig) - throws InterruptedException { EventLoopGroup boss, worker; Class channelClass; @@ -173,9 +220,17 @@ public final class HttpServer { channelClass = NioServerSocketChannel.class; } - int maxAggregated = wsConfig != null - ? Math.max(1024 * 1024, wsConfig.maxAggregatedMessageSize()) - : 1024 * 1024; + // Capture configuration into effectively-final locals for the channel initializer. + final TlsConfig tlsCfg = this.tls; + final CorsHandler corsHandler = this.cors; + final RateLimitGate rateLimitGate = this.gate; + final AuthGate auth = this.authGate; + final WebSocketRouter websocketRouter = this.wsRouter; + final WebSocketConfig websocketConfig = this.wsConfig; + final TrustedProxies proxies = this.trustedProxies; + final int maxContent = this.maxHttpContentLength; + final long readTimeoutSeconds = (httpReadTimeout != null && !httpReadTimeout.isZero() + && !httpReadTimeout.isNegative()) ? Math.max(1, httpReadTimeout.toSeconds()) : 0; try { new ServerBootstrap() @@ -187,10 +242,18 @@ public final class HttpServer { .childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel ch) { - ch.pipeline() - .addLast(new HttpServerCodec()) - .addLast(new HttpObjectAggregator(maxAggregated)) - .addLast(new HttpRequestHandler(router, cors, gate, wsRouter, wsConfig)); + ChannelPipeline pipeline = ch.pipeline(); + if (tlsCfg != null) { + pipeline.addLast("ssl", tlsCfg.newHandler(ch.alloc())); + } + if (readTimeoutSeconds > 0) { + pipeline.addLast("read-timeout", + new ReadTimeoutHandler(readTimeoutSeconds, TimeUnit.SECONDS)); + } + pipeline.addLast(new HttpServerCodec()) + .addLast(new HttpObjectAggregator(maxContent)) + .addLast(new HttpRequestHandler(router, corsHandler, rateLimitGate, + auth, proxies, websocketRouter, websocketConfig)); } }) .bind(port).sync().channel().closeFuture().sync(); diff --git a/src/main/java/dev/coph/nextusweb/server/auth/AuthConfig.java b/src/main/java/dev/coph/nextusweb/server/auth/AuthConfig.java new file mode 100644 index 0000000..d19f8ea --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/auth/AuthConfig.java @@ -0,0 +1,230 @@ +package dev.coph.nextusweb.server.auth; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Immutable mapping from request paths to the authentication requirement that applies to them, + * consumed by {@link AuthGate}. + * + *

The model mirrors {@code RateLimitConfig}: a path resolves to at most one rule, chosen with + * the precedence

+ *
    + *
  1. exact-path rules (matched by exact equality);
  2. + *
  3. prefix rules (longest matching prefix wins);
  4. + *
  5. an optional global rule applied to every other path.
  6. + *
+ * + *

Each rule pairs an {@link Authenticator} with a {@link Mode}: {@link Mode#REQUIRED} rejects + * the request with {@code 401} when authentication fails, while {@link Mode#OPTIONAL} attaches the + * principal when present but never rejects (useful for endpoints that behave differently for + * signed-in callers). Paths with no matching rule are not authenticated at all.

+ */ +public final class AuthConfig { + + /** Whether a matched rule rejects unauthenticated requests or merely annotates them. */ + public enum Mode { + /** Authentication is mandatory; failure yields {@code 401 Unauthorized}. */ + REQUIRED, + /** Authentication is best-effort; the principal is attached if present, never rejected. */ + OPTIONAL + } + + /** Rule applied to every path with no more specific match, or {@code null} if none. */ + private final Rule globalRule; + /** Rules matched by exact path equality. */ + private final Map exactPathRules; + /** Prefix rules, pre-sorted longest-prefix-first. */ + private final List prefixRules; + /** Optional {@code WWW-Authenticate} challenge sent with {@code 401} responses. */ + private final String challenge; + + private AuthConfig(Builder b) { + this.globalRule = b.globalRule; + this.exactPathRules = Map.copyOf(b.exactPathRules); + this.prefixRules = b.prefixRules.stream() + .sorted((a, c) -> Integer.compare(c.prefix.length(), a.prefix.length())) + .toList(); + this.challenge = b.challenge; + } + + /** + * Creates a builder using the given authenticator as the default for rules that do not specify + * their own. + * + * @param defaultAuthenticator the authenticator used by rules added without an explicit one + * @return a fresh builder + */ + public static Builder builder(Authenticator defaultAuthenticator) { + return new Builder(defaultAuthenticator); + } + + /** + * Returns the rule that applies to the given path, or {@code null} if the path requires no + * authentication. + * + * @param path the request path + * @return the applicable rule, or {@code null} + */ + public Rule ruleFor(String path) { + Rule exact = exactPathRules.get(path); + if (exact != null) return exact; + for (PrefixRule pr : prefixRules) { + if (path.startsWith(pr.prefix)) return pr.rule; + } + return globalRule; + } + + /** + * Returns the configured {@code WWW-Authenticate} challenge sent with {@code 401} responses. + * + * @return the challenge string, or {@code null} if none is configured + */ + public String challenge() { + return challenge; + } + + /** + * An authentication rule: which authenticator to use and whether it is mandatory. + * + * @param authenticator the authenticator to apply + * @param mode whether authentication is required or optional + */ + public record Rule(Authenticator authenticator, Mode mode) { + } + + /** Internal pairing of a path prefix with its rule. */ + private record PrefixRule(String prefix, Rule rule) { + } + + /** + * Fluent builder for {@link AuthConfig}. + */ + public static final class Builder { + private final Authenticator defaultAuthenticator; + private final Map exactPathRules = new HashMap<>(); + private final List prefixRules = new ArrayList<>(); + private Rule globalRule; + private String challenge; + + private Builder(Authenticator defaultAuthenticator) { + this.defaultAuthenticator = Objects.requireNonNull(defaultAuthenticator, + "defaultAuthenticator"); + } + + /** + * Requires authentication for an exact path, using the default authenticator. + * + * @param path the exact path to protect + * @return this builder + */ + public Builder protect(String path) { + return protect(path, defaultAuthenticator); + } + + /** + * Requires authentication for an exact path, using a specific authenticator. + * + * @param path the exact path to protect + * @param authenticator the authenticator to apply on that path + * @return this builder + */ + public Builder protect(String path, Authenticator authenticator) { + exactPathRules.put(path, new Rule(authenticator, Mode.REQUIRED)); + return this; + } + + /** + * Requires authentication for every path starting with the given prefix, using the default + * authenticator. + * + * @param prefix the path prefix to protect + * @return this builder + */ + public Builder protectPrefix(String prefix) { + return protectPrefix(prefix, defaultAuthenticator); + } + + /** + * Requires authentication for every path starting with the given prefix, using a specific + * authenticator. + * + * @param prefix the path prefix to protect + * @param authenticator the authenticator to apply on that prefix + * @return this builder + */ + public Builder protectPrefix(String prefix, Authenticator authenticator) { + prefixRules.add(new PrefixRule(prefix, new Rule(authenticator, Mode.REQUIRED))); + return this; + } + + /** + * Optionally authenticates an exact path: the principal is attached when credentials are + * valid, but the request is never rejected. + * + * @param path the exact path + * @return this builder + */ + public Builder optional(String path) { + exactPathRules.put(path, new Rule(defaultAuthenticator, Mode.OPTIONAL)); + return this; + } + + /** + * Optionally authenticates a path prefix (attach-if-present, never reject). + * + * @param prefix the path prefix + * @return this builder + */ + public Builder optionalPrefix(String prefix) { + prefixRules.add(new PrefixRule(prefix, new Rule(defaultAuthenticator, Mode.OPTIONAL))); + return this; + } + + /** + * Optionally authenticates every path that has no more specific rule (attach-if-present, + * never reject). Handy for attaching a principal everywhere while protecting only selected + * routes with {@link #protect(String)}. + * + * @return this builder + */ + public Builder optionalEverywhere() { + this.globalRule = new Rule(defaultAuthenticator, Mode.OPTIONAL); + return this; + } + + /** + * Requires authentication for every path that has no more specific rule. + * + * @return this builder + */ + public Builder requireEverywhere() { + this.globalRule = new Rule(defaultAuthenticator, Mode.REQUIRED); + return this; + } + + /** + * Sets the {@code WWW-Authenticate} challenge header value sent with {@code 401} + * responses (for example {@code "Basic realm=\"api\""}). + * + * @param challenge the challenge value + * @return this builder + */ + public Builder challenge(String challenge) { + this.challenge = challenge; + return this; + } + + /** + * Builds the immutable configuration. + * + * @return the configured instance + */ + public AuthConfig build() { + return new AuthConfig(this); + } + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/auth/AuthGate.java b/src/main/java/dev/coph/nextusweb/server/auth/AuthGate.java new file mode 100644 index 0000000..6d6130e --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/auth/AuthGate.java @@ -0,0 +1,69 @@ +package dev.coph.nextusweb.server.auth; + +import dev.coph.nextusweb.server.router.Request; +import dev.coph.nextusweb.server.router.Response; + +/** + * Request-pipeline entry point for the authentication layer. Given a request and its path it + * consults the {@link AuthConfig}, runs the applicable {@link Authenticator}, and decides whether + * the request may proceed. + * + *

On success the resolved {@link Principal} is attached to the {@link Request} (visible to + * downstream rate limiting, middlewares and handlers via {@link Request#principal()}). The gate is + * stateless and may be shared across all connections.

+ */ +public final class AuthGate { + + /** The policy this gate enforces. */ + private final AuthConfig config; + + /** + * Creates a gate for the given configuration. + * + * @param config the authentication policy to enforce + */ + public AuthGate(AuthConfig config) { + this.config = config; + } + + /** + * Applies the authentication policy to a request. + * + *

Returns {@code null} when the request may proceed — either because the path requires no + * authentication, because a principal was resolved (and has been attached to {@code req}), or + * because the path is only optionally authenticated. Returns a populated rejection response + * ({@code 401} when a required credential is missing/invalid, {@code 500} when the + * authenticator itself errors) that the caller should send instead of invoking the handler. + * Error details are never leaked to the client.

+ * + * @param req the incoming request + * @param path the resolved request path + * @return {@code null} to proceed, or the rejection response to send + */ + public Response authenticate(Request req, String path) { + AuthConfig.Rule rule = config.ruleFor(path); + if (rule == null) return null; + + Principal principal; + try { + principal = rule.authenticator().authenticate(req); + } catch (Exception e) { + return new Response().status(500).json("{\"error\":\"Authentication error\"}"); + } + + if (principal != null) { + req.principal(principal); + return null; + } + + if (rule.mode() == AuthConfig.Mode.OPTIONAL) { + return null; + } + + Response res = new Response().status(401).json("{\"error\":\"Unauthorized\"}"); + if (config.challenge() != null) { + res.header("WWW-Authenticate", config.challenge()); + } + return res; + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/auth/Authenticator.java b/src/main/java/dev/coph/nextusweb/server/auth/Authenticator.java new file mode 100644 index 0000000..9b93f8e --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/auth/Authenticator.java @@ -0,0 +1,128 @@ +package dev.coph.nextusweb.server.auth; + +import dev.coph.nextusweb.server.router.Request; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * Establishes the {@link Principal} behind a request from whatever credential the application + * uses. The framework is deliberately scheme-agnostic: ready-made factories cover API keys, + * session cookies, HTTP Basic and bearer tokens, and {@code Authenticator} is a functional + * interface so any custom scheme (mutual-TLS client certs, HMAC-signed requests, opaque session + * stores, ...) can be plugged in directly. + * + *

An authenticator returns the resolved {@link Principal} on success, or {@code null} when the + * request carries no usable or no valid credential. It must not throw for the ordinary + * "no/invalid credential" case (return {@code null} instead); a thrown exception is treated by + * {@link AuthGate} as an internal error, not an authentication failure.

+ */ +@FunctionalInterface +public interface Authenticator { + + /** + * Attempts to authenticate a request. + * + * @param request the incoming request + * @return the authenticated principal, or {@code null} if the request is unauthenticated + * @throws Exception if an unexpected error occurs while validating the credential (treated as + * an internal error, not an authentication failure) + */ + Principal authenticate(Request request) throws Exception; + + /** + * Authenticates via an API key carried in a request header (for example {@code X-API-Key}). + * The validator maps a presented key to a {@link Principal}, or to {@code null} if the key is + * unknown/revoked. Prefer a constant-time lookup in the validator to avoid timing oracles. + * + * @param headerName the header carrying the API key + * @param validator maps a presented key to a principal, or {@code null} if invalid + * @return an API-key authenticator + */ + static Authenticator apiKey(String headerName, Function validator) { + return request -> { + String key = request.header(headerName); + if (key == null || key.isEmpty()) return null; + return validator.apply(key); + }; + } + + /** + * Authenticates via a session (or other) cookie. The validator maps a presented cookie value + * to a {@link Principal}, or to {@code null} if the session is unknown/expired. + * + * @param cookieName the cookie carrying the credential + * @param validator maps a presented cookie value to a principal, or {@code null} if invalid + * @return a cookie authenticator + */ + static Authenticator cookie(String cookieName, Function validator) { + return request -> { + String value = request.cookie(cookieName); + if (value == null || value.isEmpty()) return null; + return validator.apply(value); + }; + } + + /** + * Authenticates via HTTP Basic credentials from the {@code Authorization} header. The + * validator receives the decoded username and password and returns a {@link Principal} or + * {@code null}. Malformed headers resolve to {@code null} rather than an error. + * + * @param validator maps {@code (username, password)} to a principal, or {@code null} if invalid + * @return a Basic-auth authenticator + */ + static Authenticator basic(BiFunction validator) { + return request -> { + String header = request.header("Authorization"); + if (header == null || !header.regionMatches(true, 0, "Basic ", 0, 6)) return null; + String encoded = header.substring(6).trim(); + byte[] decoded; + try { + decoded = Base64.getDecoder().decode(encoded); + } catch (IllegalArgumentException e) { + return null; + } + String creds = new String(decoded, StandardCharsets.UTF_8); + int colon = creds.indexOf(':'); + if (colon < 0) return null; + return validator.apply(creds.substring(0, colon), creds.substring(colon + 1)); + }; + } + + /** + * Authenticates via a bearer token from the {@code Authorization} header. Provided for + * completeness; the rest of the framework never requires bearer tokens. + * + * @param validator maps a presented token to a principal, or {@code null} if invalid + * @return a bearer-token authenticator + */ + static Authenticator bearer(Function validator) { + return request -> { + String header = request.header("Authorization"); + if (header == null || !header.regionMatches(true, 0, "Bearer ", 0, 7)) return null; + String token = header.substring(7).trim(); + if (token.isEmpty()) return null; + return validator.apply(token); + }; + } + + /** + * Combines several authenticators, trying each in order and returning the first principal one + * of them produces. Lets an endpoint accept, say, either an API key or a session cookie. + * + * @param authenticators the authenticators to try, in order + * @return a composite authenticator + */ + static Authenticator anyOf(Authenticator... authenticators) { + Authenticator[] copy = authenticators.clone(); + return request -> { + for (Authenticator a : copy) { + Principal p = a.authenticate(request); + if (p != null) return p; + } + return null; + }; + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/auth/Principal.java b/src/main/java/dev/coph/nextusweb/server/auth/Principal.java new file mode 100644 index 0000000..5f3ec51 --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/auth/Principal.java @@ -0,0 +1,96 @@ +package dev.coph.nextusweb.server.auth; + +import java.util.Map; +import java.util.Set; + +/** + * The authenticated identity attached to a request by the {@link AuthGate auth layer}. + * + *

The framework deliberately does not prescribe how an identity is established — + * it may come from an API key, a session cookie, HTTP Basic credentials, a mutual-TLS client + * certificate or any custom scheme implemented by an {@link Authenticator}. All the rest of the + * pipeline needs is a stable {@link #id() identifier} (used, for example, as a rate-limit key + * via {@link dev.coph.nextusweb.server.ratelimit.KeyResolver#principal()}) plus optional + * {@link #roles() roles} for authorization and free-form {@link #claims() claims}.

+ * + *

A ready-made immutable implementation is available via {@link #of(String)} and + * {@link #of(String, Set)}; applications may also implement this interface to carry richer + * domain objects.

+ */ +public interface Principal { + + /** + * Returns the stable, unique identifier of this principal (for example a user id, an account + * name or an API-key id). Used wherever the identity must be reduced to a single string, such + * as principal-based rate limiting. + * + * @return the principal identifier; never {@code null} + */ + String id(); + + /** + * Returns the roles granted to this principal, for coarse-grained authorization checks. + * + * @return the (possibly empty) set of roles; never {@code null} + */ + default Set roles() { + return Set.of(); + } + + /** + * Indicates whether this principal holds the given role. + * + * @param role the role to test for + * @return {@code true} if {@link #roles()} contains {@code role} + */ + default boolean hasRole(String role) { + return roles().contains(role); + } + + /** + * Returns arbitrary additional attributes describing this principal (for example token + * scopes, an email address or tenant information). + * + * @return the (possibly empty) claim map; never {@code null} + */ + default Map claims() { + return Map.of(); + } + + /** + * Creates a simple immutable principal with no roles. + * + * @param id the principal identifier + * @return a principal carrying only the given id + */ + static Principal of(String id) { + return of(id, Set.of()); + } + + /** + * Creates a simple immutable principal with the given id and roles. + * + * @param id the principal identifier + * @param roles the roles granted to the principal + * @return an immutable principal + */ + static Principal of(String id, Set roles) { + Set copy = Set.copyOf(roles); + return new Principal() { + @Override + public String id() { + return id; + } + + @Override + public Set roles() { + return copy; + } + + @Override + public String toString() { + return "Principal[id=" + id + ", roles=" + copy + "]"; + } + }; + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/net/ClientIp.java b/src/main/java/dev/coph/nextusweb/server/net/ClientIp.java new file mode 100644 index 0000000..17d3bea --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/net/ClientIp.java @@ -0,0 +1,53 @@ +package dev.coph.nextusweb.server.net; + +/** + * Resolves the effective client IP of a request from the transport-level peer address and an + * optional {@code X-Forwarded-For} header, according to a {@link TrustedProxies} policy. + * + *

This is the single place where forwarded headers are interpreted, so the spoofing-resistant + * logic lives in one spot and is reused by rate limiting, the auth layer and logging.

+ */ +public final class ClientIp { + + /** The de-facto standard header proxies use to record the originating client chain. */ + public static final String FORWARDED_FOR_HEADER = "X-Forwarded-For"; + + private ClientIp() { + } + + /** + * Resolves the client IP. + * + *

If the immediate peer is not a trusted proxy (or no forwarded header is present), the + * transport-level {@code socketIp} is returned unchanged — a directly connected client cannot + * influence its own apparent address. Otherwise the comma-separated forwarded chain is walked + * from right to left and the first address that is not itself a trusted proxy is + * returned: that is the closest hop the trusted infrastructure actually observed and which the + * real client cannot forge. If every listed hop is trusted, the left-most entry is used.

+ * + * @param socketIp the transport-level peer IP (never {@code null} in practice) + * @param forwardedForHeader the {@code X-Forwarded-For} header value, may be {@code null} + * @param trusted the trusted-proxy policy + * @return the resolved client IP + */ + public static String resolve(String socketIp, String forwardedForHeader, TrustedProxies trusted) { + if (forwardedForHeader == null || forwardedForHeader.isBlank() + || !trusted.isTrusted(socketIp)) { + return socketIp; + } + + String[] hops = forwardedForHeader.split(","); + for (int i = hops.length - 1; i >= 0; i--) { + String hop = hops[i].trim(); + if (hop.isEmpty()) continue; + if (!trusted.isTrusted(hop)) { + return hop; + } + } + + // Every hop in the chain is a trusted proxy; fall back to the originating (left-most) + // entry, or the socket address if the header was effectively empty. + String first = hops[0].trim(); + return first.isEmpty() ? socketIp : first; + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/net/TrustedProxies.java b/src/main/java/dev/coph/nextusweb/server/net/TrustedProxies.java new file mode 100644 index 0000000..25bb16c --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/net/TrustedProxies.java @@ -0,0 +1,162 @@ +package dev.coph.nextusweb.server.net; + +import io.netty.util.NetUtil; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Describes which transport-level peers are trusted reverse proxies, and is the basis for safely + * honouring {@code X-Forwarded-For} (and similar) headers. + * + *

Forwarded-for headers are client-supplied and therefore trivially spoofable: a client that + * talks to the server directly can claim any IP it likes. Trusting them unconditionally lets an + * attacker forge a fresh client IP on every request, which defeats IP-based rate limiting and + * pollutes logs. To avoid that, {@link ClientIp} only consults the forwarded header when the + * immediate peer is a trusted proxy, and then walks the header from right to left skipping + * further trusted hops — so the value returned is the address of the first untrusted + * hop, which cannot be spoofed by the real client.

+ * + *

Both IPv4 and IPv6 CIDR ranges are supported. Instances are immutable and thread-safe.

+ */ +public final class TrustedProxies { + + /** Shared instance that trusts no peer; forwarded headers are always ignored. */ + private static final TrustedProxies NONE = new TrustedProxies(List.of(), false); + /** Shared instance that trusts every peer; forwarded headers are always honoured. */ + private static final TrustedProxies ALL = new TrustedProxies(List.of(), true); + + /** Parsed CIDR ranges of trusted proxies. */ + private final List cidrs; + /** When {@code true}, every peer is trusted regardless of {@link #cidrs}. */ + private final boolean trustAll; + + private TrustedProxies(List cidrs, boolean trustAll) { + this.cidrs = cidrs; + this.trustAll = trustAll; + } + + /** + * Returns a policy that trusts no peer. Forwarded headers are ignored and the transport-level + * peer address is always used. This is the safe default for servers exposed directly to + * clients. + * + * @return a never-trust policy + */ + public static TrustedProxies none() { + return NONE; + } + + /** + * Returns a policy that trusts every peer and therefore always honours the forwarded header. + * + *

Only use this when the server can never be reached except through a trusted + * proxy that overwrites the forwarded header (for example a private network behind a load + * balancer), otherwise it reintroduces the spoofing problem.

+ * + * @return a trust-all policy + */ + public static TrustedProxies all() { + return ALL; + } + + /** + * Returns a policy that trusts peers whose address falls inside any of the given CIDR ranges. + * A bare address (without {@code /prefix}) is treated as a single host ({@code /32} for IPv4, + * {@code /128} for IPv6). + * + * @param cidrs the trusted ranges, e.g. {@code "10.0.0.0/8"}, {@code "127.0.0.1"}, + * {@code "::1"}, {@code "fd00::/8"} + * @return a policy trusting the given ranges + * @throws IllegalArgumentException if a range cannot be parsed + */ + public static TrustedProxies of(String... cidrs) { + List parsed = new ArrayList<>(cidrs.length); + for (String c : cidrs) parsed.add(Cidr.parse(c)); + return new TrustedProxies(List.copyOf(parsed), false); + } + + /** + * Tests whether the given IP address belongs to a trusted proxy. + * + * @param ip the literal IP address to test (no DNS resolution is performed) + * @return {@code true} if the address is trusted + */ + public boolean isTrusted(String ip) { + if (trustAll) return true; + if (ip == null || cidrs.isEmpty()) return false; + byte[] addr = NetUtil.createByteArrayFromIpAddressString(ip); + if (addr == null) return false; + for (Cidr c : cidrs) { + if (c.contains(addr)) return true; + } + return false; + } + + /** + * An IP range expressed as a base address and a prefix length, matching either IPv4 or IPv6. + * + * @param base the network base address bytes (4 for IPv4, 16 for IPv6) + * @param prefixBits the number of leading bits that must match + */ + private record Cidr(byte[] base, int prefixBits) { + + static Cidr parse(String spec) { + if (spec == null || spec.isBlank()) { + throw new IllegalArgumentException("Empty CIDR specification"); + } + String trimmed = spec.trim(); + int slash = trimmed.indexOf('/'); + String ipPart = slash >= 0 ? trimmed.substring(0, slash) : trimmed; + byte[] base = NetUtil.createByteArrayFromIpAddressString(ipPart); + if (base == null) { + throw new IllegalArgumentException("Not a valid IP address: " + ipPart); + } + int maxBits = base.length * 8; + int prefixBits = maxBits; + if (slash >= 0) { + try { + prefixBits = Integer.parseInt(trimmed.substring(slash + 1).trim()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid CIDR prefix in: " + spec, e); + } + if (prefixBits < 0 || prefixBits > maxBits) { + throw new IllegalArgumentException("CIDR prefix out of range in: " + spec); + } + } + return new Cidr(base, prefixBits); + } + + boolean contains(byte[] addr) { + if (addr.length != base.length) return false; // different address family + int fullBytes = prefixBits / 8; + for (int i = 0; i < fullBytes; i++) { + if (addr[i] != base[i]) return false; + } + int remainingBits = prefixBits % 8; + if (remainingBits != 0) { + int mask = 0xFF << (8 - remainingBits) & 0xFF; + if ((addr[fullBytes] & mask) != (base[fullBytes] & mask)) return false; + } + return true; + } + + // Records with array components get identity-based equals/hashCode by default; provide + // value semantics so deduplication and tests behave intuitively. + @Override + public boolean equals(Object o) { + return o instanceof Cidr c && prefixBits == c.prefixBits && Arrays.equals(base, c.base); + } + + @Override + public int hashCode() { + return 31 * Arrays.hashCode(base) + prefixBits; + } + + @Override + public String toString() { + return "Cidr[base=" + Arrays.toString(base) + ", prefixBits=" + prefixBits + "]"; + } + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/FixedWindowLimiter.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/FixedWindowLimiter.java index d1a5c03..53a46c1 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/FixedWindowLimiter.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/FixedWindowLimiter.java @@ -51,6 +51,7 @@ public final class FixedWindowLimiter implements RateLimiter { * * @param olderThanNanos maximum age in nanoseconds before a window is removed */ + @Override public void cleanup(long olderThanNanos) { long now = System.nanoTime(); windows.entrySet().removeIf(e -> now - e.getValue().windowStart.get() > olderThanNanos); @@ -80,23 +81,30 @@ public final class FixedWindowLimiter implements RateLimiter { * Rolls the window over if it has expired, then counts this request and decides whether * it stays within the limit. * + *

The roll-over (resetting {@code windowStart} and {@code count}) and the subsequent + * increment must happen atomically together: the previous lock-free version reset the + * count in one thread while another was incrementing it, so increments were silently lost + * and the window admitted more than {@code limit} requests around a boundary. Guarding the + * whole operation with the window's monitor keeps the count exact; contention is per key + * only, so throughput is unaffected in practice.

+ * * @param now the current time in nanoseconds * @param limit the per-window request limit * @param windowNanos the window length in nanoseconds * @return an allow result with the remaining quota, or a deny result with the time until * the window resets */ - Result tryAcquire(long now, long limit, long windowNanos) { + synchronized Result tryAcquire(long now, long limit, long windowNanos) { long start = windowStart.get(); if (now - start >= windowNanos) { - if (windowStart.compareAndSet(start, now)) { - count.set(0); - } + windowStart.set(now); + count.set(0); + start = now; } long current = count.incrementAndGet(); if (current > limit) { - long retryMs = (windowNanos - (now - windowStart.get())) / 1_000_000L; + long retryMs = (windowNanos - (now - start)) / 1_000_000L; return Result.deny(limit, Math.max(1, retryMs)); } return Result.allow(limit - current, limit); diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/KeyResolver.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/KeyResolver.java index 5ca24cf..24e022a 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/KeyResolver.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/KeyResolver.java @@ -1,14 +1,22 @@ package dev.coph.nextusweb.server.ratelimit; -import io.netty.handler.codec.http.HttpRequest; +import dev.coph.nextusweb.server.auth.Principal; +import dev.coph.nextusweb.server.router.Request; /** * Strategy for deriving the logical key under which a request is rate limited. The key - * determines which bucket a request counts against — for example one bucket per client IP, or - * one per authenticated user. + * determines which bucket a request counts against — for example one bucket per client IP, one + * per API key, one per session cookie, or one per authenticated user. * - *

Two ready-made resolvers are provided as factory methods: {@link #clientIp()} and - * {@link #userOrIp()}.

+ *

Resolvers receive the framework {@link Request} together with the already-resolved client + * IP (the pipeline computes it once, honouring the configured trusted proxies — see + * {@link dev.coph.nextusweb.server.net.ClientIp}). They are therefore not tied + * to bearer tokens: pick whichever request facet identifies the caller for your API.

+ * + *

Built-in resolvers: {@link #clientIp()}, {@link #header(String)}, {@link #cookie(String)} + * and {@link #principal()}. The header/cookie/principal resolvers fall back to the client IP when + * their facet is absent, so an anonymous caller is still bucketed rather than sharing one global + * bucket. Each key is additionally namespaced by the rule, so different rules never collide.

*/ @FunctionalInterface public interface KeyResolver { @@ -16,45 +24,65 @@ public interface KeyResolver { /** * Resolves the rate-limit key for a request. * - * @param req the incoming HTTP request, used to inspect headers - * @param remoteAddress the transport-level remote address, used as a fallback - * @return the key the request should be counted against + * @param request the incoming request (headers, cookies, attached principal, ...) + * @param clientIp the resolved client IP, honouring trusted proxies + * @return the key the request should be counted against; never {@code null} */ - String resolve(HttpRequest req, String remoteAddress); + String resolve(Request request, String clientIp); /** - * Returns a resolver that keys on the client IP address. It prefers the first entry of the - * {@code X-Forwarded-For} header (so it works behind a reverse proxy) and falls back to the - * transport-level remote address when that header is absent. + * Returns a resolver that keys purely on the resolved client IP. This is the spoofing-safe + * replacement for the old header-trusting behaviour: the IP has already been derived through + * the trusted-proxy policy, so a directly connected client cannot forge it. * * @return a client-IP key resolver */ static KeyResolver clientIp() { - return (req, remote) -> { - String forwarded = req.headers().get("X-Forwarded-For"); - if (forwarded != null && !forwarded.isEmpty()) { - int comma = forwarded.indexOf(','); - return comma > 0 ? forwarded.substring(0, comma).trim() : forwarded.trim(); - } - return remote; + return (request, clientIp) -> clientIp; + } + + /** + * Returns a resolver that keys on the value of a request header (for example an API key in + * {@code X-API-Key}), falling back to the client IP when the header is absent. + * + * @param headerName the header to key on + * @return a header-value key resolver + */ + static KeyResolver header(String headerName) { + return (request, clientIp) -> { + String value = request.header(headerName); + return (value != null && !value.isEmpty()) ? "h:" + value : "ip:" + clientIp; }; } /** - * Returns a resolver that keys on the authenticated user when possible, falling back to the - * client IP otherwise. A {@code Bearer} token from the {@code Authorization} header yields a - * {@code "u:"} key; otherwise the {@code "ip:
"} key from {@link #clientIp()} - * is used. + * Returns a resolver that keys on the value of a request cookie (for example a session id), + * falling back to the client IP when the cookie is absent. * - * @return a user-or-IP key resolver + * @param cookieName the cookie to key on + * @return a cookie-value key resolver */ - static KeyResolver userOrIp() { - return (req, remote) -> { - String auth = req.headers().get("Authorization"); - if (auth != null && auth.startsWith("Bearer ")) { - return "u:" + auth.substring(7); - } - return "ip:" + clientIp().resolve(req, remote); + static KeyResolver cookie(String cookieName) { + return (request, clientIp) -> { + String value = request.cookie(cookieName); + return (value != null && !value.isEmpty()) ? "c:" + value : "ip:" + clientIp; + }; + } + + /** + * Returns a resolver that keys on the authenticated {@link Principal} attached to the request, + * falling back to the client IP for unauthenticated requests. + * + *

For this to key on the principal, the {@link dev.coph.nextusweb.server.auth.AuthGate auth + * layer} must have run before rate limiting (configure it to authenticate the relevant paths). + * When no principal is present the resolver degrades gracefully to per-IP limiting.

+ * + * @return a principal-or-IP key resolver + */ + static KeyResolver principal() { + return (request, clientIp) -> { + Principal p = request.principal(); + return p != null ? "p:" + p.id() : "ip:" + clientIp; }; } } diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/LeakyBucketLimiter.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/LeakyBucketLimiter.java index a0cb2f0..7abd9d1 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/LeakyBucketLimiter.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/LeakyBucketLimiter.java @@ -1,7 +1,7 @@ package dev.coph.nextusweb.server.ratelimit; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; /** * A {@link RateLimiter} implementing the leaky bucket algorithm. @@ -12,8 +12,10 @@ import java.util.concurrent.atomic.AtomicLong; * enough has leaked away. Compared to the token bucket this smooths bursts into a steady * outflow rather than allowing them through up front.

* - *

State is held in {@link AtomicLong}s and updated with a lock-free compare-and-set loop, so - * the limiter is safe for concurrent use.

+ *

Each bucket's water level and last-leak timestamp are held together in a single immutable + * {@link LeakyBucket.State} behind one {@link AtomicReference} and advanced with a lock-free + * compare-and-set loop, so the level and the timestamp it was leaked to are always published + * together and the limiter is safe for concurrent use.

*/ public final class LeakyBucketLimiter implements RateLimiter { @@ -51,20 +53,19 @@ public final class LeakyBucketLimiter implements RateLimiter { * * @param olderThanNanos maximum idle age in nanoseconds before a bucket is removed */ + @Override public void cleanup(long olderThanNanos) { long now = System.nanoTime(); - buckets.entrySet().removeIf(e -> now - e.getValue().lastLeakNanos.get() > olderThanNanos); + buckets.entrySet().removeIf(e -> now - e.getValue().lastLeak() > olderThanNanos); } /** * A single client's leaky bucket, tracking the current water level and the timestamp up to - * which leakage has been accounted for. + * which leakage has been accounted for as one atomic unit. */ private static final class LeakyBucket { - /** Current water level (number of units in the bucket). */ - final AtomicLong waterLevel; - /** Timestamp, in nanoseconds, up to which leakage has been applied. */ - final AtomicLong lastLeakNanos; + /** Holds the current {@code (waterLevel, lastLeakNanos)} pair as one atomic unit. */ + private final AtomicReference state; /** * Creates an empty bucket. @@ -72,12 +73,23 @@ public final class LeakyBucketLimiter implements RateLimiter { * @param now the creation timestamp in nanoseconds */ LeakyBucket(long now) { - this.waterLevel = new AtomicLong(0); - this.lastLeakNanos = new AtomicLong(now); + this.state = new AtomicReference<>(new State(0, now)); } /** - * Applies elapsed leakage and, if there is room, adds one unit of water. + * Returns the timestamp leakage was last accounted to, used by {@link #cleanup(long)}. + * + * @return the last-leak timestamp in nanoseconds + */ + long lastLeak() { + return state.get().lastLeakNanos(); + } + + /** + * Applies elapsed leakage and, if there is room, adds one unit of water. The new level and + * the timestamp it was leaked to are swapped in together, so the previous race where the + * level advanced but the timestamp update was lost (drifting the leak accounting) can no + * longer occur. * * @param now the current time in nanoseconds * @param capacity the bucket capacity @@ -87,24 +99,33 @@ public final class LeakyBucketLimiter implements RateLimiter { */ Result tryAcquire(long now, long capacity, long leakIntervalNanos) { while (true) { - long lastLeak = lastLeakNanos.get(); - long current = waterLevel.get(); + State current = state.get(); - long leaked = (now - lastLeak) / leakIntervalNanos; - long newLevel = Math.max(0, current - leaked); + long leaked = (now - current.lastLeakNanos()) / leakIntervalNanos; + long newLevel = Math.max(0, current.waterLevel() - leaked); if (newLevel >= capacity) { long retryMs = leakIntervalNanos / 1_000_000L; return Result.deny(capacity, retryMs); } - long newLastLeak = leaked > 0 ? lastLeak + leaked * leakIntervalNanos : lastLeak; + long newLastLeak = leaked > 0 + ? current.lastLeakNanos() + leaked * leakIntervalNanos + : current.lastLeakNanos(); - if (waterLevel.compareAndSet(current, newLevel + 1)) { - lastLeakNanos.compareAndSet(lastLeak, newLastLeak); + if (state.compareAndSet(current, new State(newLevel + 1, newLastLeak))) { return Result.allow(capacity - newLevel - 1, capacity); } } } + + /** + * Immutable snapshot of a bucket's mutable state. + * + * @param waterLevel current water level (number of units in the bucket) + * @param lastLeakNanos timestamp leakage has been applied up to, in nanoseconds + */ + private record State(long waterLevel, long lastLeakNanos) { + } } } diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfig.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfig.java index 7e3d9cf..696e272 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfig.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfig.java @@ -1,9 +1,12 @@ package dev.coph.nextusweb.server.ratelimit; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.Set; /** * Immutable mapping from request paths to the {@link Rule rate-limit rules} that apply to them. @@ -27,6 +30,8 @@ public final class RateLimitConfig { private final Map exactPathRules; /** Prefix rules, pre-sorted longest-prefix-first so the most specific match wins. */ private final List prefixRules; + /** Every distinct limiter referenced by any rule, by identity; used for periodic cleanup. */ + private final Set allLimiters; /** * Builds an immutable configuration from a {@link Builder}, copying the exact-path rules @@ -40,6 +45,15 @@ public final class RateLimitConfig { this.prefixRules = b.prefixRules.stream() .sorted((a, c) -> Integer.compare(c.prefix.length(), a.prefix.length())) .toList(); + + // Collect the distinct limiter instances once so the gate's periodic cleanup can iterate + // them. Identity-based de-duplication keeps a limiter shared across several rules from + // being cleaned multiple times per pass. + Set limiters = Collections.newSetFromMap(new IdentityHashMap<>()); + if (globalRule != null) limiters.add(globalRule.limiter()); + for (Rule r : exactPathRules.values()) limiters.add(r.limiter()); + for (PrefixRule pr : prefixRules) limiters.add(pr.rule.limiter()); + this.allLimiters = Collections.unmodifiableSet(limiters); } /** @@ -79,6 +93,16 @@ public final class RateLimitConfig { return rules; } + /** + * Returns every distinct limiter referenced by this configuration, for periodic state + * eviction by {@link RateLimitGate}. + * + * @return the immutable set of distinct limiters (de-duplicated by identity) + */ + public Set allLimiters() { + return allLimiters; + } + /** * A single rate-limit rule: a limiter, the key resolver feeding it, and a name used to * namespace keys and aid diagnostics. diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitGate.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitGate.java index 445d5b7..2f61b80 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitGate.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimitGate.java @@ -1,7 +1,7 @@ package dev.coph.nextusweb.server.ratelimit; +import dev.coph.nextusweb.server.router.Request; import dev.coph.nextusweb.server.router.Response; -import io.netty.handler.codec.http.HttpRequest; import java.util.List; import java.util.concurrent.Executors; @@ -22,19 +22,38 @@ import java.util.concurrent.TimeUnit; */ public final class RateLimitGate { + /** Default idle age after which per-key limiter state is eligible for eviction. */ + private static final long DEFAULT_STALE_AFTER_NANOS = 10L * 60 * 1_000_000_000L; + /** The rule set this gate enforces. */ private final RateLimitConfig config; + /** Idle age (nanoseconds) after which a limiter's per-key state may be evicted. */ + private final long staleAfterNanos; /** Single-threaded scheduler driving periodic cleanup of stale buckets. */ private final ScheduledExecutorService cleanup; /** * Creates a gate for the given configuration and starts a background cleanup task that runs - * every five minutes on a daemon thread. + * every five minutes on a daemon thread, evicting per-key state idle for more than ten + * minutes. * * @param config the rate-limit rules to enforce */ public RateLimitGate(RateLimitConfig config) { + this(config, DEFAULT_STALE_AFTER_NANOS); + } + + /** + * Creates a gate with an explicit idle age before per-key limiter state is evicted. + * + * @param config the rate-limit rules to enforce + * @param staleAfterNanos idle age in nanoseconds after which per-key state is evicted; must + * be positive + */ + public RateLimitGate(RateLimitConfig config, long staleAfterNanos) { + if (staleAfterNanos <= 0) throw new IllegalArgumentException("staleAfterNanos must be > 0"); this.config = config; + this.staleAfterNanos = staleAfterNanos; this.cleanup = Executors.newSingleThreadScheduledExecutor(r -> { Thread t = new Thread(r, "ratelimit-cleanup"); t.setDaemon(true); @@ -52,12 +71,13 @@ public final class RateLimitGate { * independent. The first denial short-circuits and is returned immediately; if every rule * allows the request, the result with the least remaining quota is returned.

* - * @param req the incoming request, used by key resolvers - * @param path the request path used to select rules - * @param remoteAddress the client's remote address, used as a key-resolver fallback + * @param req the incoming request, used by key resolvers + * @param path the request path used to select rules + * @param clientIp the resolved client IP (honouring trusted proxies), used as a key-resolver + * fallback * @return the limiting result, or {@code null} if no rule applies to the path */ - public RateLimiter.Result check(HttpRequest req, String path, String remoteAddress) { + public RateLimiter.Result check(Request req, String path, String clientIp) { List rules = config.rulesFor(path); if (rules.isEmpty()) return null; @@ -65,7 +85,7 @@ public final class RateLimitGate { RateLimiter.Result strictest = null; for (var rule : rules) { - String key = rule.name() + ":" + rule.keyResolver().resolve(req, remoteAddress); + String key = rule.name() + ":" + rule.keyResolver().resolve(req, clientIp); RateLimiter.Result result = rule.limiter().tryAcquire(key, now); if (!result.allowed()) return result; @@ -97,11 +117,18 @@ public final class RateLimitGate { } /** - * Periodic cleanup hook invoked by the background scheduler to evict limiter state that has - * not been touched recently (older than roughly ten minutes). + * Periodic cleanup hook invoked by the background scheduler. Asks every configured limiter to + * evict per-key state idle for longer than {@link #staleAfterNanos}. A failure cleaning one + * limiter must not abort the others or kill the scheduler, so each call is guarded. */ private void doCleanup() { - long threshold = 10L * 60 * 1_000_000_000L; + for (RateLimiter limiter : config.allLimiters()) { + try { + limiter.cleanup(staleAfterNanos); + } catch (RuntimeException ignored) { + // Best-effort eviction; never let one limiter break the cleanup cycle. + } + } } /** diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimiter.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimiter.java index 246ae93..6004510 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimiter.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/RateLimiter.java @@ -8,6 +8,10 @@ package dev.coph.nextusweb.server.ratelimit; * {@link LeakyBucketLimiter}, {@link FixedWindowLimiter} and {@link SlidingWindowLimiter}. * Implementations are expected to be thread-safe, since the same limiter is shared across all * request-handling threads.

+ * + *

The interface remains effectively functional ({@link #tryAcquire} is its single abstract + * method), so simple stateless limiters can still be written as a lambda; stateful limiters that + * keep one entry per key should additionally override {@link #cleanup(long)}.

*/ public interface RateLimiter { @@ -21,6 +25,24 @@ public interface RateLimiter { */ Result tryAcquire(String key, long nowNanos); + /** + * Evicts per-key state that has not been accessed within the given age, bounding the memory + * a limiter consumes when it has seen many distinct keys. + * + *

Implementations keep one entry per key seen ({@code clientIp}, API key, ...). Without + * periodic eviction those maps grow without bound, which is both a memory leak and a denial + * of service vector (an attacker that varies the key on every request can exhaust the heap). + * {@link RateLimitGate} calls this periodically for every configured limiter.

+ * + *

The default implementation does nothing, which is correct for stateless limiters; any + * limiter that retains per-key state must override it to evict stale + * entries.

+ * + * @param olderThanNanos maximum idle age in nanoseconds before an entry is removed + */ + default void cleanup(long olderThanNanos) { + } + /** * Immutable outcome of a {@link #tryAcquire(String, long)} call. * diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/SlidingWindowLimiter.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/SlidingWindowLimiter.java index c13ed14..3ba55b7 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/SlidingWindowLimiter.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/SlidingWindowLimiter.java @@ -53,6 +53,7 @@ public final class SlidingWindowLimiter implements RateLimiter { * * @param olderThanNanos maximum age in nanoseconds before a window is removed */ + @Override public void cleanup(long olderThanNanos) { long now = System.nanoTime(); windows.entrySet().removeIf(e -> now - e.getValue().windowStart.get() > olderThanNanos); diff --git a/src/main/java/dev/coph/nextusweb/server/ratelimit/TokenBucketLimiter.java b/src/main/java/dev/coph/nextusweb/server/ratelimit/TokenBucketLimiter.java index 8bce1aa..2ff7e82 100644 --- a/src/main/java/dev/coph/nextusweb/server/ratelimit/TokenBucketLimiter.java +++ b/src/main/java/dev/coph/nextusweb/server/ratelimit/TokenBucketLimiter.java @@ -1,7 +1,7 @@ package dev.coph.nextusweb.server.ratelimit; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; /** * A {@link RateLimiter} implementing the token bucket algorithm. @@ -12,8 +12,11 @@ import java.util.concurrent.atomic.AtomicLong; * with a retry hint computed from the refill rate. This permits short bursts (up to the bucket * capacity) while bounding the sustained rate.

* - *

Token counts are stored in fixed-point form (scaled by 1e9) inside {@link AtomicLong}s and - * updated with a lock-free compare-and-set loop, so the limiter is safe for concurrent use.

+ *

Token counts are stored in fixed-point form (scaled by 1e9). Each bucket's token count and + * last-refill timestamp are held together in a single immutable {@link Bucket.State} behind one + * {@link AtomicReference} and advanced with a lock-free compare-and-set loop, so a refill and the + * timestamp it is based on are always published as one atomic unit and the limiter is safe for + * concurrent use.

*/ public final class TokenBucketLimiter implements RateLimiter { @@ -55,6 +58,7 @@ public final class TokenBucketLimiter implements RateLimiter { * * @param olderThanNanos maximum idle age in nanoseconds before a bucket is removed */ + @Override public void cleanup(long olderThanNanos) { long now = System.nanoTime(); buckets.entrySet().removeIf(e -> now - e.getValue().lastAccess() > olderThanNanos); @@ -63,68 +67,77 @@ public final class TokenBucketLimiter implements RateLimiter { /** * A single client's token bucket. Tokens are stored in fixed-point form (multiplied by - * 1e9) to retain sub-token precision while using integer atomics. - * - * @param tokensFixed current token count in fixed-point (tokens × 1e9) - * @param lastRefillNanos timestamp of the last refill/consume, in nanoseconds + * 1e9) to retain sub-token precision; the mutable pair {@code (tokens, timestamp)} lives in a + * single {@link AtomicReference} so updates are atomic as a unit. */ - private record Bucket(AtomicLong tokensFixed, AtomicLong lastRefillNanos) { - /** - * Creates a full bucket. - * - * @param tokensFixed initial token count (in whole tokens, scaled internally) - * @param lastRefillNanos the creation timestamp in nanoseconds - */ - private Bucket(long tokensFixed, long lastRefillNanos) { - this(new AtomicLong(tokensFixed * 1_000_000_000L), new AtomicLong(lastRefillNanos)); - } + private static final class Bucket { + /** Holds the current {@code (tokensFixed, lastRefillNanos)} pair as one atomic unit. */ + private final AtomicReference state; - /** - * Returns the timestamp of the last access, used by {@link #cleanup(long)}. - * - * @return the last-refill timestamp in nanoseconds - */ - long lastAccess() { - return lastRefillNanos.get(); - } + /** + * Creates a full bucket. + * + * @param tokens initial token count in whole tokens (scaled internally) + * @param lastRefillNanos the creation timestamp in nanoseconds + */ + Bucket(long tokens, long lastRefillNanos) { + this.state = new AtomicReference<>(new State(tokens * 1_000_000_000L, lastRefillNanos)); + } - /** - * Refills the bucket according to elapsed time and attempts to consume one token, - * retrying via compare-and-set on contention. - * - * @param now the current time in nanoseconds - * @param capacity the bucket capacity in whole tokens - * @param tokensPerNano the refill rate in tokens per nanosecond - * @param refillIntervalNs the nominal nanoseconds per token (unused in the hot path - * but kept for symmetry/retry computation) - * @return an allow result with the remaining tokens, or a deny result with a retry - * hint when fewer than one token is available - */ - Result tryAcquire(long now, long capacity, double tokensPerNano, long refillIntervalNs) { - while (true) { - long lastRefill = lastRefillNanos.get(); - long currentTokens = tokensFixed.get(); + /** + * Returns the timestamp of the last access, used by {@link #cleanup(long)}. + * + * @return the last-refill timestamp in nanoseconds + */ + long lastAccess() { + return state.get().lastRefillNanos(); + } - long elapsed = now - lastRefill; - long refilled = currentTokens; - if (elapsed > 0) { - long addedFixed = (long) (elapsed * tokensPerNano * 1_000_000_000.0); - refilled = Math.min(currentTokens + addedFixed, capacity * 1_000_000_000L); - } + /** + * Refills the bucket according to elapsed time and attempts to consume one token, + * retrying via compare-and-set on contention. The token count and the timestamp it was + * computed from are swapped in together, so no thread can ever observe refilled tokens + * paired with a stale timestamp (or vice versa). + * + * @param now the current time in nanoseconds + * @param capacity the bucket capacity in whole tokens + * @param tokensPerNano the refill rate in tokens per nanosecond + * @param refillIntervalNs the nominal nanoseconds per token (kept for retry computation) + * @return an allow result with the remaining tokens, or a deny result with a retry + * hint when fewer than one token is available + */ + Result tryAcquire(long now, long capacity, double tokensPerNano, long refillIntervalNs) { + long oneTokenFixed = 1_000_000_000L; + while (true) { + State current = state.get(); - long oneTokenFixed = 1_000_000_000L; - if (refilled < oneTokenFixed) { - long deficitFixed = oneTokenFixed - refilled; - long retryNs = (long) (deficitFixed / (tokensPerNano * 1_000_000_000.0)); - return Result.deny(capacity, Math.max(1, retryNs / 1_000_000)); - } + long elapsed = now - current.lastRefillNanos(); + long refilled = current.tokensFixed(); + if (elapsed > 0) { + long addedFixed = (long) (elapsed * tokensPerNano * 1_000_000_000.0); + refilled = Math.min(current.tokensFixed() + addedFixed, capacity * 1_000_000_000L); + } - long newTokens = refilled - oneTokenFixed; - if (tokensFixed.compareAndSet(currentTokens, newTokens)) { - lastRefillNanos.set(now); - return Result.allow(newTokens / 1_000_000_000L, capacity); - } + if (refilled < oneTokenFixed) { + long deficitFixed = oneTokenFixed - refilled; + long retryNs = (long) (deficitFixed / (tokensPerNano * 1_000_000_000.0)); + return Result.deny(capacity, Math.max(1, retryNs / 1_000_000)); + } + + long newTokens = refilled - oneTokenFixed; + if (state.compareAndSet(current, new State(newTokens, now))) { + return Result.allow(newTokens / 1_000_000_000L, capacity); } } } + + /** + * Immutable snapshot of a bucket's mutable state. + * + * @param tokensFixed current token count in fixed-point (tokens × 1e9) + * @param lastRefillNanos timestamp the token count was last advanced to, in nanoseconds + */ + private record State(long tokensFixed, long lastRefillNanos) { + } + } } diff --git a/src/main/java/dev/coph/nextusweb/server/router/Request.java b/src/main/java/dev/coph/nextusweb/server/router/Request.java index 7f8fe10..c51065b 100644 --- a/src/main/java/dev/coph/nextusweb/server/router/Request.java +++ b/src/main/java/dev/coph/nextusweb/server/router/Request.java @@ -1,8 +1,11 @@ package dev.coph.nextusweb.server.router; +import dev.coph.nextusweb.server.auth.Principal; import dev.coph.nextusweb.server.json.JsonMapper; import dev.coph.nextusweb.server.router.exception.BadRequestException; import io.netty.handler.codec.http.*; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; import io.netty.util.CharsetUtil; import tools.jackson.core.JacksonException; import tools.jackson.databind.JsonNode; @@ -32,6 +35,18 @@ public final class Request { /** Lazily parsed JSON body; {@code null} until {@link #json()} is first called. */ private JsonNode jsonCache; + /** Lazily decoded request cookies, keyed by name; {@code null} until first accessed. */ + private Map cookies; + + /** Lazily created bag of per-request attributes set by middlewares/handlers. */ + private Map attributes; + + /** Resolved client IP (honouring trusted proxies); {@code null} until set by the pipeline. */ + private String clientIp; + + /** Authenticated principal attached by the auth layer, or {@code null} if unauthenticated. */ + private Principal principal; + /** * Creates a request wrapper. * @@ -147,6 +162,106 @@ public final class Request { } } + /** + * Returns the value of a request cookie, decoding the {@code Cookie} header on first access. + * + * @param name the cookie name + * @return the cookie value, or {@code null} if no such cookie is present + */ + public String cookie(String name) { + if (cookies == null) { + String header = raw.headers().get(HttpHeaderNames.COOKIE); + if (header == null || header.isEmpty()) { + cookies = Map.of(); + } else { + Map parsed = new HashMap<>(); + for (Cookie c : ServerCookieDecoder.STRICT.decode(header)) { + parsed.putIfAbsent(c.name(), c.value()); + } + cookies = parsed; + } + } + return cookies.get(name); + } + + /** + * Stores a per-request attribute, or removes it when {@code value} is {@code null}. Useful for + * passing state from middlewares or the auth layer to handlers. + * + * @param name the attribute name + * @param value the value to store, or {@code null} to remove it + * @return this request, for fluent chaining + */ + public Request attribute(String name, Object value) { + if (value == null) { + if (attributes != null) attributes.remove(name); + } else { + if (attributes == null) attributes = new HashMap<>(8); + attributes.put(name, value); + } + return this; + } + + /** + * Retrieves a previously stored attribute, cast to the caller's expected type. + * + * @param name the attribute name + * @param the expected attribute type + * @return the stored value, or {@code null} if absent + */ + @SuppressWarnings("unchecked") + public T attribute(String name) { + return attributes == null ? null : (T) attributes.get(name); + } + + /** + * Returns the resolved client IP address for this request. Unlike the raw socket address this + * honours the server's trusted-proxy configuration, so it reflects the originating client + * when the server sits behind a trusted reverse proxy and the socket peer otherwise. + * + * @return the resolved client IP, or {@code null} if the pipeline has not set one + */ + public String clientIp() { + return clientIp; + } + + /** + * Sets the resolved client IP. Called by the request pipeline; not intended for handler use. + * + * @param clientIp the resolved client IP + */ + public void clientIp(String clientIp) { + this.clientIp = clientIp; + } + + /** + * Returns the authenticated principal attached to this request by the auth layer. + * + * @return the principal, or {@code null} if the request was not authenticated + */ + public Principal principal() { + return principal; + } + + /** + * Attaches an authenticated principal to this request. Called by the auth layer; not intended + * for handler use. + * + * @param principal the authenticated principal + */ + public void principal(Principal principal) { + this.principal = principal; + } + + /** + * Returns whether this request carries an authenticated principal. + * + * @return {@code true} if a principal is attached + */ + public boolean isAuthenticated() { + return principal != null; + } + /** * Returns the request's HTTP method. * diff --git a/src/main/java/dev/coph/nextusweb/server/tls/TlsConfig.java b/src/main/java/dev/coph/nextusweb/server/tls/TlsConfig.java new file mode 100644 index 0000000..62c5fff --- /dev/null +++ b/src/main/java/dev/coph/nextusweb/server/tls/TlsConfig.java @@ -0,0 +1,124 @@ +package dev.coph.nextusweb.server.tls; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandler; + +import javax.net.ssl.SSLException; +import java.io.File; +import java.io.InputStream; +import java.util.Objects; + +/** + * Configuration that enables TLS (HTTPS / WSS) on the server. Holds a ready-built Netty + * {@link SslContext} and produces the per-connection {@link SslHandler} the pipeline installs as + * its first handler. + * + *

Enabling TLS is meant to be a one-liner — point the factory at a PEM certificate chain and + * private key and pass the result to {@code HttpServer.withTls(...)}:

+ *
{@code
+ * HttpServer.builder(443, router)
+ *           .withTls(TlsConfig.fromPem(new File("fullchain.pem"), new File("privkey.pem")))
+ *           .start();
+ * }
+ * + *

For full control (custom cipher suites, client-auth / mutual TLS, a non-default provider) + * build a Netty {@link SslContext} yourself and wrap it with {@link #fromSslContext(SslContext)}. + * The {@code SslContext} is built once and reused for every connection, so it is cheap per + * request.

+ */ +public final class TlsConfig { + + /** The pre-built, shareable server SSL context. */ + private final SslContext sslContext; + + private TlsConfig(SslContext sslContext) { + this.sslContext = Objects.requireNonNull(sslContext, "sslContext"); + } + + /** + * Builds a TLS configuration from a PEM-encoded certificate chain and an unencrypted + * PKCS#8 private key. + * + * @param certificateChain the PEM file containing the certificate (chain) + * @param privateKey the PEM file containing the PKCS#8 private key + * @return a TLS configuration + * @throws IllegalStateException if the certificate/key cannot be loaded + */ + public static TlsConfig fromPem(File certificateChain, File privateKey) { + return fromPem(certificateChain, privateKey, null); + } + + /** + * Builds a TLS configuration from a PEM-encoded certificate chain and a (optionally + * password-protected) PKCS#8 private key. + * + * @param certificateChain the PEM file containing the certificate (chain) + * @param privateKey the PEM file containing the PKCS#8 private key + * @param keyPassword the password protecting {@code privateKey}, or {@code null} if none + * @return a TLS configuration + * @throws IllegalStateException if the certificate/key cannot be loaded + */ + public static TlsConfig fromPem(File certificateChain, File privateKey, String keyPassword) { + Objects.requireNonNull(certificateChain, "certificateChain"); + Objects.requireNonNull(privateKey, "privateKey"); + try { + return new TlsConfig(SslContextBuilder.forServer(certificateChain, privateKey, keyPassword).build()); + } catch (SSLException | RuntimeException e) { + // Netty surfaces missing/invalid PEM material as IllegalArgumentException; normalise + // every initialisation failure to a single, predictable exception type. + throw new IllegalStateException("Failed to initialise TLS from PEM files", e); + } + } + + /** + * Builds a TLS configuration from PEM-encoded streams, for certificates/keys loaded from the + * classpath or another non-file source. The caller retains ownership of the streams. + * + * @param certificateChain a stream of the PEM certificate (chain) + * @param privateKey a stream of the PEM PKCS#8 private key + * @param keyPassword the password protecting {@code privateKey}, or {@code null} if none + * @return a TLS configuration + * @throws IllegalStateException if the certificate/key cannot be loaded + */ + public static TlsConfig fromPem(InputStream certificateChain, InputStream privateKey, String keyPassword) { + Objects.requireNonNull(certificateChain, "certificateChain"); + Objects.requireNonNull(privateKey, "privateKey"); + try { + return new TlsConfig(SslContextBuilder.forServer(certificateChain, privateKey, keyPassword).build()); + } catch (SSLException | RuntimeException e) { + throw new IllegalStateException("Failed to initialise TLS from PEM streams", e); + } + } + + /** + * Wraps a fully configured Netty {@link SslContext}, for advanced setups such as custom cipher + * suites or mutual TLS. + * + * @param sslContext a server-mode SSL context + * @return a TLS configuration backed by the given context + */ + public static TlsConfig fromSslContext(SslContext sslContext) { + return new TlsConfig(sslContext); + } + + /** + * Creates a new per-connection {@link SslHandler} from the shared context. + * + * @param alloc the channel's buffer allocator + * @return a fresh TLS handler for one connection + */ + public SslHandler newHandler(ByteBufAllocator alloc) { + return sslContext.newHandler(alloc); + } + + /** + * Returns the underlying Netty SSL context. + * + * @return the shared server SSL context + */ + public SslContext sslContext() { + return sslContext; + } +} diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java index ba1c1d8..af2792d 100644 --- a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java +++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketConfig.java @@ -31,6 +31,8 @@ public final class WebSocketConfig { private final boolean compression; /** Whether the protocol handler matches the path by prefix rather than exact equality. */ private final boolean checkStartsWith; + /** Max in-flight callbacks queued per connection before read backpressure kicks in. */ + private final int maxQueuedMessages; /** * Builds an immutable configuration from a {@link Builder}, defensively copying its sets. @@ -46,6 +48,7 @@ public final class WebSocketConfig { this.subprotocols = Set.copyOf(b.subprotocols); this.compression = b.compression; this.checkStartsWith = b.checkStartsWith; + this.maxQueuedMessages = b.maxQueuedMessages; } /** @@ -153,6 +156,17 @@ public final class WebSocketConfig { return checkStartsWith; } + /** + * Returns the maximum number of in-flight callbacks that may be queued per connection before + * the framework stops reading further frames (backpressure), protecting the server from a + * client that floods messages faster than the handler consumes them. + * + * @return the per-connection queued-message high-watermark + */ + public int maxQueuedMessages() { + return maxQueuedMessages; + } + /** * Fluent builder for {@link WebSocketConfig}, pre-populated with sensible defaults: 64 KiB * frames, 1 MiB aggregated messages, a 60-second idle timeout, no origin restriction @@ -175,6 +189,8 @@ public final class WebSocketConfig { private boolean compression = true; /** Whether path matching uses a prefix check; defaults to {@code false}. */ private boolean checkStartsWith = false; + /** Per-connection queued-message high-watermark; defaults to 1024. */ + private int maxQueuedMessages = 1024; /** * Creates a builder pre-populated with the default configuration values described @@ -284,6 +300,19 @@ public final class WebSocketConfig { return this; } + /** + * Sets the per-connection queued-message high-watermark for backpressure. + * + * @param messages the maximum queued callbacks before reads are paused; must be positive + * @return this builder, for fluent chaining + * @throws IllegalArgumentException if {@code messages <= 0} + */ + public Builder maxQueuedMessages(int messages) { + if (messages <= 0) throw new IllegalArgumentException("maxQueuedMessages must be > 0"); + this.maxQueuedMessages = messages; + return this; + } + /** * Builds the immutable {@link WebSocketConfig}. * diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java index 8aa5eec..54d5dfc 100644 --- a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java +++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandler.java @@ -1,18 +1,22 @@ package dev.coph.nextusweb.server.websocket; +import dev.coph.nextusweb.server.auth.Principal; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.timeout.IdleStateEvent; -import io.netty.util.CharsetUtil; import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; /** * Netty channel handler that bridges low-level WebSocket frames to the high-level @@ -20,16 +24,23 @@ import java.util.concurrent.Executors; * *

It creates a {@link WebSocketSession} when the handshake completes, then translates each * incoming frame into the matching callback ({@code onMessage}, {@code onBinary}, - * {@code onClose}). Callbacks are dispatched on a virtual-thread executor so application code - * may block without stalling the Netty event loop, and any exception they throw is funneled to - * {@link WebSocketHandler#onError}. Idle-timeout events close the channel.

+ * {@code onClose}). Callbacks run on virtual threads so application code may block without + * stalling the Netty event loop, but, crucially, all callbacks for a single connection are + * executed strictly in arrival order by a per-connection serial drainer rather + * than each on its own thread — message ordering within a connection is therefore guaranteed.

+ * + *

The drainer's pending queue is bounded: once it reaches the configured high-watermark the + * handler disables the channel's auto-read (backpressure), so a client that floods messages faster + * than the handler consumes them cannot exhaust memory or spawn unbounded work; reading resumes + * once the backlog drains. Each handler instance serves exactly one connection, so this state is + * per-connection.

* *

This class is package-private; instances are created via * {@link WebSocketFrameHandlerFactory}.

*/ final class WebSocketFrameHandler extends SimpleChannelInboundHandler { - /** Executor running one virtual thread per task, used to dispatch handler callbacks. */ + /** Executor running one virtual thread per drain task. */ private static final Executor VT_EXECUTOR = Executors.newVirtualThreadPerTaskExecutor(); /** The application handler receiving lifecycle callbacks. */ @@ -38,6 +49,21 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler pathParams; + /** Authenticated principal for the connection, or {@code null} if anonymous. */ + private final Principal principal; + /** Queued-callback high-watermark at which reads are paused. */ + private final int maxQueued; + /** Watermark at which reads resume after having been paused. */ + private final int resumeQueued; + + /** FIFO of pending callbacks for this connection; drained by a single virtual thread. */ + private final Queue tasks = new ConcurrentLinkedQueue<>(); + /** Number of callbacks currently queued (drives the backpressure watermarks). */ + private final AtomicInteger queued = new AtomicInteger(); + /** Guards that at most one drainer runs at a time, preserving ordering. */ + private final AtomicBoolean draining = new AtomicBoolean(false); + /** Whether reads are currently paused for backpressure. */ + private volatile boolean readsPaused = false; /** * Creates a frame handler bound to an application handler and connection metadata. @@ -45,11 +71,17 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler pathParams) { + WebSocketFrameHandler(WebSocketHandler handler, String path, Map pathParams, + Principal principal, int maxQueued) { this.handler = handler; this.path = path; this.pathParams = pathParams; + this.principal = principal; + this.maxQueued = Math.max(1, maxQueued); + this.resumeQueued = Math.max(1, this.maxQueued / 4); } /** @@ -64,9 +96,9 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler { + submit(ctx, () -> { try { handler.onOpen(session); } catch (Throwable t) { @@ -85,8 +117,9 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler { + submit(ctx, () -> { try { handler.onMessage(session, content); } catch (Throwable t) { @@ -109,7 +142,7 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler { + submit(ctx, () -> { try { handler.onBinary(session, data); } catch (Throwable t) { @@ -119,7 +152,7 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler { + submit(ctx, () -> { try { handler.onClose(session, code, reason); } catch (Throwable t) { @@ -129,10 +162,61 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler= maxQueued && !readsPaused) { + readsPaused = true; + ctx.channel().config().setAutoRead(false); + } + if (draining.compareAndSet(false, true)) { + VT_EXECUTOR.execute(() -> drain(ctx)); + } + } + + /** + * Drains and runs queued callbacks one at a time (preserving order) until the queue is empty, + * resuming reads once the backlog falls back below the low-watermark. The single-drainer + * invariant is upheld by {@link #draining}; a final re-check avoids a lost wake-up if a task + * was enqueued just as the drainer was finishing. + * + * @param ctx the channel context + */ + private void drain(ChannelHandlerContext ctx) { + try { + Runnable task; + while ((task = tasks.poll()) != null) { + try { + task.run(); + } finally { + int remaining = queued.decrementAndGet(); + if (readsPaused && remaining <= resumeQueued) { + readsPaused = false; + if (ctx.channel().isActive()) { + ctx.channel().config().setAutoRead(true); + ctx.read(); + } + } + } + } + } finally { + draining.set(false); + if (!tasks.isEmpty() && draining.compareAndSet(false, true)) { + VT_EXECUTOR.execute(() -> drain(ctx)); + } + } + } + /** * Invoked when the channel goes inactive (the connection dropped without a clean close * handshake). Clears the stored session and dispatches {@link WebSocketHandler#onClose} with - * the abnormal-closure code {@code 1006}. + * the abnormal-closure code {@code 1006}, ordered after any still-queued callbacks. * * @param ctx the channel context */ @@ -140,7 +224,7 @@ final class WebSocketFrameHandler extends SimpleChannelInboundHandler { + submit(ctx, () -> { try { handler.onClose(session, 1006, "Connection closed"); } catch (Throwable t) { diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java index 342b16b..3065c82 100644 --- a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java +++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketFrameHandlerFactory.java @@ -1,5 +1,6 @@ package dev.coph.nextusweb.server.websocket; +import dev.coph.nextusweb.server.auth.Principal; import io.netty.channel.ChannelHandler; import java.util.Map; @@ -13,6 +14,9 @@ import java.util.Map; */ public final class WebSocketFrameHandlerFactory { + /** Default per-connection queued-message high-watermark when none is supplied. */ + private static final int DEFAULT_MAX_QUEUED = 1024; + /** * Private constructor preventing instantiation of this stateless utility class. */ @@ -21,7 +25,7 @@ public final class WebSocketFrameHandlerFactory { /** * Creates a channel handler that bridges Netty WebSocket frames to the given application - * {@link WebSocketHandler}. + * {@link WebSocketHandler}, using the default backpressure watermark and no principal. * * @param handler the application handler to dispatch lifecycle events to * @param path the path the connection was established on @@ -30,6 +34,38 @@ public final class WebSocketFrameHandlerFactory { */ public static ChannelHandler create(WebSocketHandler handler, String path, Map pathParams) { - return new WebSocketFrameHandler(handler, path, pathParams); + return create(handler, path, pathParams, null, DEFAULT_MAX_QUEUED); + } + + /** + * Creates a channel handler with an authenticated principal and the default backpressure + * watermark. + * + * @param handler the application handler to dispatch lifecycle events to + * @param path the path the connection was established on + * @param pathParams the path parameters captured during routing + * @param principal the authenticated principal, or {@code null} if the connection is anonymous + * @return a new channel handler ready to be inserted into the pipeline + */ + public static ChannelHandler create(WebSocketHandler handler, String path, + Map pathParams, Principal principal) { + return create(handler, path, pathParams, principal, DEFAULT_MAX_QUEUED); + } + + /** + * Creates a channel handler with an authenticated principal and an explicit backpressure + * watermark. + * + * @param handler the application handler to dispatch lifecycle events to + * @param path the path the connection was established on + * @param pathParams the path parameters captured during routing + * @param principal the authenticated principal, or {@code null} if the connection is anonymous + * @param maxQueued the per-connection queued-message high-watermark before reads are paused + * @return a new channel handler ready to be inserted into the pipeline + */ + public static ChannelHandler create(WebSocketHandler handler, String path, + Map pathParams, Principal principal, + int maxQueued) { + return new WebSocketFrameHandler(handler, path, pathParams, principal, maxQueued); } } diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java index 7680a25..0a287a3 100644 --- a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java +++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketGroup.java @@ -1,6 +1,7 @@ package dev.coph.nextusweb.server.websocket; import dev.coph.nextusweb.server.json.JsonMapper; +import io.netty.buffer.Unpooled; import io.netty.channel.group.ChannelGroup; import io.netty.channel.group.DefaultChannelGroup; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; @@ -102,8 +103,9 @@ public final class WebSocketGroup { public WebSocketGroup broadcastJson(Object value) { try { byte[] bytes = JsonMapper.MAPPER.writeValueAsBytes(value); - String text = new String(bytes, java.nio.charset.StandardCharsets.UTF_8); - channels.writeAndFlush(new TextWebSocketFrame(text)); + // Build the text frame straight from the serialized UTF-8 bytes; the channel group + // duplicates the payload per recipient, so no String round-trip re-encode is needed. + channels.writeAndFlush(new TextWebSocketFrame(Unpooled.wrappedBuffer(bytes))); } catch (JacksonException e) { throw new RuntimeException("JSON serialization failed", e); } diff --git a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java index 1cc6492..159ca1f 100644 --- a/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java +++ b/src/main/java/dev/coph/nextusweb/server/websocket/WebSocketSession.java @@ -1,5 +1,6 @@ package dev.coph.nextusweb.server.websocket; +import dev.coph.nextusweb.server.auth.Principal; import dev.coph.nextusweb.server.json.JsonMapper; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -47,6 +48,8 @@ public final class WebSocketSession { private final String path; /** Path parameters captured during routing, keyed by name. */ private final Map pathParams; + /** Authenticated principal for this connection, or {@code null} if anonymous. */ + private final Principal principal; /** Thread-safe bag of user-defined attributes attached to the session. */ private final Map attributes = new ConcurrentHashMap<>(); @@ -57,12 +60,14 @@ public final class WebSocketSession { * @param channel the underlying Netty channel * @param path the connection path * @param pathParams the path parameters captured during routing + * @param principal the authenticated principal, or {@code null} if the connection is anonymous */ - WebSocketSession(Channel channel, String path, Map pathParams) { + WebSocketSession(Channel channel, String path, Map pathParams, Principal principal) { this.channel = channel; this.id = UUID.randomUUID().toString(); this.path = path; this.pathParams = pathParams; + this.principal = principal; } /** @@ -93,6 +98,16 @@ public final class WebSocketSession { return pathParams.get(name); } + /** + * Returns the authenticated principal associated with this connection, established by the auth + * layer during the upgrade handshake. + * + * @return the principal, or {@code null} if the connection is anonymous + */ + public Principal principal() { + return principal; + } + /** * Indicates whether the connection is still open. * diff --git a/src/test/java/dev/coph/nextusweb/server/auth/AuthConfigTest.java b/src/test/java/dev/coph/nextusweb/server/auth/AuthConfigTest.java new file mode 100644 index 0000000..686048a --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/auth/AuthConfigTest.java @@ -0,0 +1,126 @@ +package dev.coph.nextusweb.server.auth; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class AuthConfigTest { + + /** A distinct authenticator instance, identifiable by reference, that authenticates nobody. */ + private Authenticator marker() { + return req -> null; + } + + @Test + void ruleForReturnsNullWhenUnprotected() { + AuthConfig cfg = AuthConfig.builder(marker()) + .protect("/admin") + .build(); + assertNull(cfg.ruleFor("/public")); + } + + @Test + void exactPathBeatsPrefix() { + Authenticator exactAuth = marker(); + Authenticator prefixAuth = marker(); + AuthConfig cfg = AuthConfig.builder(marker()) + .protect("/api/health", exactAuth) + .protectPrefix("/api/", prefixAuth) + .build(); + + AuthConfig.Rule rule = cfg.ruleFor("/api/health"); + assertNotNull(rule); + assertSame(exactAuth, rule.authenticator()); + } + + @Test + void longerPrefixWins() { + Authenticator shortAuth = marker(); + Authenticator longAuth = marker(); + AuthConfig cfg = AuthConfig.builder(marker()) + .protectPrefix("/api/", shortAuth) + .protectPrefix("/api/v2/", longAuth) + .build(); + + assertSame(longAuth, cfg.ruleFor("/api/v2/users").authenticator()); + assertSame(shortAuth, cfg.ruleFor("/api/v1/users").authenticator()); + } + + @Test + void protectUsesRequiredMode() { + AuthConfig cfg = AuthConfig.builder(marker()).protect("/admin").build(); + assertEquals(AuthConfig.Mode.REQUIRED, cfg.ruleFor("/admin").mode()); + } + + @Test + void optionalUsesOptionalMode() { + AuthConfig cfg = AuthConfig.builder(marker()).optional("/feed").build(); + assertEquals(AuthConfig.Mode.OPTIONAL, cfg.ruleFor("/feed").mode()); + } + + @Test + void requireEverywhereAppliesGlobalRequiredRule() { + AuthConfig cfg = AuthConfig.builder(marker()) + .requireEverywhere() + .build(); + AuthConfig.Rule rule = cfg.ruleFor("/anything"); + assertNotNull(rule); + assertEquals(AuthConfig.Mode.REQUIRED, rule.mode()); + } + + @Test + void optionalEverywhereAppliesGlobalOptionalRule() { + AuthConfig cfg = AuthConfig.builder(marker()) + .optionalEverywhere() + .build(); + AuthConfig.Rule rule = cfg.ruleFor("/anything"); + assertNotNull(rule); + assertEquals(AuthConfig.Mode.OPTIONAL, rule.mode()); + } + + @Test + void specificRuleBeatsGlobal() { + Authenticator specific = marker(); + AuthConfig cfg = AuthConfig.builder(marker()) + .optionalEverywhere() + .protect("/admin", specific) + .build(); + + AuthConfig.Rule adminRule = cfg.ruleFor("/admin"); + assertEquals(AuthConfig.Mode.REQUIRED, adminRule.mode()); + assertSame(specific, adminRule.authenticator()); + // Everything else still falls through to the optional global rule. + assertEquals(AuthConfig.Mode.OPTIONAL, cfg.ruleFor("/other").mode()); + } + + @Test + void defaultAuthenticatorUsedWhenNotOverridden() { + Authenticator def = marker(); + AuthConfig cfg = AuthConfig.builder(def) + .protect("/admin") + .protectPrefix("/api/") + .build(); + assertSame(def, cfg.ruleFor("/admin").authenticator()); + assertSame(def, cfg.ruleFor("/api/x").authenticator()); + } + + @Test + void challengeIsStored() { + AuthConfig cfg = AuthConfig.builder(marker()) + .protect("/admin") + .challenge("Basic realm=\"api\"") + .build(); + assertEquals("Basic realm=\"api\"", cfg.challenge()); + } + + @Test + void challengeDefaultsToNull() { + AuthConfig cfg = AuthConfig.builder(marker()).build(); + assertNull(cfg.challenge()); + } + + @Test + void builderRejectsNullDefaultAuthenticator() { + assertThrows(NullPointerException.class, () -> AuthConfig.builder(null)); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/auth/AuthGateTest.java b/src/test/java/dev/coph/nextusweb/server/auth/AuthGateTest.java new file mode 100644 index 0000000..7092c3e --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/auth/AuthGateTest.java @@ -0,0 +1,147 @@ +package dev.coph.nextusweb.server.auth; + +import dev.coph.nextusweb.server.router.Request; +import dev.coph.nextusweb.server.router.Response; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class AuthGateTest { + + private Request request(String apiKey) { + FullHttpRequest raw = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); + if (apiKey != null) raw.headers().set("X-API-Key", apiKey); + return new Request(raw, Map.of()); + } + + private Authenticator apiKeyAuth() { + return Authenticator.apiKey("X-API-Key", + key -> key.equals("valid") ? Principal.of("user-1") : null); + } + + @Test + void unprotectedPathProceedsWithoutPrincipal() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .protectPrefix("/admin") + .build()); + Request req = request(null); + assertNull(gate.authenticate(req, "/public")); + assertNull(req.principal()); + } + + @Test + void protectedPathRejectsMissingCredentials() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .protectPrefix("/admin") + .build()); + Response rejection = gate.authenticate(request(null), "/admin/users"); + assertNotNull(rejection); + assertEquals(401, rejection.status()); + } + + @Test + void protectedPathRejectsInvalidCredentials() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .protect("/admin") + .build()); + Response rejection = gate.authenticate(request("wrong"), "/admin"); + assertNotNull(rejection); + assertEquals(401, rejection.status()); + } + + @Test + void protectedPathAttachesPrincipalOnSuccess() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .protect("/admin") + .build()); + Request req = request("valid"); + assertNull(gate.authenticate(req, "/admin")); + assertNotNull(req.principal()); + assertEquals("user-1", req.principal().id()); + } + + @Test + void optionalPathProceedsAnonymouslyButAttachesWhenPresent() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .optionalPrefix("/feed") + .build()); + + Request anon = request(null); + assertNull(gate.authenticate(anon, "/feed")); + assertNull(anon.principal()); + + Request authed = request("valid"); + assertNull(gate.authenticate(authed, "/feed")); + assertEquals("user-1", authed.principal().id()); + } + + @Test + void challengeHeaderAddedToUnauthorized() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .protect("/admin") + .challenge("ApiKey realm=\"api\"") + .build()); + Response rejection = gate.authenticate(request(null), "/admin"); + assertEquals("ApiKey realm=\"api\"", rejection.headers().get("WWW-Authenticate")); + } + + @Test + void authenticatorErrorYields500() { + Authenticator boom = req -> { + throw new IllegalStateException("db down"); + }; + AuthGate gate = new AuthGate(AuthConfig.builder(boom).protect("/admin").build()); + Response rejection = gate.authenticate(request(null), "/admin"); + assertNotNull(rejection); + assertEquals(500, rejection.status()); + } + + @Test + void exactPathAuthenticatorIsUsedOverPrefix() { + // The prefix authenticator never authenticates; the exact-path one accepts the "valid" + // key. The exact rule must win so the request on the exact path can succeed. + Authenticator prefixDeny = req -> null; + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .protect("/api/health", apiKeyAuth()) + .protectPrefix("/api/", prefixDeny) + .build()); + + Request exact = request("valid"); + assertNull(gate.authenticate(exact, "/api/health")); + assertEquals("user-1", exact.principal().id()); + + // A sibling path under the prefix uses the (always-denying) prefix authenticator. + Response rejection = gate.authenticate(request("valid"), "/api/other"); + assertNotNull(rejection); + assertEquals(401, rejection.status()); + } + + @Test + void requireEverywhereRejectsAnyPathWithoutCredentials() { + AuthGate gate = new AuthGate(AuthConfig.builder(apiKeyAuth()) + .requireEverywhere() + .build()); + assertEquals(401, gate.authenticate(request(null), "/whatever").status()); + assertNull(gate.authenticate(request("valid"), "/whatever")); + } + + @Test + void anyOfAuthenticatorAcceptsEitherCredential() { + Authenticator combined = Authenticator.anyOf( + apiKeyAuth(), + Authenticator.cookie("sid", s -> s.equals("sess") ? Principal.of("cookie-user") : null)); + AuthGate gate = new AuthGate(AuthConfig.builder(combined).protect("/admin").build()); + + Request viaKey = request("valid"); + assertNull(gate.authenticate(viaKey, "/admin")); + assertEquals("user-1", viaKey.principal().id()); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/auth/AuthenticatorTest.java b/src/test/java/dev/coph/nextusweb/server/auth/AuthenticatorTest.java new file mode 100644 index 0000000..cd6babc --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/auth/AuthenticatorTest.java @@ -0,0 +1,119 @@ +package dev.coph.nextusweb.server.auth; + +import dev.coph.nextusweb.server.router.Request; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class AuthenticatorTest { + + private Request request(String header, String value) { + FullHttpRequest raw = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); + if (header != null) raw.headers().set(header, value); + return new Request(raw, Map.of()); + } + + @Test + void apiKeyResolvesViaValidator() throws Exception { + Authenticator auth = Authenticator.apiKey("X-API-Key", + key -> key.equals("good") ? Principal.of("svc") : null); + + assertEquals("svc", auth.authenticate(request("X-API-Key", "good")).id()); + assertNull(auth.authenticate(request("X-API-Key", "bad"))); + assertNull(auth.authenticate(request(null, null))); + } + + @Test + void cookieResolvesViaValidator() throws Exception { + Authenticator auth = Authenticator.cookie("sid", + sid -> sid.equals("abc") ? Principal.of("u1") : null); + + assertEquals("u1", auth.authenticate(request("Cookie", "sid=abc")).id()); + assertNull(auth.authenticate(request("Cookie", "sid=zzz"))); + assertNull(auth.authenticate(request(null, null))); + } + + @Test + void basicDecodesCredentials() throws Exception { + Authenticator auth = Authenticator.basic( + (user, pass) -> user.equals("alice") && pass.equals("s3cret") ? Principal.of(user) : null); + + String header = "Basic " + Base64.getEncoder() + .encodeToString("alice:s3cret".getBytes(StandardCharsets.UTF_8)); + assertEquals("alice", auth.authenticate(request("Authorization", header)).id()); + + String wrong = "Basic " + Base64.getEncoder() + .encodeToString("alice:nope".getBytes(StandardCharsets.UTF_8)); + assertNull(auth.authenticate(request("Authorization", wrong))); + assertNull(auth.authenticate(request("Authorization", "Basic not-base64!!"))); + assertNull(auth.authenticate(request(null, null))); + } + + @Test + void anyOfReturnsFirstMatch() throws Exception { + Authenticator key = Authenticator.apiKey("X-API-Key", + k -> k.equals("k") ? Principal.of("byKey") : null); + Authenticator cookie = Authenticator.cookie("sid", + s -> Principal.of("byCookie")); + Authenticator combined = Authenticator.anyOf(key, cookie); + + assertEquals("byKey", combined.authenticate(request("X-API-Key", "k")).id()); + assertEquals("byCookie", combined.authenticate(request("Cookie", "sid=x")).id()); + assertNull(combined.authenticate(request(null, null))); + } + + @Test + void bearerDecodesToken() throws Exception { + Authenticator auth = Authenticator.bearer( + token -> token.equals("tok123") ? Principal.of("u") : null); + + assertEquals("u", auth.authenticate(request("Authorization", "Bearer tok123")).id()); + assertNull(auth.authenticate(request("Authorization", "Bearer wrong"))); + assertNull(auth.authenticate(request("Authorization", "Bearer "))); + assertNull(auth.authenticate(request("Authorization", "Basic abc"))); + assertNull(auth.authenticate(request(null, null))); + } + + @Test + void authSchemesAreCaseInsensitive() throws Exception { + Authenticator bearer = Authenticator.bearer(t -> Principal.of("b")); + assertEquals("b", bearer.authenticate(request("Authorization", "bearer tok")).id()); + + Authenticator basic = Authenticator.basic((u, p) -> Principal.of(u)); + String header = "basic " + Base64.getEncoder() + .encodeToString("alice:pw".getBytes(StandardCharsets.UTF_8)); + assertEquals("alice", basic.authenticate(request("Authorization", header)).id()); + } + + @Test + void basicReturnsNullWhenNoColon() throws Exception { + Authenticator auth = Authenticator.basic((u, p) -> Principal.of(u)); + String header = "Basic " + Base64.getEncoder() + .encodeToString("nocolon".getBytes(StandardCharsets.UTF_8)); + assertNull(auth.authenticate(request("Authorization", header))); + } + + @Test + void basicAllowsEmptyPassword() throws Exception { + Authenticator auth = Authenticator.basic((u, p) -> Principal.of(u + ":" + p)); + String header = "Basic " + Base64.getEncoder() + .encodeToString("user:".getBytes(StandardCharsets.UTF_8)); + assertEquals("user:", auth.authenticate(request("Authorization", header)).id()); + } + + @Test + void apiKeyEmptyHeaderTreatedAsAbsent() throws Exception { + Authenticator auth = Authenticator.apiKey("X-API-Key", k -> Principal.of("never")); + assertNull(auth.authenticate(request("X-API-Key", ""))); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/auth/PrincipalTest.java b/src/test/java/dev/coph/nextusweb/server/auth/PrincipalTest.java new file mode 100644 index 0000000..d978d30 --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/auth/PrincipalTest.java @@ -0,0 +1,61 @@ +package dev.coph.nextusweb.server.auth; + +import org.junit.jupiter.api.Test; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.*; + +class PrincipalTest { + + @Test + void ofIdOnlyHasNoRolesOrClaims() { + Principal p = Principal.of("user-1"); + assertEquals("user-1", p.id()); + assertTrue(p.roles().isEmpty()); + assertTrue(p.claims().isEmpty()); + assertFalse(p.hasRole("admin")); + } + + @Test + void ofWithRolesExposesRolesAndHasRole() { + Principal p = Principal.of("user-2", Set.of("admin", "ops")); + assertEquals("user-2", p.id()); + assertEquals(Set.of("admin", "ops"), p.roles()); + assertTrue(p.hasRole("admin")); + assertTrue(p.hasRole("ops")); + assertFalse(p.hasRole("guest")); + } + + @Test + void rolesAreDefensivelyCopiedAndImmutable() { + Set source = new HashSet<>(Set.of("admin")); + Principal p = Principal.of("user-3", source); + + // Mutating the source after construction must not affect the principal. + source.add("sneaky"); + assertEquals(Set.of("admin"), p.roles()); + + // The exposed set must be unmodifiable. + assertThrows(UnsupportedOperationException.class, () -> p.roles().add("x")); + } + + @Test + void customImplementationIsSupported() { + Principal custom = new Principal() { + @Override + public String id() { + return "svc"; + } + + @Override + public Set roles() { + return Set.of("service"); + } + }; + assertEquals("svc", custom.id()); + assertTrue(custom.hasRole("service")); + assertTrue(custom.claims().isEmpty()); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/net/ClientIpTest.java b/src/test/java/dev/coph/nextusweb/server/net/ClientIpTest.java new file mode 100644 index 0000000..beeb03d --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/net/ClientIpTest.java @@ -0,0 +1,56 @@ +package dev.coph.nextusweb.server.net; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class ClientIpTest { + + @Test + void usesSocketIpWhenNoForwardedHeader() { + assertEquals("203.0.113.5", + ClientIp.resolve("203.0.113.5", null, TrustedProxies.all())); + } + + @Test + void ignoresForwardedHeaderWhenPeerNotTrusted() { + // A direct (untrusted) client cannot spoof its IP via X-Forwarded-For. + assertEquals("203.0.113.5", + ClientIp.resolve("203.0.113.5", "1.2.3.4", TrustedProxies.none())); + } + + @Test + void usesForwardedHeaderWhenPeerTrusted() { + TrustedProxies trusted = TrustedProxies.of("10.0.0.0/8"); + assertEquals("1.2.3.4", + ClientIp.resolve("10.0.0.1", "1.2.3.4", trusted)); + } + + @Test + void returnsFirstUntrustedHopFromTheRight() { + // Chain: realClient, edgeProxy, internalProxy(=peer). Both proxies are trusted, so the + // resolved client is the first untrusted entry walking from the right. + TrustedProxies trusted = TrustedProxies.of("10.0.0.0/8"); + String xff = "9.9.9.9, 10.0.0.9, 10.0.0.8"; + assertEquals("9.9.9.9", + ClientIp.resolve("10.0.0.8", xff, trusted)); + } + + @Test + void spoofedLeadingEntriesAreIgnored() { + // Attacker prepends a fake hop; since the genuine client hop (8.8.8.8) is the first + // untrusted from the right, the forged "1.1.1.1" is never returned. + TrustedProxies trusted = TrustedProxies.of("10.0.0.0/8"); + String xff = "1.1.1.1, 8.8.8.8, 10.0.0.8"; + assertEquals("8.8.8.8", + ClientIp.resolve("10.0.0.8", xff, trusted)); + } + + @Test + void allHopsTrustedFallsBackToLeftmost() { + TrustedProxies trusted = TrustedProxies.of("10.0.0.0/8"); + String xff = "10.0.0.7, 10.0.0.8"; + assertEquals("10.0.0.7", + ClientIp.resolve("10.0.0.8", xff, trusted)); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/net/TrustedProxiesTest.java b/src/test/java/dev/coph/nextusweb/server/net/TrustedProxiesTest.java new file mode 100644 index 0000000..e0dbbbf --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/net/TrustedProxiesTest.java @@ -0,0 +1,64 @@ +package dev.coph.nextusweb.server.net; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class TrustedProxiesTest { + + @Test + void noneTrustsNothing() { + TrustedProxies tp = TrustedProxies.none(); + assertFalse(tp.isTrusted("127.0.0.1")); + assertFalse(tp.isTrusted("10.0.0.1")); + } + + @Test + void allTrustsEverything() { + TrustedProxies tp = TrustedProxies.all(); + assertTrue(tp.isTrusted("8.8.8.8")); + assertTrue(tp.isTrusted("::1")); + } + + @Test + void matchesIpv4Cidr() { + TrustedProxies tp = TrustedProxies.of("10.0.0.0/8"); + assertTrue(tp.isTrusted("10.1.2.3")); + assertTrue(tp.isTrusted("10.255.255.255")); + assertFalse(tp.isTrusted("11.0.0.1")); + assertFalse(tp.isTrusted("192.168.0.1")); + } + + @Test + void matchesBareHostAsSingleAddress() { + TrustedProxies tp = TrustedProxies.of("127.0.0.1"); + assertTrue(tp.isTrusted("127.0.0.1")); + assertFalse(tp.isTrusted("127.0.0.2")); + } + + @Test + void matchesIpv6Cidr() { + TrustedProxies tp = TrustedProxies.of("fd00::/8"); + assertTrue(tp.isTrusted("fd12:3456::1")); + assertFalse(tp.isTrusted("fe80::1")); + } + + @Test + void differentFamilyDoesNotMatch() { + TrustedProxies tp = TrustedProxies.of("10.0.0.0/8"); + assertFalse(tp.isTrusted("::1")); + } + + @Test + void invalidAddressIsNotTrusted() { + TrustedProxies tp = TrustedProxies.of("10.0.0.0/8"); + assertFalse(tp.isTrusted("not-an-ip")); + assertFalse(tp.isTrusted(null)); + } + + @Test + void rejectsInvalidCidr() { + assertThrows(IllegalArgumentException.class, () -> TrustedProxies.of("10.0.0.0/40")); + assertThrows(IllegalArgumentException.class, () -> TrustedProxies.of("garbage")); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/ratelimit/KeyResolverTest.java b/src/test/java/dev/coph/nextusweb/server/ratelimit/KeyResolverTest.java index df02435..ad46188 100644 --- a/src/test/java/dev/coph/nextusweb/server/ratelimit/KeyResolverTest.java +++ b/src/test/java/dev/coph/nextusweb/server/ratelimit/KeyResolverTest.java @@ -1,53 +1,70 @@ package dev.coph.nextusweb.server.ratelimit; -import io.netty.handler.codec.http.DefaultHttpRequest; +import dev.coph.nextusweb.server.auth.Principal; +import dev.coph.nextusweb.server.router.Request; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpVersion; import org.junit.jupiter.api.Test; +import java.util.Map; +import java.util.Set; + import static org.junit.jupiter.api.Assertions.*; class KeyResolverTest { - private HttpRequest req(String header, String value) { - HttpRequest r = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (header != null) r.headers().set(header, value); - return r; + private Request request() { + FullHttpRequest raw = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); + return new Request(raw, Map.of()); + } + + private Request requestWith(String header, String value) { + FullHttpRequest raw = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); + raw.headers().set(header, value); + return new Request(raw, Map.of()); } @Test - void clientIpUsesRemoteWhenNoForwardedHeader() { - assertEquals("10.0.0.1", KeyResolver.clientIp().resolve(req(null, null), "10.0.0.1")); + void clientIpReturnsResolvedIpVerbatim() { + assertEquals("10.0.0.1", KeyResolver.clientIp().resolve(request(), "10.0.0.1")); } @Test - void clientIpUsesForwardedHeaderFirstValue() { - HttpRequest r = req("X-Forwarded-For", "1.1.1.1, 2.2.2.2"); - assertEquals("1.1.1.1", KeyResolver.clientIp().resolve(r, "10.0.0.1")); + void headerResolverUsesHeaderValue() { + Request r = requestWith("X-API-Key", "secret123"); + assertEquals("h:secret123", KeyResolver.header("X-API-Key").resolve(r, "10.0.0.1")); } @Test - void clientIpHandlesSingleForwardedValue() { - HttpRequest r = req("X-Forwarded-For", "3.3.3.3"); - assertEquals("3.3.3.3", KeyResolver.clientIp().resolve(r, "10.0.0.1")); + void headerResolverFallsBackToClientIp() { + assertEquals("ip:10.0.0.1", KeyResolver.header("X-API-Key").resolve(request(), "10.0.0.1")); } @Test - void userOrIpReturnsBearerToken() { - HttpRequest r = req("Authorization", "Bearer abc123"); - assertEquals("u:abc123", KeyResolver.userOrIp().resolve(r, "10.0.0.1")); + void cookieResolverUsesCookieValue() { + Request r = requestWith("Cookie", "sid=abc; other=x"); + assertEquals("c:abc", KeyResolver.cookie("sid").resolve(r, "10.0.0.1")); } @Test - void userOrIpFallsBackToClientIp() { - HttpRequest r = req(null, null); - assertEquals("ip:10.0.0.1", KeyResolver.userOrIp().resolve(r, "10.0.0.1")); + void cookieResolverFallsBackToClientIp() { + assertEquals("ip:10.0.0.1", KeyResolver.cookie("sid").resolve(request(), "10.0.0.1")); } @Test - void userOrIpIgnoresNonBearerAuth() { - HttpRequest r = req("Authorization", "Basic xyz"); - assertEquals("ip:10.0.0.1", KeyResolver.userOrIp().resolve(r, "10.0.0.1")); + void principalResolverUsesPrincipalId() { + Request r = request(); + r.principal(Principal.of("user-42", Set.of("admin"))); + assertEquals("p:user-42", KeyResolver.principal().resolve(r, "10.0.0.1")); + } + + @Test + void principalResolverFallsBackToClientIpWhenAnonymous() { + assertEquals("ip:10.0.0.1", KeyResolver.principal().resolve(request(), "10.0.0.1")); } } diff --git a/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfigTest.java b/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfigTest.java index f232767..ef825fb 100644 --- a/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfigTest.java +++ b/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitConfigTest.java @@ -75,4 +75,30 @@ class RateLimitConfigTest { assertEquals("global", rules.get(0).name()); assertEquals("/x", rules.get(1).name()); } + + @Test + void allLimitersCollectsEveryDistinctLimiter() { + // Distinct concrete instances (non-capturing lambdas would be the same JVM singleton). + RateLimiter a = new FixedWindowLimiter(1, 1000); + RateLimiter b = new FixedWindowLimiter(1, 1000); + RateLimiter c = new FixedWindowLimiter(1, 1000); + RateLimitConfig cfg = RateLimitConfig.builder() + .global(a, keyer()) + .forPath("/x", b, keyer()) + .forPrefix("/api/", c, keyer()) + .build(); + assertEquals(3, cfg.allLimiters().size()); + assertTrue(cfg.allLimiters().containsAll(List.of(a, b, c))); + } + + @Test + void allLimitersDeduplicatesSharedInstance() { + RateLimiter shared = alwaysAllow(); + RateLimitConfig cfg = RateLimitConfig.builder() + .global(shared, keyer()) + .forPath("/x", shared, keyer()) + .forPrefix("/api/", shared, keyer()) + .build(); + assertEquals(1, cfg.allLimiters().size()); + } } diff --git a/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitGateTest.java b/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitGateTest.java index 28d665e..21b4b48 100644 --- a/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitGateTest.java +++ b/src/test/java/dev/coph/nextusweb/server/ratelimit/RateLimitGateTest.java @@ -1,18 +1,24 @@ package dev.coph.nextusweb.server.ratelimit; +import dev.coph.nextusweb.server.router.Request; import dev.coph.nextusweb.server.router.Response; -import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpVersion; import org.junit.jupiter.api.Test; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.*; class RateLimitGateTest { - private HttpRequest req() { - return new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + private Request req() { + FullHttpRequest raw = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); + return new Request(raw, Map.of()); } @Test diff --git a/src/test/java/dev/coph/nextusweb/server/router/RequestTest.java b/src/test/java/dev/coph/nextusweb/server/router/RequestTest.java index 7883e6a..8c1eced 100644 --- a/src/test/java/dev/coph/nextusweb/server/router/RequestTest.java +++ b/src/test/java/dev/coph/nextusweb/server/router/RequestTest.java @@ -1,5 +1,6 @@ package dev.coph.nextusweb.server.router; +import dev.coph.nextusweb.server.auth.Principal; import dev.coph.nextusweb.server.router.exception.BadRequestException; import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.DefaultFullHttpRequest; @@ -101,4 +102,48 @@ class RequestTest { Request req = new Request(build(HttpMethod.POST, "/", "not-json"), Map.of()); assertThrows(BadRequestException.class, () -> req.jsonAs(Payload.class)); } + + @Test + void cookieParsesNamedCookie() { + FullHttpRequest raw = build(HttpMethod.GET, "/", null); + raw.headers().set("Cookie", "sid=abc123; theme=dark"); + Request req = new Request(raw, Map.of()); + assertEquals("abc123", req.cookie("sid")); + assertEquals("dark", req.cookie("theme")); + assertNull(req.cookie("missing")); + } + + @Test + void cookieReturnsNullWhenNoCookieHeader() { + Request req = new Request(build(HttpMethod.GET, "/", null), Map.of()); + assertNull(req.cookie("sid")); + } + + @Test + void attributesSetGetAndRemove() { + Request req = new Request(build(HttpMethod.GET, "/", null), Map.of()); + assertNull(req.attribute("k")); + req.attribute("k", "v"); + assertEquals("v", req.attribute("k")); + req.attribute("k", null); + assertNull(req.attribute("k")); + } + + @Test + void clientIpRoundTrips() { + Request req = new Request(build(HttpMethod.GET, "/", null), Map.of()); + assertNull(req.clientIp()); + req.clientIp("203.0.113.9"); + assertEquals("203.0.113.9", req.clientIp()); + } + + @Test + void principalRoundTripsAndDrivesIsAuthenticated() { + Request req = new Request(build(HttpMethod.GET, "/", null), Map.of()); + assertFalse(req.isAuthenticated()); + assertNull(req.principal()); + req.principal(Principal.of("user-7")); + assertTrue(req.isAuthenticated()); + assertEquals("user-7", req.principal().id()); + } } diff --git a/src/test/java/dev/coph/nextusweb/server/tls/TlsConfigTest.java b/src/test/java/dev/coph/nextusweb/server/tls/TlsConfigTest.java new file mode 100644 index 0000000..52185fd --- /dev/null +++ b/src/test/java/dev/coph/nextusweb/server/tls/TlsConfigTest.java @@ -0,0 +1,22 @@ +package dev.coph.nextusweb.server.tls; + +import org.junit.jupiter.api.Test; + +import java.io.File; + +import static org.junit.jupiter.api.Assertions.*; + +class TlsConfigTest { + + @Test + void fromPemWithMissingFilesThrowsIllegalState() { + File missingCert = new File("does-not-exist-cert.pem"); + File missingKey = new File("does-not-exist-key.pem"); + assertThrows(IllegalStateException.class, () -> TlsConfig.fromPem(missingCert, missingKey)); + } + + @Test + void fromSslContextRejectsNull() { + assertThrows(NullPointerException.class, () -> TlsConfig.fromSslContext(null)); + } +} diff --git a/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketGroupTest.java b/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketGroupTest.java index 035093e..580abab 100644 --- a/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketGroupTest.java +++ b/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketGroupTest.java @@ -17,7 +17,7 @@ class WebSocketGroupTest { } private WebSocketSession session(EmbeddedChannel ch) { - return new WebSocketSession(ch, "/ws", Map.of()); + return new WebSocketSession(ch, "/ws", Map.of(), null); } @Test diff --git a/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketHandlerTest.java b/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketHandlerTest.java index 18c353f..376eea5 100644 --- a/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketHandlerTest.java +++ b/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketHandlerTest.java @@ -13,7 +13,7 @@ class WebSocketHandlerTest { void defaultMethodsDoNotThrow() { WebSocketHandler handler = new WebSocketHandler() {}; EmbeddedChannel ch = new EmbeddedChannel(); - WebSocketSession session = new WebSocketSession(ch, "/ws", Map.of()); + WebSocketSession session = new WebSocketSession(ch, "/ws", Map.of(), null); assertDoesNotThrow(() -> handler.onOpen(session)); assertDoesNotThrow(() -> handler.onMessage(session, "msg")); assertDoesNotThrow(() -> handler.onBinary(session, new byte[]{1})); diff --git a/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketSessionTest.java b/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketSessionTest.java index a851143..817f561 100644 --- a/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketSessionTest.java +++ b/src/test/java/dev/coph/nextusweb/server/websocket/WebSocketSessionTest.java @@ -15,7 +15,7 @@ import static org.junit.jupiter.api.Assertions.*; class WebSocketSessionTest { private WebSocketSession session(EmbeddedChannel ch) { - return new WebSocketSession(ch, "/ws/{id}", Map.of("id", "42")); + return new WebSocketSession(ch, "/ws/{id}", Map.of("id", "42"), null); } @Test