Files
PenaltyTracker/auth.go
T
2026-05-16 20:39:27 +02:00

257 lines
6.7 KiB
Go

package main
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
type ctxKey string
const userCtxKey ctxKey = "user"
const sessionCookie = "pt_session"
const sessionDuration = 24 * time.Hour * 14
func registerAuthRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST /api/login", handleLogin)
mux.HandleFunc("POST /api/logout", handleLogout)
mux.HandleFunc("GET /api/me", requireAuth(handleMe))
mux.HandleFunc("PATCH /api/me", requireAuth(handleUpdateMe))
}
func ensureDefaultAdmin() error {
var count int
if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil {
return err
}
if count > 0 {
return nil
}
hash, err := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
if err != nil {
return err
}
_, err = db.Exec(
"INSERT INTO users(username,password_hash,display_name,language,is_system_admin) VALUES(?,?,?,?,1)",
"admin", string(hash), "System Admin", "en",
)
return err
}
func newToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func handleLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_body")
return
}
req.Username = strings.TrimSpace(req.Username)
if req.Username == "" || req.Password == "" {
writeError(w, http.StatusBadRequest, "missing_credentials")
return
}
var id int64
var hash string
err := db.QueryRow("SELECT id,password_hash FROM users WHERE username=?", req.Username).Scan(&id, &hash)
if err != nil {
writeError(w, http.StatusUnauthorized, "invalid_credentials")
return
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)); err != nil {
writeError(w, http.StatusUnauthorized, "invalid_credentials")
return
}
token, err := newToken()
if err != nil {
writeError(w, http.StatusInternalServerError, "token_error")
return
}
expires := time.Now().Add(sessionDuration)
if _, err := db.Exec("INSERT INTO sessions(token,user_id,expires_at) VALUES(?,?,?)", token, id, expires.Format(time.RFC3339)); err != nil {
writeError(w, http.StatusInternalServerError, "session_error")
return
}
cookie := &http.Cookie{
Name: sessionCookie,
Value: token,
Path: "/",
Expires: expires,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
if crossSiteCookies {
cookie.SameSite = http.SameSiteNoneMode
cookie.Secure = true
}
http.SetCookie(w, cookie)
user, _ := loadUser(id)
writeJSON(w, http.StatusOK, user)
}
func handleLogout(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie(sessionCookie)
if err == nil {
db.Exec("DELETE FROM sessions WHERE token=?", c.Value)
}
clear := &http.Cookie{
Name: sessionCookie,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
if crossSiteCookies {
clear.SameSite = http.SameSiteNoneMode
clear.Secure = true
}
http.SetCookie(w, clear)
w.WriteHeader(http.StatusNoContent)
}
func handleMe(w http.ResponseWriter, r *http.Request) {
u := userFromCtx(r)
writeJSON(w, http.StatusOK, u)
}
func handleUpdateMe(w http.ResponseWriter, r *http.Request) {
u := userFromCtx(r)
var req struct {
Username *string `json:"username"`
Language *string `json:"language"`
DisplayName *string `json:"display_name"`
Password *string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_body")
return
}
if req.Username != nil {
newName := strings.TrimSpace(*req.Username)
if newName == "" {
writeError(w, http.StatusBadRequest, "missing_username")
return
}
if _, err := db.Exec("UPDATE users SET username=? WHERE id=?", newName, u.ID); err != nil {
writeError(w, http.StatusConflict, "username_taken")
return
}
}
if req.Language != nil {
if _, err := db.Exec("UPDATE users SET language=? WHERE id=?", *req.Language, u.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error")
return
}
}
if req.DisplayName != nil {
if _, err := db.Exec("UPDATE users SET display_name=? WHERE id=?", *req.DisplayName, u.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error")
return
}
}
if req.Password != nil && *req.Password != "" {
hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
if err != nil {
writeError(w, http.StatusInternalServerError, "hash_error")
return
}
if _, err := db.Exec("UPDATE users SET password_hash=? WHERE id=?", string(hash), u.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error")
return
}
}
user, _ := loadUser(u.ID)
writeJSON(w, http.StatusOK, user)
}
func loadUser(id int64) (*User, error) {
u := &User{}
var admin int
err := db.QueryRow("SELECT id,username,display_name,language,is_system_admin FROM users WHERE id=?", id).
Scan(&u.ID, &u.Username, &u.DisplayName, &u.Language, &admin)
if err != nil {
return nil, err
}
u.IsSystemAdmin = admin == 1
return u, nil
}
func authUser(r *http.Request) (*User, error) {
c, err := r.Cookie(sessionCookie)
if err != nil {
return nil, errors.New("no_session")
}
var userID int64
var expiresAt string
err = db.QueryRow("SELECT user_id, expires_at FROM sessions WHERE token=?", c.Value).Scan(&userID, &expiresAt)
if err != nil {
return nil, errors.New("invalid_session")
}
if t, err := time.Parse(time.RFC3339, expiresAt); err == nil && time.Now().After(t) {
db.Exec("DELETE FROM sessions WHERE token=?", c.Value)
return nil, errors.New("expired")
}
return loadUser(userID)
}
func requireAuth(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
u, err := authUser(r)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
ctx := context.WithValue(r.Context(), userCtxKey, u)
h.ServeHTTP(w, r.WithContext(ctx))
}
}
func requireAdmin(h http.HandlerFunc) http.HandlerFunc {
return requireAuth(func(w http.ResponseWriter, r *http.Request) {
u := userFromCtx(r)
if !u.IsSystemAdmin {
writeError(w, http.StatusForbidden, "forbidden")
return
}
h.ServeHTTP(w, r)
})
}
func userFromCtx(r *http.Request) *User {
v := r.Context().Value(userCtxKey)
if v == nil {
return nil
}
return v.(*User)
}
func writeJSON(w http.ResponseWriter, code int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if body != nil {
_ = json.NewEncoder(w).Encode(body)
}
}
func writeError(w http.ResponseWriter, code int, msg string) {
writeJSON(w, code, map[string]string{"error": msg})
}