214 lines
4.7 KiB
Go
214 lines
4.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type Config struct {
|
|
Addr string `json:"addr"`
|
|
DBPath string `json:"db_path"`
|
|
RulesDir string `json:"rules_dir"`
|
|
CORSOrigins []string `json:"cors_origins"`
|
|
CrossSiteCookies bool `json:"cross_site_cookies"`
|
|
}
|
|
|
|
var corsOrigins []string
|
|
var crossSiteCookies bool
|
|
|
|
func defaultConfig() *Config {
|
|
return &Config{
|
|
Addr: ":8080",
|
|
DBPath: "penaltytracker.db",
|
|
RulesDir: "rules",
|
|
CORSOrigins: []string{},
|
|
CrossSiteCookies: false,
|
|
}
|
|
}
|
|
|
|
func writeConfig(path string, cfg *Config) error {
|
|
if dir := filepath.Dir(path); dir != "" {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
data, err := json.MarshalIndent(cfg, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(path, data, 0o644)
|
|
}
|
|
|
|
func loadConfig(path string) (*Config, error) {
|
|
cfg := defaultConfig()
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
if err := writeConfig(path, cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
log.Printf("config file created at %s with defaults", path)
|
|
return cfg, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if err := json.Unmarshal(data, cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
if cfg.Addr == "" {
|
|
cfg.Addr = ":8080"
|
|
}
|
|
if cfg.DBPath == "" {
|
|
cfg.DBPath = "penaltytracker.db"
|
|
}
|
|
if cfg.RulesDir == "" {
|
|
cfg.RulesDir = "rules"
|
|
}
|
|
for i, o := range cfg.CORSOrigins {
|
|
cfg.CORSOrigins[i] = strings.TrimSpace(o)
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func ensureDir(dir string) error {
|
|
if dir == "" || dir == "." {
|
|
return nil
|
|
}
|
|
return os.MkdirAll(dir, 0o755)
|
|
}
|
|
|
|
func ensureDirectories(cfg *Config) error {
|
|
if err := ensureDir(cfg.RulesDir); err != nil {
|
|
return err
|
|
}
|
|
if dbDir := filepath.Dir(cfg.DBPath); dbDir != "" {
|
|
if err := ensureDir(dbDir); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
configPath := flag.String("config", "config.json", "path to config file")
|
|
flag.Parse()
|
|
|
|
cfg, err := loadConfig(*configPath)
|
|
if err != nil {
|
|
log.Fatalf("config load: %v", err)
|
|
}
|
|
|
|
if err := ensureDirectories(cfg); err != nil {
|
|
log.Fatalf("ensure directories: %v", err)
|
|
}
|
|
|
|
corsOrigins = cfg.CORSOrigins
|
|
crossSiteCookies = cfg.CrossSiteCookies
|
|
|
|
if err := openDB(cfg.DBPath); err != nil {
|
|
log.Fatalf("db open: %v", err)
|
|
}
|
|
if err := migrate(); err != nil {
|
|
log.Fatalf("migrate: %v", err)
|
|
}
|
|
if err := ensureDefaultAdmin(); err != nil {
|
|
log.Fatalf("default admin: %v", err)
|
|
}
|
|
if err := loadRules(cfg.RulesDir); err != nil {
|
|
log.Printf("rules load warning: %v", err)
|
|
}
|
|
|
|
hub = newHub()
|
|
go hub.run()
|
|
|
|
mux := http.NewServeMux()
|
|
registerAuthRoutes(mux)
|
|
registerUserRoutes(mux)
|
|
registerCompetitionRoutes(mux)
|
|
registerPilotRoutes(mux)
|
|
registerPenaltyRoutes(mux)
|
|
registerRuleRoutes(mux)
|
|
registerWSRoutes(mux)
|
|
|
|
server := &http.Server{
|
|
Addr: cfg.Addr,
|
|
Handler: withLog(withCORS(mux)),
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
log.Printf("listening on %s (cors_origins=%v)", cfg.Addr, corsOrigins)
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("listen: %v", err)
|
|
}
|
|
}()
|
|
|
|
sig := make(chan os.Signal, 1)
|
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sig
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
_ = server.Shutdown(ctx)
|
|
}
|
|
|
|
func withLog(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start))
|
|
})
|
|
}
|
|
|
|
func originAllowed(origin string) bool {
|
|
if origin == "" {
|
|
return false
|
|
}
|
|
for _, o := range corsOrigins {
|
|
if o == "*" {
|
|
return true
|
|
}
|
|
if strings.EqualFold(o, origin) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func withCORS(next http.Handler) http.Handler {
|
|
if len(corsOrigins) == 0 {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if originAllowed(origin) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Vary", "Origin")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS")
|
|
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
|
|
if reqHeaders == "" {
|
|
reqHeaders = "Content-Type"
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
|
|
w.Header().Set("Access-Control-Max-Age", "600")
|
|
}
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|