Files

141 lines
2.6 KiB
Go

package main
import (
"encoding/json"
"net/http"
"strconv"
"sync"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
// Non-browser client without Origin header — allow.
return true
}
if originAllowed(origin) {
return true
}
return sameOriginRequest(r)
},
}
type wsMessage struct {
Type string `json:"type"`
Payload any `json:"payload,omitempty"`
}
type client struct {
conn *websocket.Conn
competitionID int64
send chan []byte
}
type Hub struct {
mu sync.RWMutex
clients map[int64]map[*client]struct{}
}
var hub *Hub
func newHub() *Hub {
return &Hub{clients: map[int64]map[*client]struct{}{}}
}
func (h *Hub) run() {}
func (h *Hub) register(c *client) {
h.mu.Lock()
defer h.mu.Unlock()
set, ok := h.clients[c.competitionID]
if !ok {
set = map[*client]struct{}{}
h.clients[c.competitionID] = set
}
set[c] = struct{}{}
}
func (h *Hub) unregister(c *client) {
h.mu.Lock()
defer h.mu.Unlock()
if set, ok := h.clients[c.competitionID]; ok {
if _, ok := set[c]; ok {
delete(set, c)
close(c.send)
}
if len(set) == 0 {
delete(h.clients, c.competitionID)
}
}
}
func (h *Hub) broadcast(competitionID int64, kind string, payload any) {
msg := wsMessage{Type: kind, Payload: payload}
data, err := json.Marshal(msg)
if err != nil {
return
}
h.mu.RLock()
defer h.mu.RUnlock()
set, ok := h.clients[competitionID]
if !ok {
return
}
for c := range set {
select {
case c.send <- data:
default:
}
}
}
func registerWSRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/competitions/{id}/ws", requireAuth(handleWS))
}
func handleWS(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_id")
return
}
u := userFromCtx(r)
if _, ok := canAccessCompetition(u, id); !ok {
writeError(w, http.StatusForbidden, "forbidden")
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
c := &client{conn: conn, competitionID: id, send: make(chan []byte, 16)}
hub.register(c)
go writePump(c)
readPump(c)
}
func readPump(c *client) {
defer func() {
hub.unregister(c)
c.conn.Close()
}()
c.conn.SetReadLimit(1024)
for {
if _, _, err := c.conn.ReadMessage(); err != nil {
return
}
}
}
func writePump(c *client) {
defer c.conn.Close()
for msg := range c.send {
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
return
}
}
}