293 lines
6.9 KiB
Go
293 lines
6.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type Config struct {
|
|
Addr string `json:"addr"`
|
|
DBPath string `json:"db_path"`
|
|
RulesDir string `json:"rules_dir"`
|
|
CORSOrigins []string `json:"cors_origins"`
|
|
CrossSiteCookies bool `json:"cross_site_cookies"`
|
|
}
|
|
|
|
var corsOrigins []string
|
|
var crossSiteCookies bool
|
|
|
|
func defaultConfig() *Config {
|
|
return &Config{
|
|
Addr: ":8080",
|
|
DBPath: "penaltytracker.db",
|
|
RulesDir: "rules",
|
|
CORSOrigins: []string{},
|
|
CrossSiteCookies: false,
|
|
}
|
|
}
|
|
|
|
func writeConfig(path string, cfg *Config) error {
|
|
if dir := filepath.Dir(path); dir != "" {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
data, err := json.MarshalIndent(cfg, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(path, data, 0o644)
|
|
}
|
|
|
|
func loadConfig(path string) (*Config, error) {
|
|
cfg := defaultConfig()
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
if err := writeConfig(path, cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
log.Printf("config file created at %s with defaults", path)
|
|
return cfg, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if err := json.Unmarshal(data, cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
if cfg.Addr == "" {
|
|
cfg.Addr = ":8080"
|
|
}
|
|
if cfg.DBPath == "" {
|
|
cfg.DBPath = "penaltytracker.db"
|
|
}
|
|
if cfg.RulesDir == "" {
|
|
cfg.RulesDir = "rules"
|
|
}
|
|
for i, o := range cfg.CORSOrigins {
|
|
cfg.CORSOrigins[i] = strings.TrimSpace(o)
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func ensureDir(dir string) error {
|
|
if dir == "" || dir == "." {
|
|
return nil
|
|
}
|
|
return os.MkdirAll(dir, 0o755)
|
|
}
|
|
|
|
func ensureDirectories(cfg *Config) error {
|
|
if err := ensureDir(cfg.RulesDir); err != nil {
|
|
return err
|
|
}
|
|
if dbDir := filepath.Dir(cfg.DBPath); dbDir != "" {
|
|
if err := ensureDir(dbDir); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
configPath := flag.String("config", "config.json", "path to config file")
|
|
flag.Parse()
|
|
|
|
cfg, err := loadConfig(*configPath)
|
|
if err != nil {
|
|
log.Fatalf("config load: %v", err)
|
|
}
|
|
|
|
if err := ensureDirectories(cfg); err != nil {
|
|
log.Fatalf("ensure directories: %v", err)
|
|
}
|
|
|
|
corsOrigins = cfg.CORSOrigins
|
|
crossSiteCookies = cfg.CrossSiteCookies
|
|
|
|
if err := openDB(cfg.DBPath); err != nil {
|
|
log.Fatalf("db open: %v", err)
|
|
}
|
|
if err := migrate(); err != nil {
|
|
log.Fatalf("migrate: %v", err)
|
|
}
|
|
if err := ensureDefaultAdmin(); err != nil {
|
|
log.Fatalf("default admin: %v", err)
|
|
}
|
|
if err := loadRules(cfg.RulesDir); err != nil {
|
|
log.Printf("rules load warning: %v", err)
|
|
}
|
|
|
|
hub = newHub()
|
|
go hub.run()
|
|
|
|
mux := http.NewServeMux()
|
|
registerAuthRoutes(mux)
|
|
registerUserRoutes(mux)
|
|
registerCompetitionRoutes(mux)
|
|
registerPilotRoutes(mux)
|
|
registerPenaltyRoutes(mux)
|
|
registerRuleRoutes(mux)
|
|
registerWSRoutes(mux)
|
|
|
|
handler := withSecurityHeaders(withLog(withCSRF(withCORS(mux))))
|
|
|
|
server := &http.Server{
|
|
Addr: cfg.Addr,
|
|
Handler: handler,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
log.Printf("listening on %s (cors_origins=%v)", cfg.Addr, corsOrigins)
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("listen: %v", err)
|
|
}
|
|
}()
|
|
|
|
sig := make(chan os.Signal, 1)
|
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sig
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
_ = server.Shutdown(ctx)
|
|
}
|
|
|
|
func withLog(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start))
|
|
})
|
|
}
|
|
|
|
func withSecurityHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
h := w.Header()
|
|
h.Set("X-Content-Type-Options", "nosniff")
|
|
h.Set("X-Frame-Options", "DENY")
|
|
h.Set("Referrer-Policy", "no-referrer")
|
|
h.Set("Permissions-Policy", "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()")
|
|
h.Set("Cross-Origin-Resource-Policy", "same-site")
|
|
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'")
|
|
if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") {
|
|
h.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func originAllowed(origin string) bool {
|
|
if origin == "" {
|
|
return false
|
|
}
|
|
for _, o := range corsOrigins {
|
|
if o == "*" {
|
|
return true
|
|
}
|
|
if strings.EqualFold(o, origin) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func sameOriginRequest(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
// Fall back to Referer
|
|
ref := r.Header.Get("Referer")
|
|
if ref == "" {
|
|
// No origin info: only safe methods allowed. State-changing must have Origin.
|
|
return false
|
|
}
|
|
u, err := url.Parse(ref)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
origin = u.Scheme + "://" + u.Host
|
|
}
|
|
u, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
originHost, _, _ := net.SplitHostPort(u.Host)
|
|
if originHost == "" {
|
|
originHost = u.Host
|
|
}
|
|
reqHost, _, _ := net.SplitHostPort(r.Host)
|
|
if reqHost == "" {
|
|
reqHost = r.Host
|
|
}
|
|
return strings.EqualFold(originHost, reqHost)
|
|
}
|
|
|
|
func isStateChanging(method string) bool {
|
|
switch method {
|
|
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func withCSRF(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !isStateChanging(r.Method) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
origin := r.Header.Get("Origin")
|
|
// Allow if Origin matches a configured CORS origin.
|
|
if origin != "" && originAllowed(origin) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
// Allow same-origin requests (Origin or Referer host matches request Host).
|
|
if sameOriginRequest(r) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
writeError(w, http.StatusForbidden, "csrf_forbidden")
|
|
})
|
|
}
|
|
|
|
func withCORS(next http.Handler) http.Handler {
|
|
if len(corsOrigins) == 0 {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if originAllowed(origin) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Vary", "Origin")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS")
|
|
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
|
|
if reqHeaders == "" {
|
|
reqHeaders = "Content-Type"
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
|
|
w.Header().Set("Access-Control-Max-Age", "600")
|
|
}
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|