389 lines
10 KiB
Go
389 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
const userCtxKey ctxKey = "user"
|
|
|
|
const sessionCookie = "pt_session"
|
|
const sessionDuration = 24 * time.Hour * 14
|
|
|
|
const maxUsernameLen = 64
|
|
const maxDisplayNameLen = 128
|
|
const maxPasswordLen = 256
|
|
const minPasswordLen = 8
|
|
const maxFieldLen = 2000
|
|
|
|
// validatePassword returns "" if the password meets the policy, or a short
|
|
// error code suitable for the JSON `error` field if it doesn't.
|
|
func validatePassword(p string) string {
|
|
if len(p) < minPasswordLen {
|
|
return "password_too_short"
|
|
}
|
|
if len(p) > maxPasswordLen {
|
|
return "too_long"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func registerAuthRoutes(mux *http.ServeMux) {
|
|
mux.HandleFunc("POST /api/login", loginRateLimit(handleLogin))
|
|
mux.HandleFunc("POST /api/logout", handleLogout)
|
|
mux.HandleFunc("GET /api/me", requireAuth(handleMe))
|
|
mux.HandleFunc("PATCH /api/me", requireAuth(handleUpdateMe))
|
|
}
|
|
|
|
func ensureDefaultAdmin() error {
|
|
var count int
|
|
if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil {
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
return nil
|
|
}
|
|
hash, err := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = db.Exec(
|
|
"INSERT INTO users(username,password_hash,display_name,language,is_system_admin,must_change_password) VALUES(?,?,?,?,1,1)",
|
|
"admin", string(hash), "System Admin", "en",
|
|
)
|
|
return err
|
|
}
|
|
|
|
func newToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|
|
// ---- login rate limiter ---------------------------------------------------
|
|
|
|
type loginAttempts struct {
|
|
times []time.Time
|
|
}
|
|
|
|
var (
|
|
loginMu sync.Mutex
|
|
loginAttempts_ = map[string]*loginAttempts{}
|
|
)
|
|
|
|
const loginMaxAttempts = 4
|
|
const loginWindow = 5 * time.Minute
|
|
|
|
func clientIP(r *http.Request) string {
|
|
if xf := r.Header.Get("X-Forwarded-For"); xf != "" {
|
|
parts := strings.Split(xf, ",")
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
// loginStatus returns (remaining, retryAfterSeconds). When remaining == 0 the
|
|
// caller must reject the request and use retryAfterSeconds for headers.
|
|
func loginStatus(ip string) (int, int) {
|
|
loginMu.Lock()
|
|
defer loginMu.Unlock()
|
|
a, ok := loginAttempts_[ip]
|
|
if !ok {
|
|
a = &loginAttempts{}
|
|
loginAttempts_[ip] = a
|
|
}
|
|
cutoff := time.Now().Add(-loginWindow)
|
|
kept := a.times[:0]
|
|
for _, t := range a.times {
|
|
if t.After(cutoff) {
|
|
kept = append(kept, t)
|
|
}
|
|
}
|
|
a.times = kept
|
|
remaining := loginMaxAttempts - len(a.times)
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
retryAfter := 0
|
|
if remaining == 0 && len(a.times) > 0 {
|
|
// Time until the oldest attempt falls out of the window.
|
|
next := a.times[0].Add(loginWindow).Sub(time.Now())
|
|
retryAfter = int(next.Seconds()) + 1
|
|
if retryAfter < 1 {
|
|
retryAfter = 1
|
|
}
|
|
}
|
|
return remaining, retryAfter
|
|
}
|
|
|
|
func loginRecord(ip string) {
|
|
loginMu.Lock()
|
|
defer loginMu.Unlock()
|
|
a, ok := loginAttempts_[ip]
|
|
if !ok {
|
|
a = &loginAttempts{}
|
|
loginAttempts_[ip] = a
|
|
}
|
|
a.times = append(a.times, time.Now())
|
|
}
|
|
|
|
func loginRateLimit(h http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ip := clientIP(r)
|
|
remaining, retryAfter := loginStatus(ip)
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(loginMaxAttempts))
|
|
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
|
if remaining == 0 {
|
|
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
|
w.Header().Set("X-RateLimit-Reset", strconv.Itoa(retryAfter))
|
|
writeError(w, http.StatusTooManyRequests, "too_many_attempts")
|
|
return
|
|
}
|
|
h.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
// ---- handlers --------------------------------------------------------------
|
|
|
|
func normalizeUsername(s string) string {
|
|
return strings.ToLower(strings.TrimSpace(s))
|
|
}
|
|
|
|
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "invalid_body")
|
|
return
|
|
}
|
|
req.Username = normalizeUsername(req.Username)
|
|
if req.Username == "" || req.Password == "" {
|
|
writeError(w, http.StatusBadRequest, "missing_credentials")
|
|
return
|
|
}
|
|
if len(req.Username) > maxUsernameLen || len(req.Password) > maxPasswordLen {
|
|
writeError(w, http.StatusBadRequest, "too_long")
|
|
return
|
|
}
|
|
var id int64
|
|
var hash string
|
|
err := db.QueryRow("SELECT id,password_hash FROM users WHERE username=?", req.Username).Scan(&id, &hash)
|
|
if err != nil {
|
|
loginRecord(clientIP(r))
|
|
writeError(w, http.StatusUnauthorized, "invalid_credentials")
|
|
return
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)); err != nil {
|
|
loginRecord(clientIP(r))
|
|
writeError(w, http.StatusUnauthorized, "invalid_credentials")
|
|
return
|
|
}
|
|
token, err := newToken()
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "token_error")
|
|
return
|
|
}
|
|
expires := time.Now().Add(sessionDuration)
|
|
if _, err := db.Exec("INSERT INTO sessions(token,user_id,expires_at) VALUES(?,?,?)", token, id, expires.Format(time.RFC3339)); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "session_error")
|
|
return
|
|
}
|
|
cookie := &http.Cookie{
|
|
Name: sessionCookie,
|
|
Value: token,
|
|
Path: "/",
|
|
Expires: expires,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Secure: r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https"),
|
|
}
|
|
if crossSiteCookies {
|
|
cookie.SameSite = http.SameSiteNoneMode
|
|
cookie.Secure = true
|
|
}
|
|
http.SetCookie(w, cookie)
|
|
user, _ := loadUser(id)
|
|
writeJSON(w, http.StatusOK, user)
|
|
}
|
|
|
|
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
c, err := r.Cookie(sessionCookie)
|
|
if err == nil {
|
|
db.Exec("DELETE FROM sessions WHERE token=?", c.Value)
|
|
}
|
|
clear := &http.Cookie{
|
|
Name: sessionCookie,
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Secure: r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https"),
|
|
}
|
|
if crossSiteCookies {
|
|
clear.SameSite = http.SameSiteNoneMode
|
|
clear.Secure = true
|
|
}
|
|
http.SetCookie(w, clear)
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func handleMe(w http.ResponseWriter, r *http.Request) {
|
|
u := userFromCtx(r)
|
|
writeJSON(w, http.StatusOK, u)
|
|
}
|
|
|
|
// handleUpdateMe lets the logged-in user change ONLY their language and password.
|
|
// Username and display name are administrative attributes and can only be set
|
|
// through the admin user endpoints.
|
|
func handleUpdateMe(w http.ResponseWriter, r *http.Request) {
|
|
u := userFromCtx(r)
|
|
var req struct {
|
|
Language *string `json:"language"`
|
|
Password *string `json:"password"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "invalid_body")
|
|
return
|
|
}
|
|
if req.Language != nil {
|
|
lang := strings.TrimSpace(*req.Language)
|
|
if len(lang) > 8 {
|
|
writeError(w, http.StatusBadRequest, "invalid_language")
|
|
return
|
|
}
|
|
if _, err := db.Exec("UPDATE users SET language=? WHERE id=?", lang, u.ID); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "db_error")
|
|
return
|
|
}
|
|
}
|
|
if req.Password != nil && *req.Password != "" {
|
|
if msg := validatePassword(*req.Password); msg != "" {
|
|
writeError(w, http.StatusBadRequest, msg)
|
|
return
|
|
}
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "hash_error")
|
|
return
|
|
}
|
|
if _, err := db.Exec("UPDATE users SET password_hash=?,must_change_password=0 WHERE id=?", string(hash), u.ID); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "db_error")
|
|
return
|
|
}
|
|
}
|
|
user, _ := loadUser(u.ID)
|
|
writeJSON(w, http.StatusOK, user)
|
|
}
|
|
|
|
func loadUser(id int64) (*User, error) {
|
|
u := &User{}
|
|
var admin, mustChange int
|
|
err := db.QueryRow("SELECT id,username,display_name,language,is_system_admin,must_change_password FROM users WHERE id=?", id).
|
|
Scan(&u.ID, &u.Username, &u.DisplayName, &u.Language, &admin, &mustChange)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
u.IsSystemAdmin = admin == 1
|
|
u.MustChangePassword = mustChange == 1
|
|
return u, nil
|
|
}
|
|
|
|
func authUser(r *http.Request) (*User, error) {
|
|
c, err := r.Cookie(sessionCookie)
|
|
if err != nil {
|
|
return nil, errors.New("no_session")
|
|
}
|
|
var userID int64
|
|
var expiresAt string
|
|
err = db.QueryRow("SELECT user_id, expires_at FROM sessions WHERE token=?", c.Value).Scan(&userID, &expiresAt)
|
|
if err != nil {
|
|
return nil, errors.New("invalid_session")
|
|
}
|
|
if t, err := time.Parse(time.RFC3339, expiresAt); err == nil && time.Now().After(t) {
|
|
db.Exec("DELETE FROM sessions WHERE token=?", c.Value)
|
|
return nil, errors.New("expired")
|
|
}
|
|
return loadUser(userID)
|
|
}
|
|
|
|
// passwordChangeExempt returns true for endpoints that a user must remain able
|
|
// to access even while in the "must change password" lock state.
|
|
func passwordChangeExempt(r *http.Request) bool {
|
|
if r.URL.Path == "/api/me" {
|
|
return true
|
|
}
|
|
if r.URL.Path == "/api/logout" {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func requireAuth(h http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
u, err := authUser(r)
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
if u.MustChangePassword && !passwordChangeExempt(r) {
|
|
writeError(w, http.StatusForbidden, "password_change_required")
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), userCtxKey, u)
|
|
h.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func requireAdmin(h http.HandlerFunc) http.HandlerFunc {
|
|
return requireAuth(func(w http.ResponseWriter, r *http.Request) {
|
|
u := userFromCtx(r)
|
|
if !u.IsSystemAdmin {
|
|
writeError(w, http.StatusForbidden, "forbidden")
|
|
return
|
|
}
|
|
h.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func userFromCtx(r *http.Request) *User {
|
|
v := r.Context().Value(userCtxKey)
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
return v.(*User)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, code int, body any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
if body != nil {
|
|
_ = json.NewEncoder(w).Encode(body)
|
|
}
|
|
}
|
|
|
|
func writeError(w http.ResponseWriter, code int, msg string) {
|
|
writeJSON(w, code, map[string]string{"error": msg})
|
|
}
|