141 lines
2.6 KiB
Go
141 lines
2.6 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
// Non-browser client without Origin header — allow.
|
|
return true
|
|
}
|
|
if originAllowed(origin) {
|
|
return true
|
|
}
|
|
return sameOriginRequest(r)
|
|
},
|
|
}
|
|
|
|
type wsMessage struct {
|
|
Type string `json:"type"`
|
|
Payload any `json:"payload,omitempty"`
|
|
}
|
|
|
|
type client struct {
|
|
conn *websocket.Conn
|
|
competitionID int64
|
|
send chan []byte
|
|
}
|
|
|
|
type Hub struct {
|
|
mu sync.RWMutex
|
|
clients map[int64]map[*client]struct{}
|
|
}
|
|
|
|
var hub *Hub
|
|
|
|
func newHub() *Hub {
|
|
return &Hub{clients: map[int64]map[*client]struct{}{}}
|
|
}
|
|
|
|
func (h *Hub) run() {}
|
|
|
|
func (h *Hub) register(c *client) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
set, ok := h.clients[c.competitionID]
|
|
if !ok {
|
|
set = map[*client]struct{}{}
|
|
h.clients[c.competitionID] = set
|
|
}
|
|
set[c] = struct{}{}
|
|
}
|
|
|
|
func (h *Hub) unregister(c *client) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
if set, ok := h.clients[c.competitionID]; ok {
|
|
if _, ok := set[c]; ok {
|
|
delete(set, c)
|
|
close(c.send)
|
|
}
|
|
if len(set) == 0 {
|
|
delete(h.clients, c.competitionID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) broadcast(competitionID int64, kind string, payload any) {
|
|
msg := wsMessage{Type: kind, Payload: payload}
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
return
|
|
}
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
set, ok := h.clients[competitionID]
|
|
if !ok {
|
|
return
|
|
}
|
|
for c := range set {
|
|
select {
|
|
case c.send <- data:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func registerWSRoutes(mux *http.ServeMux) {
|
|
mux.HandleFunc("GET /api/competitions/{id}/ws", requireAuth(handleWS))
|
|
}
|
|
|
|
func handleWS(w http.ResponseWriter, r *http.Request) {
|
|
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "invalid_id")
|
|
return
|
|
}
|
|
u := userFromCtx(r)
|
|
if _, ok := canAccessCompetition(u, id); !ok {
|
|
writeError(w, http.StatusForbidden, "forbidden")
|
|
return
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
c := &client{conn: conn, competitionID: id, send: make(chan []byte, 16)}
|
|
hub.register(c)
|
|
go writePump(c)
|
|
readPump(c)
|
|
}
|
|
|
|
func readPump(c *client) {
|
|
defer func() {
|
|
hub.unregister(c)
|
|
c.conn.Close()
|
|
}()
|
|
c.conn.SetReadLimit(1024)
|
|
for {
|
|
if _, _, err := c.conn.ReadMessage(); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func writePump(c *client) {
|
|
defer c.conn.Close()
|
|
for msg := range c.send {
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|