package main import ( "context" "encoding/json" "errors" "flag" "log" "net" "net/http" "net/url" "os" "os/signal" "path/filepath" "strings" "syscall" "time" ) type Config struct { Addr string `json:"addr"` DBPath string `json:"db_path"` RulesDir string `json:"rules_dir"` BackupDir string `json:"backup_dir"` CORSOrigins []string `json:"cors_origins"` CrossSiteCookies bool `json:"cross_site_cookies"` } var corsOrigins []string var crossSiteCookies bool var backupDir string func defaultConfig() *Config { return &Config{ Addr: ":8080", DBPath: "penaltytracker.db", RulesDir: "rules", BackupDir: "backup", CORSOrigins: []string{}, CrossSiteCookies: false, } } func writeConfig(path string, cfg *Config) error { if dir := filepath.Dir(path); dir != "" { if err := os.MkdirAll(dir, 0o755); err != nil { return err } } data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return err } return os.WriteFile(path, data, 0o644) } func loadConfig(path string) (*Config, error) { cfg := defaultConfig() data, err := os.ReadFile(path) if err != nil { if errors.Is(err, os.ErrNotExist) { if err := writeConfig(path, cfg); err != nil { return nil, err } log.Printf("config file created at %s with defaults", path) return cfg, nil } return nil, err } if err := json.Unmarshal(data, cfg); err != nil { return nil, err } if cfg.Addr == "" { cfg.Addr = ":8080" } if cfg.DBPath == "" { cfg.DBPath = "penaltytracker.db" } if cfg.RulesDir == "" { cfg.RulesDir = "rules" } if cfg.BackupDir == "" { cfg.BackupDir = "backup" } for i, o := range cfg.CORSOrigins { cfg.CORSOrigins[i] = strings.TrimSpace(o) } return cfg, nil } func ensureDir(dir string) error { if dir == "" || dir == "." { return nil } return os.MkdirAll(dir, 0o755) } func ensureDirectories(cfg *Config) error { if err := ensureDir(cfg.RulesDir); err != nil { return err } if err := ensureDir(cfg.BackupDir); err != nil { return err } if dbDir := filepath.Dir(cfg.DBPath); dbDir != "" { if err := ensureDir(dbDir); err != nil { return err } } return nil } func main() { configPath := flag.String("config", "config.json", "path to config file") flag.Parse() cfg, err := loadConfig(*configPath) if err != nil { log.Fatalf("config load: %v", err) } if err := ensureDirectories(cfg); err != nil { log.Fatalf("ensure directories: %v", err) } corsOrigins = cfg.CORSOrigins crossSiteCookies = cfg.CrossSiteCookies backupDir = cfg.BackupDir if err := openDB(cfg.DBPath); err != nil { log.Fatalf("db open: %v", err) } if err := migrate(); err != nil { log.Fatalf("migrate: %v", err) } if err := ensureDefaultAdmin(); err != nil { log.Fatalf("default admin: %v", err) } if err := loadRules(cfg.RulesDir); err != nil { log.Printf("rules load warning: %v", err) } hub = newHub() go hub.run() mux := http.NewServeMux() registerAuthRoutes(mux) registerUserRoutes(mux) registerCompetitionRoutes(mux) registerPilotRoutes(mux) registerPenaltyRoutes(mux) registerRuleRoutes(mux) registerWSRoutes(mux) handler := withSecurityHeaders(withLog(withCSRF(withCORS(mux)))) server := &http.Server{ Addr: cfg.Addr, Handler: handler, ReadHeaderTimeout: 10 * time.Second, } go func() { log.Printf("listening on %s (cors_origins=%v)", cfg.Addr, corsOrigins) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("listen: %v", err) } }() sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) <-sig ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() _ = server.Shutdown(ctx) } func withLog(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start)) }) } func withSecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.Header() h.Set("X-Content-Type-Options", "nosniff") h.Set("X-Frame-Options", "DENY") h.Set("Referrer-Policy", "no-referrer") h.Set("Permissions-Policy", "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()") h.Set("Cross-Origin-Resource-Policy", "same-site") h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") { h.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } next.ServeHTTP(w, r) }) } func originAllowed(origin string) bool { if origin == "" { return false } for _, o := range corsOrigins { if o == "*" { return true } if strings.EqualFold(o, origin) { return true } } return false } func sameOriginRequest(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { // Fall back to Referer ref := r.Header.Get("Referer") if ref == "" { // No origin info: only safe methods allowed. State-changing must have Origin. return false } u, err := url.Parse(ref) if err != nil { return false } origin = u.Scheme + "://" + u.Host } u, err := url.Parse(origin) if err != nil { return false } originHost, _, _ := net.SplitHostPort(u.Host) if originHost == "" { originHost = u.Host } reqHost, _, _ := net.SplitHostPort(r.Host) if reqHost == "" { reqHost = r.Host } return strings.EqualFold(originHost, reqHost) } func isStateChanging(method string) bool { switch method { case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: return true } return false } func withCSRF(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !isStateChanging(r.Method) { next.ServeHTTP(w, r) return } origin := r.Header.Get("Origin") // Allow if Origin matches a configured CORS origin. if origin != "" && originAllowed(origin) { next.ServeHTTP(w, r) return } // Allow same-origin requests (Origin or Referer host matches request Host). if sameOriginRequest(r) { next.ServeHTTP(w, r) return } writeError(w, http.StatusForbidden, "csrf_forbidden") }) } func withCORS(next http.Handler) http.Handler { if len(corsOrigins) == 0 { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if originAllowed(origin) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS") reqHeaders := r.Header.Get("Access-Control-Request-Headers") if reqHeaders == "" { reqHeaders = "Content-Type" } w.Header().Set("Access-Control-Allow-Headers", reqHeaders) w.Header().Set("Access-Control-Max-Age", "600") } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) }