package main import ( "context" "encoding/json" "errors" "flag" "log" "net/http" "os" "os/signal" "path/filepath" "strings" "syscall" "time" ) type Config struct { Addr string `json:"addr"` DBPath string `json:"db_path"` RulesDir string `json:"rules_dir"` CORSOrigins []string `json:"cors_origins"` CrossSiteCookies bool `json:"cross_site_cookies"` } var corsOrigins []string var crossSiteCookies bool func defaultConfig() *Config { return &Config{ Addr: ":8080", DBPath: "penaltytracker.db", RulesDir: "rules", CORSOrigins: []string{}, CrossSiteCookies: false, } } func writeConfig(path string, cfg *Config) error { if dir := filepath.Dir(path); dir != "" { if err := os.MkdirAll(dir, 0o755); err != nil { return err } } data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return err } return os.WriteFile(path, data, 0o644) } func loadConfig(path string) (*Config, error) { cfg := defaultConfig() data, err := os.ReadFile(path) if err != nil { if errors.Is(err, os.ErrNotExist) { if err := writeConfig(path, cfg); err != nil { return nil, err } log.Printf("config file created at %s with defaults", path) return cfg, nil } return nil, err } if err := json.Unmarshal(data, cfg); err != nil { return nil, err } if cfg.Addr == "" { cfg.Addr = ":8080" } if cfg.DBPath == "" { cfg.DBPath = "penaltytracker.db" } if cfg.RulesDir == "" { cfg.RulesDir = "rules" } for i, o := range cfg.CORSOrigins { cfg.CORSOrigins[i] = strings.TrimSpace(o) } return cfg, nil } func ensureDir(dir string) error { if dir == "" || dir == "." { return nil } return os.MkdirAll(dir, 0o755) } func ensureDirectories(cfg *Config) error { if err := ensureDir(cfg.RulesDir); err != nil { return err } if dbDir := filepath.Dir(cfg.DBPath); dbDir != "" { if err := ensureDir(dbDir); err != nil { return err } } return nil } func main() { configPath := flag.String("config", "config.json", "path to config file") flag.Parse() cfg, err := loadConfig(*configPath) if err != nil { log.Fatalf("config load: %v", err) } if err := ensureDirectories(cfg); err != nil { log.Fatalf("ensure directories: %v", err) } corsOrigins = cfg.CORSOrigins crossSiteCookies = cfg.CrossSiteCookies if err := openDB(cfg.DBPath); err != nil { log.Fatalf("db open: %v", err) } if err := migrate(); err != nil { log.Fatalf("migrate: %v", err) } if err := ensureDefaultAdmin(); err != nil { log.Fatalf("default admin: %v", err) } if err := loadRules(cfg.RulesDir); err != nil { log.Printf("rules load warning: %v", err) } hub = newHub() go hub.run() mux := http.NewServeMux() registerAuthRoutes(mux) registerUserRoutes(mux) registerCompetitionRoutes(mux) registerPilotRoutes(mux) registerPenaltyRoutes(mux) registerRuleRoutes(mux) registerWSRoutes(mux) server := &http.Server{ Addr: cfg.Addr, Handler: withLog(withCORS(mux)), ReadHeaderTimeout: 10 * time.Second, } go func() { log.Printf("listening on %s (cors_origins=%v)", cfg.Addr, corsOrigins) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("listen: %v", err) } }() sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) <-sig ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() _ = server.Shutdown(ctx) } func withLog(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start)) }) } func originAllowed(origin string) bool { if origin == "" { return false } for _, o := range corsOrigins { if o == "*" { return true } if strings.EqualFold(o, origin) { return true } } return false } func withCORS(next http.Handler) http.Handler { if len(corsOrigins) == 0 { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if originAllowed(origin) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS") reqHeaders := r.Header.Get("Access-Control-Request-Headers") if reqHeaders == "" { reqHeaders = "Content-Type" } w.Header().Set("Access-Control-Allow-Headers", reqHeaders) w.Header().Set("Access-Control-Max-Age", "600") } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) }