package main import ( "context" "crypto/rand" "encoding/hex" "encoding/json" "errors" "net/http" "strings" "time" "golang.org/x/crypto/bcrypt" ) type ctxKey string const userCtxKey ctxKey = "user" const sessionCookie = "pt_session" const sessionDuration = 24 * time.Hour * 14 func registerAuthRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /api/login", handleLogin) mux.HandleFunc("POST /api/logout", handleLogout) mux.HandleFunc("GET /api/me", requireAuth(handleMe)) mux.HandleFunc("PATCH /api/me", requireAuth(handleUpdateMe)) } func ensureDefaultAdmin() error { var count int if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil { return err } if count > 0 { return nil } hash, err := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost) if err != nil { return err } _, err = db.Exec( "INSERT INTO users(username,password_hash,display_name,language,is_system_admin) VALUES(?,?,?,?,1)", "admin", string(hash), "System Admin", "en", ) return err } func newToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return hex.EncodeToString(b), nil } func handleLogin(w http.ResponseWriter, r *http.Request) { var req struct { Username string `json:"username"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "invalid_body") return } req.Username = strings.TrimSpace(req.Username) if req.Username == "" || req.Password == "" { writeError(w, http.StatusBadRequest, "missing_credentials") return } var id int64 var hash string err := db.QueryRow("SELECT id,password_hash FROM users WHERE username=?", req.Username).Scan(&id, &hash) if err != nil { writeError(w, http.StatusUnauthorized, "invalid_credentials") return } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)); err != nil { writeError(w, http.StatusUnauthorized, "invalid_credentials") return } token, err := newToken() if err != nil { writeError(w, http.StatusInternalServerError, "token_error") return } expires := time.Now().Add(sessionDuration) if _, err := db.Exec("INSERT INTO sessions(token,user_id,expires_at) VALUES(?,?,?)", token, id, expires.Format(time.RFC3339)); err != nil { writeError(w, http.StatusInternalServerError, "session_error") return } cookie := &http.Cookie{ Name: sessionCookie, Value: token, Path: "/", Expires: expires, HttpOnly: true, SameSite: http.SameSiteLaxMode, } if crossSiteCookies { cookie.SameSite = http.SameSiteNoneMode cookie.Secure = true } http.SetCookie(w, cookie) user, _ := loadUser(id) writeJSON(w, http.StatusOK, user) } func handleLogout(w http.ResponseWriter, r *http.Request) { c, err := r.Cookie(sessionCookie) if err == nil { db.Exec("DELETE FROM sessions WHERE token=?", c.Value) } clear := &http.Cookie{ Name: sessionCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: true, SameSite: http.SameSiteLaxMode, } if crossSiteCookies { clear.SameSite = http.SameSiteNoneMode clear.Secure = true } http.SetCookie(w, clear) w.WriteHeader(http.StatusNoContent) } func handleMe(w http.ResponseWriter, r *http.Request) { u := userFromCtx(r) writeJSON(w, http.StatusOK, u) } func handleUpdateMe(w http.ResponseWriter, r *http.Request) { u := userFromCtx(r) var req struct { Username *string `json:"username"` Language *string `json:"language"` DisplayName *string `json:"display_name"` Password *string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "invalid_body") return } if req.Username != nil { newName := strings.TrimSpace(*req.Username) if newName == "" { writeError(w, http.StatusBadRequest, "missing_username") return } if _, err := db.Exec("UPDATE users SET username=? WHERE id=?", newName, u.ID); err != nil { writeError(w, http.StatusConflict, "username_taken") return } } if req.Language != nil { if _, err := db.Exec("UPDATE users SET language=? WHERE id=?", *req.Language, u.ID); err != nil { writeError(w, http.StatusInternalServerError, "db_error") return } } if req.DisplayName != nil { if _, err := db.Exec("UPDATE users SET display_name=? WHERE id=?", *req.DisplayName, u.ID); err != nil { writeError(w, http.StatusInternalServerError, "db_error") return } } if req.Password != nil && *req.Password != "" { hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost) if err != nil { writeError(w, http.StatusInternalServerError, "hash_error") return } if _, err := db.Exec("UPDATE users SET password_hash=? WHERE id=?", string(hash), u.ID); err != nil { writeError(w, http.StatusInternalServerError, "db_error") return } } user, _ := loadUser(u.ID) writeJSON(w, http.StatusOK, user) } func loadUser(id int64) (*User, error) { u := &User{} var admin int err := db.QueryRow("SELECT id,username,display_name,language,is_system_admin FROM users WHERE id=?", id). Scan(&u.ID, &u.Username, &u.DisplayName, &u.Language, &admin) if err != nil { return nil, err } u.IsSystemAdmin = admin == 1 return u, nil } func authUser(r *http.Request) (*User, error) { c, err := r.Cookie(sessionCookie) if err != nil { return nil, errors.New("no_session") } var userID int64 var expiresAt string err = db.QueryRow("SELECT user_id, expires_at FROM sessions WHERE token=?", c.Value).Scan(&userID, &expiresAt) if err != nil { return nil, errors.New("invalid_session") } if t, err := time.Parse(time.RFC3339, expiresAt); err == nil && time.Now().After(t) { db.Exec("DELETE FROM sessions WHERE token=?", c.Value) return nil, errors.New("expired") } return loadUser(userID) } func requireAuth(h http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { u, err := authUser(r) if err != nil { writeError(w, http.StatusUnauthorized, "unauthorized") return } ctx := context.WithValue(r.Context(), userCtxKey, u) h.ServeHTTP(w, r.WithContext(ctx)) } } func requireAdmin(h http.HandlerFunc) http.HandlerFunc { return requireAuth(func(w http.ResponseWriter, r *http.Request) { u := userFromCtx(r) if !u.IsSystemAdmin { writeError(w, http.StatusForbidden, "forbidden") return } h.ServeHTTP(w, r) }) } func userFromCtx(r *http.Request) *User { v := r.Context().Value(userCtxKey) if v == nil { return nil } return v.(*User) } func writeJSON(w http.ResponseWriter, code int, body any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) if body != nil { _ = json.NewEncoder(w).Encode(body) } } func writeError(w http.ResponseWriter, code int, msg string) { writeJSON(w, code, map[string]string{"error": msg}) }