257 lines
6.7 KiB
Go
257 lines
6.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
const userCtxKey ctxKey = "user"
|
|
|
|
const sessionCookie = "pt_session"
|
|
const sessionDuration = 24 * time.Hour * 14
|
|
|
|
func registerAuthRoutes(mux *http.ServeMux) {
|
|
mux.HandleFunc("POST /api/login", handleLogin)
|
|
mux.HandleFunc("POST /api/logout", handleLogout)
|
|
mux.HandleFunc("GET /api/me", requireAuth(handleMe))
|
|
mux.HandleFunc("PATCH /api/me", requireAuth(handleUpdateMe))
|
|
}
|
|
|
|
func ensureDefaultAdmin() error {
|
|
var count int
|
|
if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil {
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
return nil
|
|
}
|
|
hash, err := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = db.Exec(
|
|
"INSERT INTO users(username,password_hash,display_name,language,is_system_admin) VALUES(?,?,?,?,1)",
|
|
"admin", string(hash), "System Admin", "en",
|
|
)
|
|
return err
|
|
}
|
|
|
|
func newToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|
|
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "invalid_body")
|
|
return
|
|
}
|
|
req.Username = strings.TrimSpace(req.Username)
|
|
if req.Username == "" || req.Password == "" {
|
|
writeError(w, http.StatusBadRequest, "missing_credentials")
|
|
return
|
|
}
|
|
var id int64
|
|
var hash string
|
|
err := db.QueryRow("SELECT id,password_hash FROM users WHERE username=?", req.Username).Scan(&id, &hash)
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "invalid_credentials")
|
|
return
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)); err != nil {
|
|
writeError(w, http.StatusUnauthorized, "invalid_credentials")
|
|
return
|
|
}
|
|
token, err := newToken()
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "token_error")
|
|
return
|
|
}
|
|
expires := time.Now().Add(sessionDuration)
|
|
if _, err := db.Exec("INSERT INTO sessions(token,user_id,expires_at) VALUES(?,?,?)", token, id, expires.Format(time.RFC3339)); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "session_error")
|
|
return
|
|
}
|
|
cookie := &http.Cookie{
|
|
Name: sessionCookie,
|
|
Value: token,
|
|
Path: "/",
|
|
Expires: expires,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
if crossSiteCookies {
|
|
cookie.SameSite = http.SameSiteNoneMode
|
|
cookie.Secure = true
|
|
}
|
|
http.SetCookie(w, cookie)
|
|
user, _ := loadUser(id)
|
|
writeJSON(w, http.StatusOK, user)
|
|
}
|
|
|
|
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
c, err := r.Cookie(sessionCookie)
|
|
if err == nil {
|
|
db.Exec("DELETE FROM sessions WHERE token=?", c.Value)
|
|
}
|
|
clear := &http.Cookie{
|
|
Name: sessionCookie,
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
if crossSiteCookies {
|
|
clear.SameSite = http.SameSiteNoneMode
|
|
clear.Secure = true
|
|
}
|
|
http.SetCookie(w, clear)
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func handleMe(w http.ResponseWriter, r *http.Request) {
|
|
u := userFromCtx(r)
|
|
writeJSON(w, http.StatusOK, u)
|
|
}
|
|
|
|
func handleUpdateMe(w http.ResponseWriter, r *http.Request) {
|
|
u := userFromCtx(r)
|
|
var req struct {
|
|
Username *string `json:"username"`
|
|
Language *string `json:"language"`
|
|
DisplayName *string `json:"display_name"`
|
|
Password *string `json:"password"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "invalid_body")
|
|
return
|
|
}
|
|
if req.Username != nil {
|
|
newName := strings.TrimSpace(*req.Username)
|
|
if newName == "" {
|
|
writeError(w, http.StatusBadRequest, "missing_username")
|
|
return
|
|
}
|
|
if _, err := db.Exec("UPDATE users SET username=? WHERE id=?", newName, u.ID); err != nil {
|
|
writeError(w, http.StatusConflict, "username_taken")
|
|
return
|
|
}
|
|
}
|
|
if req.Language != nil {
|
|
if _, err := db.Exec("UPDATE users SET language=? WHERE id=?", *req.Language, u.ID); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "db_error")
|
|
return
|
|
}
|
|
}
|
|
if req.DisplayName != nil {
|
|
if _, err := db.Exec("UPDATE users SET display_name=? WHERE id=?", *req.DisplayName, u.ID); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "db_error")
|
|
return
|
|
}
|
|
}
|
|
if req.Password != nil && *req.Password != "" {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "hash_error")
|
|
return
|
|
}
|
|
if _, err := db.Exec("UPDATE users SET password_hash=? WHERE id=?", string(hash), u.ID); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "db_error")
|
|
return
|
|
}
|
|
}
|
|
user, _ := loadUser(u.ID)
|
|
writeJSON(w, http.StatusOK, user)
|
|
}
|
|
|
|
func loadUser(id int64) (*User, error) {
|
|
u := &User{}
|
|
var admin int
|
|
err := db.QueryRow("SELECT id,username,display_name,language,is_system_admin FROM users WHERE id=?", id).
|
|
Scan(&u.ID, &u.Username, &u.DisplayName, &u.Language, &admin)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
u.IsSystemAdmin = admin == 1
|
|
return u, nil
|
|
}
|
|
|
|
func authUser(r *http.Request) (*User, error) {
|
|
c, err := r.Cookie(sessionCookie)
|
|
if err != nil {
|
|
return nil, errors.New("no_session")
|
|
}
|
|
var userID int64
|
|
var expiresAt string
|
|
err = db.QueryRow("SELECT user_id, expires_at FROM sessions WHERE token=?", c.Value).Scan(&userID, &expiresAt)
|
|
if err != nil {
|
|
return nil, errors.New("invalid_session")
|
|
}
|
|
if t, err := time.Parse(time.RFC3339, expiresAt); err == nil && time.Now().After(t) {
|
|
db.Exec("DELETE FROM sessions WHERE token=?", c.Value)
|
|
return nil, errors.New("expired")
|
|
}
|
|
return loadUser(userID)
|
|
}
|
|
|
|
func requireAuth(h http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
u, err := authUser(r)
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), userCtxKey, u)
|
|
h.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func requireAdmin(h http.HandlerFunc) http.HandlerFunc {
|
|
return requireAuth(func(w http.ResponseWriter, r *http.Request) {
|
|
u := userFromCtx(r)
|
|
if !u.IsSystemAdmin {
|
|
writeError(w, http.StatusForbidden, "forbidden")
|
|
return
|
|
}
|
|
h.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func userFromCtx(r *http.Request) *User {
|
|
v := r.Context().Value(userCtxKey)
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
return v.(*User)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, code int, body any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
if body != nil {
|
|
_ = json.NewEncoder(w).Encode(body)
|
|
}
|
|
}
|
|
|
|
func writeError(w http.ResponseWriter, code int, msg string) {
|
|
writeJSON(w, code, map[string]string{"error": msg})
|
|
}
|