Files
PenaltyTracker/main.go
T

303 lines
7.1 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"flag"
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
)
type Config struct {
Addr string `json:"addr"`
DBPath string `json:"db_path"`
RulesDir string `json:"rules_dir"`
BackupDir string `json:"backup_dir"`
CORSOrigins []string `json:"cors_origins"`
CrossSiteCookies bool `json:"cross_site_cookies"`
}
var corsOrigins []string
var crossSiteCookies bool
var backupDir string
func defaultConfig() *Config {
return &Config{
Addr: ":8080",
DBPath: "penaltytracker.db",
RulesDir: "rules",
BackupDir: "backup",
CORSOrigins: []string{},
CrossSiteCookies: false,
}
}
func writeConfig(path string, cfg *Config) error {
if dir := filepath.Dir(path); dir != "" {
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func loadConfig(path string) (*Config, error) {
cfg := defaultConfig()
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
if err := writeConfig(path, cfg); err != nil {
return nil, err
}
log.Printf("config file created at %s with defaults", path)
return cfg, nil
}
return nil, err
}
if err := json.Unmarshal(data, cfg); err != nil {
return nil, err
}
if cfg.Addr == "" {
cfg.Addr = ":8080"
}
if cfg.DBPath == "" {
cfg.DBPath = "penaltytracker.db"
}
if cfg.RulesDir == "" {
cfg.RulesDir = "rules"
}
if cfg.BackupDir == "" {
cfg.BackupDir = "backup"
}
for i, o := range cfg.CORSOrigins {
cfg.CORSOrigins[i] = strings.TrimSpace(o)
}
return cfg, nil
}
func ensureDir(dir string) error {
if dir == "" || dir == "." {
return nil
}
return os.MkdirAll(dir, 0o755)
}
func ensureDirectories(cfg *Config) error {
if err := ensureDir(cfg.RulesDir); err != nil {
return err
}
if err := ensureDir(cfg.BackupDir); err != nil {
return err
}
if dbDir := filepath.Dir(cfg.DBPath); dbDir != "" {
if err := ensureDir(dbDir); err != nil {
return err
}
}
return nil
}
func main() {
configPath := flag.String("config", "config.json", "path to config file")
flag.Parse()
cfg, err := loadConfig(*configPath)
if err != nil {
log.Fatalf("config load: %v", err)
}
if err := ensureDirectories(cfg); err != nil {
log.Fatalf("ensure directories: %v", err)
}
corsOrigins = cfg.CORSOrigins
crossSiteCookies = cfg.CrossSiteCookies
backupDir = cfg.BackupDir
if err := openDB(cfg.DBPath); err != nil {
log.Fatalf("db open: %v", err)
}
if err := migrate(); err != nil {
log.Fatalf("migrate: %v", err)
}
if err := ensureDefaultAdmin(); err != nil {
log.Fatalf("default admin: %v", err)
}
if err := loadRules(cfg.RulesDir); err != nil {
log.Printf("rules load warning: %v", err)
}
hub = newHub()
go hub.run()
mux := http.NewServeMux()
registerAuthRoutes(mux)
registerUserRoutes(mux)
registerCompetitionRoutes(mux)
registerPilotRoutes(mux)
registerPenaltyRoutes(mux)
registerRuleRoutes(mux)
registerWSRoutes(mux)
handler := withSecurityHeaders(withLog(withCSRF(withCORS(mux))))
server := &http.Server{
Addr: cfg.Addr,
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
}
go func() {
log.Printf("listening on %s (cors_origins=%v)", cfg.Addr, corsOrigins)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen: %v", err)
}
}()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
}
func withLog(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start))
})
}
func withSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("X-Content-Type-Options", "nosniff")
h.Set("X-Frame-Options", "DENY")
h.Set("Referrer-Policy", "no-referrer")
h.Set("Permissions-Policy", "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()")
h.Set("Cross-Origin-Resource-Policy", "same-site")
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'")
if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") {
h.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
}
func originAllowed(origin string) bool {
if origin == "" {
return false
}
for _, o := range corsOrigins {
if o == "*" {
return true
}
if strings.EqualFold(o, origin) {
return true
}
}
return false
}
func sameOriginRequest(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
// Fall back to Referer
ref := r.Header.Get("Referer")
if ref == "" {
// No origin info: only safe methods allowed. State-changing must have Origin.
return false
}
u, err := url.Parse(ref)
if err != nil {
return false
}
origin = u.Scheme + "://" + u.Host
}
u, err := url.Parse(origin)
if err != nil {
return false
}
originHost, _, _ := net.SplitHostPort(u.Host)
if originHost == "" {
originHost = u.Host
}
reqHost, _, _ := net.SplitHostPort(r.Host)
if reqHost == "" {
reqHost = r.Host
}
return strings.EqualFold(originHost, reqHost)
}
func isStateChanging(method string) bool {
switch method {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
return true
}
return false
}
func withCSRF(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !isStateChanging(r.Method) {
next.ServeHTTP(w, r)
return
}
origin := r.Header.Get("Origin")
// Allow if Origin matches a configured CORS origin.
if origin != "" && originAllowed(origin) {
next.ServeHTTP(w, r)
return
}
// Allow same-origin requests (Origin or Referer host matches request Host).
if sameOriginRequest(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusForbidden, "csrf_forbidden")
})
}
func withCORS(next http.Handler) http.Handler {
if len(corsOrigins) == 0 {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if originAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS")
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
if reqHeaders == "" {
reqHeaders = "Content-Type"
}
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
w.Header().Set("Access-Control-Max-Age", "600")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}