From e64f60fd0929c72ced869b2e7dcab69a07b07ca2 Mon Sep 17 00:00:00 2001 From: Patrick Nagurny Date: Mon, 20 May 2019 10:40:40 -0400 Subject: [PATCH] fix connection locks issues --- core/ws/ws.go | 63 ++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/core/ws/ws.go b/core/ws/ws.go index f714bd5..b669c56 100644 --- a/core/ws/ws.go +++ b/core/ws/ws.go @@ -73,7 +73,7 @@ func Handler(w rest.ResponseWriter, r *rest.Request) { if err != nil { log.Println(err.Error()) - writeMessage(c, websocket.CloseMessage, websocket.FormatCloseMessage(4001, err.Error())) + writeMessageRaw(c, websocket.CloseMessage, websocket.FormatCloseMessage(4001, err.Error())) break } @@ -83,7 +83,7 @@ func Handler(w rest.ResponseWriter, r *rest.Request) { err = authenticate(message, c) if err != nil { log.Println("Authentication error " + err.Error()) - writeMessage(c, websocket.CloseMessage, websocket.FormatCloseMessage(4000, err.Error())) + writeMessageRaw(c, websocket.CloseMessage, websocket.FormatCloseMessage(4000, err.Error())) break } continue @@ -146,7 +146,7 @@ func processMessage(message Message, conn *websocket.Conn) error { return err } - err = writeMessage(conn, websocket.TextMessage, responseData) + err = writeMessageRaw(conn, websocket.TextMessage, responseData) if err != nil { unsubscribeAll(conn) @@ -173,6 +173,9 @@ func subscribe(conn *websocket.Conn, key string, clientMap map[string][]*websock } func unsubscribe(conn *websocket.Conn, key string, clientMap map[string][]*websocket.Conn) { + locks[conn].Lock() + defer locks[conn].Unlock() + newConns := clientMap[key][:0] for _, c := range clientMap[key] { @@ -183,7 +186,8 @@ func unsubscribe(conn *websocket.Conn, key string, clientMap map[string][]*webso } func unsubscribeAll(conn *websocket.Conn) { - // TODO fix "concurrent map iteration and map write" error + locks[conn].Lock() + for key, conns := range txSubscriptions { newConns := conns[:0] for _, c := range conns { @@ -227,16 +231,7 @@ func PushTransaction(transaction *types.Transaction, userIds []string, action st for _, userId := range userIds { key := getKey(userId, transaction.OrgId) for _, conn := range txSubscriptions[key] { - sequenceNumbers[conn]++ - message.SequenceNumber = sequenceNumbers[conn] - messageData, err := json.Marshal(message) - - if err != nil { - log.Println("PushTransaction json error:", err) - return - } - - err = writeMessage(conn, websocket.TextMessage, messageData) + err := writeMessage(conn, &message) if err != nil { log.Println("Cannot PushTransaction to client:", err) @@ -252,15 +247,7 @@ func PushAccount(account *types.Account, userIds []string, action string) { for _, userId := range userIds { key := getKey(userId, account.OrgId) for _, conn := range accountSubscriptions[key] { - sequenceNumbers[conn]++ - message.SequenceNumber = sequenceNumbers[conn] - messageData, err := json.Marshal(message) - - if err != nil { - log.Println("PushAccount error:", err) - return - } - err = writeMessage(conn, websocket.TextMessage, messageData) + err := writeMessage(conn, &message) if err != nil { log.Println("Cannot PushAccount to client:", err) @@ -276,16 +263,7 @@ func PushPrice(price *types.Price, userIds []string, action string) { for _, userId := range userIds { key := getKey(userId, price.OrgId) for _, conn := range priceSubscriptions[key] { - sequenceNumbers[conn]++ - message.SequenceNumber = sequenceNumbers[conn] - messageData, err := json.Marshal(message) - - if err != nil { - log.Println("PushPrice error:", err) - return - } - - err = writeMessage(conn, websocket.TextMessage, messageData) + err := writeMessage(conn, &message) if err != nil { log.Println("Cannot PushPrice to client:", err) @@ -337,7 +315,24 @@ func checkVersion(clientVersion string) error { return nil } -func writeMessage(conn *websocket.Conn, messageType int, data []byte) error { +func writeMessage(conn *websocket.Conn, message *Message) error { + locks[conn].Lock() + sequenceNumbers[conn]++ + message.SequenceNumber = sequenceNumbers[conn] + locks[conn].Unlock() + + messageData, err := json.Marshal(message) + + if err != nil { + log.Println("json error:", err) + return err + } + + return writeMessageRaw(conn, websocket.TextMessage, messageData) + +} + +func writeMessageRaw(conn *websocket.Conn, messageType int, data []byte) error { locks[conn].Lock() defer locks[conn].Unlock() return conn.WriteMessage(messageType, data)