Files
openaccounting-server/core/ws/ws.go

340 lines
7.6 KiB
Go
Raw Normal View History

2018-10-19 15:31:41 -04:00
package ws
import (
"encoding/json"
"errors"
"github.com/Masterminds/semver"
"github.com/ant0ine/go-json-rest/rest"
"github.com/gorilla/websocket"
"github.com/mitchellh/mapstructure"
"github.com/openaccounting/oa-server/core/auth"
"github.com/openaccounting/oa-server/core/model/types"
"log"
"net/http"
"sync"
)
const version = "1.2.0"
2018-10-19 15:31:41 -04:00
//var upgrader = websocket.Upgrader{} // use default options
var txSubscriptions = make(map[string][]*websocket.Conn)
var accountSubscriptions = make(map[string][]*websocket.Conn)
var priceSubscriptions = make(map[string][]*websocket.Conn)
var userMap = make(map[*websocket.Conn]*types.User)
var sequenceNumbers = make(map[*websocket.Conn]int)
var locks = make(map[*websocket.Conn]*sync.Mutex)
type Message struct {
Version string `json:"version"`
SequenceNumber int `json:"sequenceNumber"`
Type string `json:"type"`
Action string `json:"action"`
Data interface{} `json:"data"`
}
func Handler(w rest.ResponseWriter, r *rest.Request) {
c, err := websocket.Upgrade(w.(http.ResponseWriter), r.Request, nil, 0, 0)
if err != nil {
log.Print("upgrade:", err)
return
}
sequenceNumbers[c] = -1
locks[c] = &sync.Mutex{}
defer c.Close()
for {
mt, messageData, err := c.ReadMessage()
if err != nil {
log.Println("readerr:", err)
// remove connection from maps
unsubscribeAll(c)
break
}
if mt != websocket.TextMessage {
log.Println("Unsupported message type", mt)
continue
}
message := Message{}
err = json.Unmarshal(messageData, &message)
if err != nil {
log.Println("Could not parse message:", string(messageData))
continue
}
log.Printf("recv: %s", message)
// check version
err = checkVersion(message.Version)
if err != nil {
log.Println(err.Error())
2019-05-20 10:40:40 -04:00
writeMessageRaw(c, websocket.CloseMessage, websocket.FormatCloseMessage(4001, err.Error()))
2018-10-19 15:31:41 -04:00
break
}
// make sure they are authenticated
if userMap[c] == nil {
log.Println("checking message for authentication")
err = authenticate(message, c)
if err != nil {
log.Println("Authentication error " + err.Error())
2019-05-20 10:40:40 -04:00
writeMessageRaw(c, websocket.CloseMessage, websocket.FormatCloseMessage(4000, err.Error()))
2018-10-19 15:31:41 -04:00
break
}
continue
}
err = processMessage(message, c)
if err != nil {
log.Println(err)
continue
}
}
}
func getKey(userId string, orgId string) string {
return userId + "-" + orgId
}
func processMessage(message Message, conn *websocket.Conn) error {
var dataString string
err := mapstructure.Decode(message.Data, &dataString)
if err != nil {
return err
}
key := getKey(userMap[conn].Id, dataString)
log.Println(message.Action, message.Type, dataString)
switch message.Action {
case "subscribe":
switch message.Type {
case "transaction":
subscribe(conn, key, txSubscriptions)
case "account":
subscribe(conn, key, accountSubscriptions)
case "price":
subscribe(conn, key, priceSubscriptions)
default:
return errors.New("Unhandled message type: " + message.Type)
}
case "unsubscribe":
switch message.Type {
case "transaction":
unsubscribe(conn, key, txSubscriptions)
case "account":
unsubscribe(conn, key, accountSubscriptions)
case "price":
unsubscribe(conn, key, priceSubscriptions)
default:
return errors.New("Unhandled message type: " + message.Type)
}
case "ping":
sequenceNumbers[conn]++
response := Message{version, sequenceNumbers[conn], "pong", "pong", nil}
responseData, err := json.Marshal(response)
if err != nil {
return err
}
2019-05-20 10:40:40 -04:00
err = writeMessageRaw(conn, websocket.TextMessage, responseData)
2018-10-19 15:31:41 -04:00
if err != nil {
unsubscribeAll(conn)
return err
}
}
return nil
}
func subscribe(conn *websocket.Conn, key string, clientMap map[string][]*websocket.Conn) {
conns := clientMap[key]
alreadySubscribed := false
for _, c := range conns {
if conn == c {
alreadySubscribed = true
}
}
if alreadySubscribed == false {
clientMap[key] = append(clientMap[key], conn)
}
}
func unsubscribe(conn *websocket.Conn, key string, clientMap map[string][]*websocket.Conn) {
2019-05-20 10:40:40 -04:00
locks[conn].Lock()
defer locks[conn].Unlock()
2018-10-19 15:31:41 -04:00
newConns := clientMap[key][:0]
for _, c := range clientMap[key] {
if conn != c {
newConns = append(newConns, c)
}
}
}
func unsubscribeAll(conn *websocket.Conn) {
2019-05-20 10:40:40 -04:00
locks[conn].Lock()
2018-10-19 15:31:41 -04:00
for key, conns := range txSubscriptions {
newConns := conns[:0]
for _, c := range conns {
if conn != c {
newConns = append(newConns, c)
}
}
txSubscriptions[key] = newConns
}
for key, conns := range accountSubscriptions {
newConns := conns[:0]
for _, c := range conns {
if conn != c {
newConns = append(newConns, c)
}
}
accountSubscriptions[key] = newConns
}
for key, conns := range priceSubscriptions {
newConns := conns[:0]
for _, c := range conns {
if conn != c {
newConns = append(newConns, c)
}
}
priceSubscriptions[key] = newConns
}
delete(userMap, conn)
delete(sequenceNumbers, conn)
delete(locks, conn)
}
func PushTransaction(transaction *types.Transaction, userIds []string, action string) {
log.Println(txSubscriptions)
message := Message{version, -1, "transaction", action, transaction}
for _, userId := range userIds {
key := getKey(userId, transaction.OrgId)
for _, conn := range txSubscriptions[key] {
2019-05-20 10:40:40 -04:00
err := writeMessage(conn, &message)
2018-10-19 15:31:41 -04:00
if err != nil {
log.Println("Cannot PushTransaction to client:", err)
unsubscribeAll(conn)
}
}
}
}
func PushAccount(account *types.Account, userIds []string, action string) {
message := Message{version, -1, "account", action, account}
for _, userId := range userIds {
key := getKey(userId, account.OrgId)
for _, conn := range accountSubscriptions[key] {
2019-05-20 10:40:40 -04:00
err := writeMessage(conn, &message)
2018-10-19 15:31:41 -04:00
if err != nil {
log.Println("Cannot PushAccount to client:", err)
unsubscribeAll(conn)
}
}
}
}
func PushPrice(price *types.Price, userIds []string, action string) {
message := Message{version, -1, "price", action, price}
for _, userId := range userIds {
key := getKey(userId, price.OrgId)
for _, conn := range priceSubscriptions[key] {
2019-05-20 10:40:40 -04:00
err := writeMessage(conn, &message)
2018-10-19 15:31:41 -04:00
if err != nil {
log.Println("Cannot PushPrice to client:", err)
unsubscribeAll(conn)
}
}
}
}
func authenticate(message Message, conn *websocket.Conn) error {
var id string
err := mapstructure.Decode(message.Data, &id)
if err != nil {
return err
}
if message.Action != "authenticate" {
return errors.New("Authentication required")
}
user, err := auth.Instance.Authenticate(id, "")
if err != nil {
return err
}
userMap[conn] = user
return nil
}
func checkVersion(clientVersion string) error {
constraint, err := semver.NewConstraint(clientVersion)
if err != nil {
return errors.New("Invalid version")
}
serverVersion, _ := semver.NewVersion(version)
2018-10-20 18:17:15 -04:00
compatVersion, _ := semver.NewVersion("0.1.8")
2018-10-19 15:31:41 -04:00
versionMatch := constraint.Check(serverVersion)
2018-10-20 18:17:15 -04:00
compatMatch := constraint.Check(compatVersion)
2018-10-19 15:31:41 -04:00
2018-10-20 18:17:15 -04:00
if versionMatch == false && compatMatch == false {
2018-10-19 15:31:41 -04:00
return errors.New("Invalid version")
}
return nil
}
2019-05-20 10:40:40 -04:00
func writeMessage(conn *websocket.Conn, message *Message) error {
locks[conn].Lock()
sequenceNumbers[conn]++
message.SequenceNumber = sequenceNumbers[conn]
locks[conn].Unlock()
messageData, err := json.Marshal(message)
if err != nil {
log.Println("json error:", err)
return err
}
return writeMessageRaw(conn, websocket.TextMessage, messageData)
}
func writeMessageRaw(conn *websocket.Conn, messageType int, data []byte) error {
2018-10-19 15:31:41 -04:00
locks[conn].Lock()
defer locks[conn].Unlock()
return conn.WriteMessage(messageType, data)
}