From 0c5aff496b9b7d6ee6df56cf08336a657b4722a9 Mon Sep 17 00:00:00 2001 From: nxshock Date: Wed, 3 Apr 2024 13:47:20 +0500 Subject: [PATCH] Prevent parallel websocket writes --- httpserver.go | 15 ++++++++++----- wconn.go | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 wconn.go diff --git a/httpserver.go b/httpserver.go index 3d1edc3..5926815 100644 --- a/httpserver.go +++ b/httpserver.go @@ -16,7 +16,7 @@ import ( ) type WsConnections struct { - connections map[*websocket.Conn]struct{} + connections map[*WsConnection]struct{} mutex sync.Mutex } @@ -24,19 +24,24 @@ func (wc *WsConnections) Add(c *websocket.Conn) { wc.mutex.Lock() defer wc.mutex.Unlock() - wc.connections[c] = struct{}{} + wc.connections[NewWsConnection(c)] = struct{}{} } func (wc *WsConnections) Delete(c *websocket.Conn) { wc.mutex.Lock() defer wc.mutex.Unlock() - delete(wc.connections, c) + for k := range wc.connections { + if k.w == c { + delete(wc.connections, k) + break + } + } } func (wc *WsConnections) Send(message interface{}) { for conn := range wc.connections { - go func(conn *websocket.Conn) { _ = conn.WriteJSON(message) }(conn) + go func(conn *WsConnection) { _ = conn.Send(message) }(conn) } } @@ -49,7 +54,7 @@ var upgrader = websocket.Upgrader{ } var wsConnections = &WsConnections{ - connections: make(map[*websocket.Conn]struct{})} + connections: make(map[*WsConnection]struct{})} func httpServer(listenAddress string) { if listenAddress == "none" { diff --git a/wconn.go b/wconn.go new file mode 100644 index 0000000..ea1c72a --- /dev/null +++ b/wconn.go @@ -0,0 +1,23 @@ +package main + +import ( + "sync" + + "github.com/gorilla/websocket" +) + +type WsConnection struct { + w *websocket.Conn + mu sync.Mutex +} + +func NewWsConnection(w *websocket.Conn) *WsConnection { + return &WsConnection{w: w} +} + +func (w *WsConnection) Send(message any) error { + w.mu.Lock() + defer w.mu.Unlock() + + return w.w.WriteJSON(message) +}