Files
PenaltyTracker/main.go
T
2026-05-16 20:39:27 +02:00

214 lines
4.7 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"flag"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
)
type Config struct {
Addr string `json:"addr"`
DBPath string `json:"db_path"`
RulesDir string `json:"rules_dir"`
CORSOrigins []string `json:"cors_origins"`
CrossSiteCookies bool `json:"cross_site_cookies"`
}
var corsOrigins []string
var crossSiteCookies bool
func defaultConfig() *Config {
return &Config{
Addr: ":8080",
DBPath: "penaltytracker.db",
RulesDir: "rules",
CORSOrigins: []string{},
CrossSiteCookies: false,
}
}
func writeConfig(path string, cfg *Config) error {
if dir := filepath.Dir(path); dir != "" {
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func loadConfig(path string) (*Config, error) {
cfg := defaultConfig()
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
if err := writeConfig(path, cfg); err != nil {
return nil, err
}
log.Printf("config file created at %s with defaults", path)
return cfg, nil
}
return nil, err
}
if err := json.Unmarshal(data, cfg); err != nil {
return nil, err
}
if cfg.Addr == "" {
cfg.Addr = ":8080"
}
if cfg.DBPath == "" {
cfg.DBPath = "penaltytracker.db"
}
if cfg.RulesDir == "" {
cfg.RulesDir = "rules"
}
for i, o := range cfg.CORSOrigins {
cfg.CORSOrigins[i] = strings.TrimSpace(o)
}
return cfg, nil
}
func ensureDir(dir string) error {
if dir == "" || dir == "." {
return nil
}
return os.MkdirAll(dir, 0o755)
}
func ensureDirectories(cfg *Config) error {
if err := ensureDir(cfg.RulesDir); err != nil {
return err
}
if dbDir := filepath.Dir(cfg.DBPath); dbDir != "" {
if err := ensureDir(dbDir); err != nil {
return err
}
}
return nil
}
func main() {
configPath := flag.String("config", "config.json", "path to config file")
flag.Parse()
cfg, err := loadConfig(*configPath)
if err != nil {
log.Fatalf("config load: %v", err)
}
if err := ensureDirectories(cfg); err != nil {
log.Fatalf("ensure directories: %v", err)
}
corsOrigins = cfg.CORSOrigins
crossSiteCookies = cfg.CrossSiteCookies
if err := openDB(cfg.DBPath); err != nil {
log.Fatalf("db open: %v", err)
}
if err := migrate(); err != nil {
log.Fatalf("migrate: %v", err)
}
if err := ensureDefaultAdmin(); err != nil {
log.Fatalf("default admin: %v", err)
}
if err := loadRules(cfg.RulesDir); err != nil {
log.Printf("rules load warning: %v", err)
}
hub = newHub()
go hub.run()
mux := http.NewServeMux()
registerAuthRoutes(mux)
registerUserRoutes(mux)
registerCompetitionRoutes(mux)
registerPilotRoutes(mux)
registerPenaltyRoutes(mux)
registerRuleRoutes(mux)
registerWSRoutes(mux)
server := &http.Server{
Addr: cfg.Addr,
Handler: withLog(withCORS(mux)),
ReadHeaderTimeout: 10 * time.Second,
}
go func() {
log.Printf("listening on %s (cors_origins=%v)", cfg.Addr, corsOrigins)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen: %v", err)
}
}()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
}
func withLog(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start))
})
}
func originAllowed(origin string) bool {
if origin == "" {
return false
}
for _, o := range corsOrigins {
if o == "*" {
return true
}
if strings.EqualFold(o, origin) {
return true
}
}
return false
}
func withCORS(next http.Handler) http.Handler {
if len(corsOrigins) == 0 {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if originAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS")
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
if reqHeaders == "" {
reqHeaders = "Content-Type"
}
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
w.Header().Set("Access-Control-Max-Age", "600")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}