initial commit

This commit is contained in:
Patrick Nagurny
2018-10-19 15:31:41 -04:00
commit e2dd29259f
203 changed files with 44839 additions and 0 deletions

377
core/model/account.go Normal file
View File

@@ -0,0 +1,377 @@
package model
import (
"errors"
"fmt"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/ws"
"sort"
"time"
)
type AccountInterface interface {
CreateAccount(account *types.Account, userId string) error
UpdateAccount(account *types.Account, userId string) error
DeleteAccount(id string, userId string, orgId string) error
GetAccounts(orgId string, userId string, tokenId string) ([]*types.Account, error)
GetAccountsWithBalances(orgId string, userId string, tokenId string, date time.Time) ([]*types.Account, error)
}
type ByName []*types.Account
func (a ByName) Len() int { return len(a) }
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByName) Less(i, j int) bool { return a[i].Name < a[j].Name }
func (model *Model) CreateAccount(account *types.Account, userId string) (err error) {
if account.Id == "" {
return errors.New("id required")
}
if account.OrgId == "" {
return errors.New("orgId required")
}
if account.Name == "" {
return errors.New("name required")
}
if account.Currency == "" {
return errors.New("currency required")
}
userAccounts, err := model.GetAccounts(account.OrgId, userId, "")
if err != nil {
return
}
if !model.accountsContainWriteAccess(userAccounts, account.Parent) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", account.Parent))
}
err = model.db.InsertAccount(account)
if err != nil {
return
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access account
userIds, err2 := model.db.GetOrgUserIds(account.OrgId)
if err2 == nil {
ws.PushAccount(account, userIds, "create")
}
return
}
func (model *Model) UpdateAccount(account *types.Account, userId string) (err error) {
if account.Id == "" {
return errors.New("id required")
}
if account.OrgId == "" {
return errors.New("orgId required")
}
if account.Name == "" {
return errors.New("name required")
}
if account.Currency == "" {
return errors.New("currency required")
}
if account.Parent == account.Id {
return errors.New("account cannot be its own parent")
}
userAccounts, err := model.GetAccounts(account.OrgId, userId, "")
if err != nil {
return
}
if !model.accountsContainWriteAccess(userAccounts, account.Parent) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", account.Parent))
}
err = model.db.UpdateAccount(account)
if err != nil {
return
}
err = model.db.AddBalance(account, time.Now())
if err != nil {
return
}
err = model.db.AddNativeBalanceCost(account, time.Now())
if err != nil {
return
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access account
userIds, err2 := model.db.GetOrgUserIds(account.OrgId)
if err2 == nil {
ws.PushAccount(account, userIds, "update")
}
return
}
func (model *Model) DeleteAccount(id string, userId string, orgId string) (err error) {
// TODO make sure user is part of org
// check to make sure user has permission
userAccounts, err := model.GetAccounts(orgId, userId, "")
if err != nil {
return
}
if !model.accountsContainWriteAccess(userAccounts, id) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", id))
}
// don't allow deleting of accounts that have transactions or child accounts
count, err := model.db.GetSplitCountByAccountId(id)
if err != nil {
return
}
if count != 0 {
return errors.New("Cannot delete an account that has transactions")
}
count, err = model.db.GetChildCountByAccountId(id)
if err != nil {
return
}
if count != 0 {
return errors.New("Cannot delete an account that has children")
}
account, err := model.db.GetAccount(id)
if err != nil {
return
}
err = model.db.DeleteAccount(id)
if err != nil {
return
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access account
userIds, err2 := model.db.GetOrgUserIds(account.OrgId)
if err2 == nil {
ws.PushAccount(account, userIds, "delete")
}
return
}
func (model *Model) getAccounts(orgId string, userId string, tokenId string, date time.Time, withBalances bool) ([]*types.Account, error) {
permissionedAccounts, err := model.db.GetPermissionedAccountIds(orgId, userId, "")
if err != nil {
return nil, err
}
var allAccounts []*types.Account
if withBalances == true {
allAccounts, err = model.getAllAccountsWithBalances(orgId, date)
} else {
allAccounts, err = model.getAllAccounts(orgId)
}
if err != nil {
return nil, err
}
accountMap := model.makeAccountMap(allAccounts)
writeAccessMap := make(map[string]*types.Account)
readAccessMap := make(map[string]*types.Account)
for _, accountId := range permissionedAccounts {
writeAccessMap[accountId] = accountMap[accountId].Account
// parents are read only
parents := model.getParents(accountId, accountMap)
for _, parentAccount := range parents {
readAccessMap[parentAccount.Id] = parentAccount
}
// top level accounts are initially read only unless user has permission
topLevelAccounts := model.getTopLevelAccounts(accountMap)
for _, topLevelAccount := range topLevelAccounts {
readAccessMap[topLevelAccount.Id] = topLevelAccount
}
// Children have write access
children := model.getChildren(accountId, accountMap)
for _, childAccount := range children {
writeAccessMap[childAccount.Id] = childAccount
}
}
filtered := make([]*types.Account, 0)
for _, account := range writeAccessMap {
filtered = append(filtered, account)
}
for id, account := range readAccessMap {
_, ok := writeAccessMap[id]
if ok == false {
account.ReadOnly = true
filtered = append(filtered, account)
}
}
// TODO sort by inserted
sort.Sort(ByName(filtered))
return filtered, nil
}
func (model *Model) GetAccounts(orgId string, userId string, tokenId string) ([]*types.Account, error) {
return model.getAccounts(orgId, userId, tokenId, time.Time{}, false)
}
func (model *Model) GetAccountsWithBalances(orgId string, userId string, tokenId string, date time.Time) ([]*types.Account, error) {
return model.getAccounts(orgId, userId, tokenId, date, true)
}
func (model *Model) getAllAccounts(orgId string) ([]*types.Account, error) {
return model.db.GetAccountsByOrgId(orgId)
}
func (model *Model) getAllAccountsWithBalances(orgId string, date time.Time) ([]*types.Account, error) {
accounts, err := model.db.GetAccountsByOrgId(orgId)
if err != nil {
return nil, err
}
err = model.db.AddBalances(accounts, date)
if err != nil {
return nil, err
}
err = model.db.AddNativeBalancesCost(accounts, date)
if err != nil {
return nil, err
}
return accounts, nil
}
func (model *Model) makeAccountMap(accounts []*types.Account) map[string]*types.AccountNode {
m := make(map[string]*types.AccountNode)
for _, account := range accounts {
m[account.Id] = &types.AccountNode{
Account: account,
Parent: nil,
Children: nil,
}
}
for _, account := range accounts {
m[account.Id].Parent = m[account.Parent]
if value, ok := m[account.Parent]; ok {
value.Children = append(value.Children, m[account.Id])
value.Account.HasChildren = true
}
}
return m
}
func (model *Model) getChildren(parentId string, accountMap map[string]*types.AccountNode) []*types.Account {
if _, ok := accountMap[parentId]; !ok {
return nil
}
children := make([]*types.Account, 0)
for _, childAccountNode := range accountMap[parentId].Children {
children = append(children, childAccountNode.Account)
grandChildren := model.getChildren(childAccountNode.Account.Id, accountMap)
children = append(children, grandChildren...)
}
return children
}
func (model *Model) getParents(accountId string, accountMap map[string]*types.AccountNode) []*types.Account {
node, ok := accountMap[accountId]
if !ok {
return nil
}
if node.Parent == nil {
return make([]*types.Account, 0)
}
parents := model.getParents(node.Parent.Account.Id, accountMap)
return append(parents, node.Parent.Account)
}
func (model *Model) accountsContainWriteAccess(accounts []*types.Account, accountId string) bool {
for _, account := range accounts {
if account.Id == accountId && !account.ReadOnly {
return true
}
}
return false
}
func (model *Model) getAccountFromList(accounts []*types.Account, accountId string) *types.Account {
for _, account := range accounts {
if account.Id == accountId {
return account
}
}
return nil
}
func (model *Model) getTopLevelAccounts(accountMap map[string]*types.AccountNode) []*types.Account {
accounts := make([]*types.Account, 0)
for _, node := range accountMap {
if node.Parent == nil {
accounts = append(accounts, node.Account)
for _, child := range node.Children {
accounts = append(accounts, child.Account)
}
break
}
}
return accounts
}

330
core/model/account_test.go Normal file
View File

@@ -0,0 +1,330 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
"time"
)
type TdAccount struct {
db.Datastore
mock.Mock
}
func (td *TdAccount) GetPermissionedAccountIds(userId string, orgId string, tokenId string) ([]string, error) {
// User has permission to only "Assets" account
return []string{"2"}, nil
}
func (td *TdAccount) GetAccountsByOrgId(orgId string) ([]*types.Account, error) {
args := td.Called(orgId)
return args.Get(0).([]*types.Account), args.Error(1)
}
func (td *TdAccount) InsertAccount(account *types.Account) error {
return nil
}
func (td *TdAccount) UpdateAccount(account *types.Account) error {
return nil
}
func (td *TdAccount) AddBalance(account *types.Account, date time.Time) error {
return nil
}
func (td *TdAccount) AddNativeBalanceNearestInTime(account *types.Account, date time.Time) error {
return nil
}
func (td *TdAccount) AddNativeBalanceCost(account *types.Account, date time.Time) error {
return nil
}
func (td *TdAccount) AddBalances(accounts []*types.Account, date time.Time) error {
balance := int64(1000)
for _, account := range accounts {
account.Balance = &balance
}
return nil
}
func (td *TdAccount) AddNativeBalancesNearestInTime(accounts []*types.Account, date time.Time) error {
balance := int64(1000)
for _, account := range accounts {
account.NativeBalance = &balance
}
return nil
}
func (td *TdAccount) AddNativeBalancesCost(accounts []*types.Account, date time.Time) error {
balance := int64(1000)
for _, account := range accounts {
account.NativeBalance = &balance
}
return nil
}
func (td *TdAccount) GetSplitCountByAccountId(id string) (int64, error) {
args := td.Called(id)
return args.Get(0).(int64), args.Error(1)
}
func (td *TdAccount) GetChildCountByAccountId(id string) (int64, error) {
args := td.Called(id)
return args.Get(0).(int64), args.Error(1)
}
func (td *TdAccount) DeleteAccount(id string) error {
return nil
}
func (td *TdAccount) GetOrgUserIds(id string) ([]string, error) {
return []string{"1"}, nil
}
func (td *TdAccount) GetAccount(id string) (*types.Account, error) {
return &types.Account{}, nil
}
func getTestAccounts() []*types.Account {
return []*types.Account{
&types.Account{
Id: "2",
OrgId: "1",
Name: "Assets",
Parent: "1",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
&types.Account{
Id: "3",
OrgId: "1",
Name: "Current Assets",
Parent: "2",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
&types.Account{
Id: "1",
OrgId: "1",
Name: "Root",
Parent: "",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
}
}
func TestCreateAccount(t *testing.T) {
tests := map[string]struct {
err error
account *types.Account
}{
"success": {
err: nil,
account: &types.Account{
Id: "1",
OrgId: "1",
Name: "Cash",
Parent: "3",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
},
"permission error": {
err: errors.New("user does not have permission to access account 1"),
account: &types.Account{
Id: "1",
OrgId: "1",
Name: "Cash",
Parent: "1",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
model := NewModel(td, nil, types.Config{})
err := model.CreateAccount(test.account, "1")
assert.Equal(t, test.err, err)
}
}
func TestUpdateAccount(t *testing.T) {
tests := map[string]struct {
err error
account *types.Account
}{
"success": {
err: nil,
account: &types.Account{
Id: "3",
OrgId: "1",
Name: "Current Assets2",
Parent: "2",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
},
"error": {
err: errors.New("account cannot be its own parent"),
account: &types.Account{
Id: "3",
OrgId: "1",
Name: "Current Assets",
Parent: "3",
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
model := NewModel(td, nil, types.Config{})
err := model.UpdateAccount(test.account, "1")
assert.Equal(t, test.err, err)
if err == nil {
td.AssertExpectations(t)
}
}
}
func TestDeleteAccount(t *testing.T) {
tests := map[string]struct {
err error
accountId string
count int64
}{
"success": {
err: nil,
accountId: "3",
count: 0,
},
"error": {
err: errors.New("Cannot delete an account that has transactions"),
accountId: "3",
count: 1,
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
td.On("GetSplitCountByAccountId", test.accountId).Return(test.count, nil)
td.On("GetChildCountByAccountId", test.accountId).Return(test.count, nil)
model := NewModel(td, nil, types.Config{})
err := model.DeleteAccount(test.accountId, "1", "1")
assert.Equal(t, test.err, err)
if err == nil {
td.AssertExpectations(t)
}
}
}
func TestGetAccounts(t *testing.T) {
tests := map[string]struct {
err error
}{
"success": {
err: nil,
},
// "error": {
// err: errors.New("db error"),
// },
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), test.err)
model := NewModel(td, nil, types.Config{})
accounts, err := model.GetAccounts("1", "1", "")
assert.Equal(t, test.err, err)
if err == nil {
td.AssertExpectations(t)
assert.Equal(t, 3, len(accounts))
assert.Equal(t, false, accounts[0].ReadOnly)
assert.Equal(t, false, accounts[1].ReadOnly)
assert.Equal(t, true, accounts[2].ReadOnly)
}
}
}
func TestGetAccountsWithBalances(t *testing.T) {
tests := map[string]struct {
err error
}{
"success": {
err: nil,
},
"error": {
err: errors.New("db error"),
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), test.err)
model := NewModel(td, nil, types.Config{})
accounts, err := model.GetAccountsWithBalances("1", "1", "", time.Now())
assert.Equal(t, test.err, err)
if err == nil {
td.AssertExpectations(t)
assert.Equal(t, 3, len(accounts))
assert.Equal(t, false, accounts[0].ReadOnly)
assert.Equal(t, false, accounts[1].ReadOnly)
assert.Equal(t, true, accounts[2].ReadOnly)
assert.Equal(t, int64(1000), *accounts[0].Balance)
assert.Equal(t, int64(1000), *accounts[1].Balance)
assert.Equal(t, int64(1000), *accounts[0].NativeBalance)
assert.Equal(t, int64(1000), *accounts[1].NativeBalance)
}
}
}

37
core/model/apikey.go Normal file
View File

@@ -0,0 +1,37 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
)
type ApiKeyInterface interface {
CreateApiKey(*types.ApiKey) error
UpdateApiKey(*types.ApiKey) error
DeleteApiKey(string, string) error
GetApiKeys(string) ([]*types.ApiKey, error)
}
func (model *Model) CreateApiKey(key *types.ApiKey) error {
if key.Id == "" {
return errors.New("id required")
}
return model.db.InsertApiKey(key)
}
func (model *Model) UpdateApiKey(key *types.ApiKey) error {
if key.Id == "" {
return errors.New("id required")
}
return model.db.UpdateApiKey(key)
}
func (model *Model) DeleteApiKey(id string, userId string) error {
return model.db.DeleteApiKey(id, userId)
}
func (model *Model) GetApiKeys(userId string) ([]*types.ApiKey, error) {
return model.db.GetApiKeys(userId)
}

391
core/model/db/account.go Normal file
View File

@@ -0,0 +1,391 @@
package db
import (
"database/sql"
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"math"
"strings"
"time"
)
const emptyAccountId = "00000000000000000000000000000000"
type AccountInterface interface {
InsertAccount(account *types.Account) error
UpdateAccount(account *types.Account) error
GetAccount(string) (*types.Account, error)
GetAccountsByOrgId(orgId string) ([]*types.Account, error)
GetPermissionedAccountIds(string, string, string) ([]string, error)
GetSplitCountByAccountId(id string) (int64, error)
GetChildCountByAccountId(id string) (int64, error)
DeleteAccount(id string) error
AddBalances([]*types.Account, time.Time) error
AddNativeBalancesCost([]*types.Account, time.Time) error
AddNativeBalancesNearestInTime([]*types.Account, time.Time) error
AddBalance(*types.Account, time.Time) error
AddNativeBalanceCost(*types.Account, time.Time) error
AddNativeBalanceNearestInTime(*types.Account, time.Time) error
GetRootAccount(string) (*types.Account, error)
}
func (db *DB) InsertAccount(account *types.Account) error {
account.Inserted = time.Now()
account.Updated = account.Inserted
query := "INSERT INTO account(id,orgId,inserted,updated,name,parent,currency,`precision`,debitBalance) VALUES(UNHEX(?),UNHEX(?),?,?,?,UNHEX(?),?,?,?)"
_, err := db.Exec(
query,
account.Id,
account.OrgId,
util.TimeToMs(account.Inserted),
util.TimeToMs(account.Updated),
account.Name,
account.Parent,
account.Currency,
account.Precision,
account.DebitBalance)
return err
}
func (db *DB) UpdateAccount(account *types.Account) error {
account.Updated = time.Now()
query := "UPDATE account SET updated = ?, name = ?, parent = UNHEX(?), currency = ?, `precision` = ?, debitBalance = ? WHERE id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(account.Updated),
account.Name,
account.Parent,
account.Currency,
account.Precision,
account.DebitBalance,
account.Id)
return err
}
func (db *DB) GetAccount(id string) (*types.Account, error) {
a := types.Account{}
var inserted int64
var updated int64
err := db.QueryRow("SELECT LOWER(HEX(id)),LOWER(HEX(orgId)),inserted,updated,name,LOWER(HEX(parent)),currency,`precision`,debitBalance FROM account WHERE id = UNHEX(?)", id).
Scan(&a.Id, &a.OrgId, &inserted, &updated, &a.Name, &a.Parent, &a.Currency, &a.Precision, &a.DebitBalance)
if a.Parent == emptyAccountId {
a.Parent = ""
}
switch {
case err == sql.ErrNoRows:
return nil, errors.New("Account not found")
case err != nil:
return nil, err
default:
a.Inserted = util.MsToTime(inserted)
a.Updated = util.MsToTime(updated)
return &a, nil
}
}
func (db *DB) GetAccountsByOrgId(orgId string) ([]*types.Account, error) {
rows, err := db.Query("SELECT LOWER(HEX(id)),LOWER(HEX(orgId)),inserted,updated,name,LOWER(HEX(parent)),currency,`precision`,debitBalance FROM account WHERE orgId = UNHEX(?)", orgId)
if err != nil {
return nil, err
}
defer rows.Close()
accounts := make([]*types.Account, 0)
for rows.Next() {
a := new(types.Account)
var inserted int64
var updated int64
err = rows.Scan(&a.Id, &a.OrgId, &inserted, &updated, &a.Name, &a.Parent, &a.Currency, &a.Precision, &a.DebitBalance)
if err != nil {
return nil, err
}
if a.Parent == emptyAccountId {
a.Parent = ""
}
a.Inserted = util.MsToTime(inserted)
a.Updated = util.MsToTime(updated)
accounts = append(accounts, a)
}
err = rows.Err()
if err != nil {
return nil, err
}
return accounts, nil
}
func (db *DB) GetPermissionedAccountIds(orgId string, userId string, tokenId string) ([]string, error) {
// Get user permissions
// TODO incorporate tokens
rows, err := db.Query("SELECT LOWER(HEX(accountId)) FROM permission WHERE orgId = UNHEX(?) AND userId = UNHEX(?)", orgId, userId)
if err != nil {
return nil, err
}
defer rows.Close()
var permissionedAccounts []string
var id string
for rows.Next() {
err := rows.Scan(&id)
if err != nil {
return nil, err
}
permissionedAccounts = append(permissionedAccounts, id)
}
err = rows.Err()
if err != nil {
return nil, err
}
return permissionedAccounts, nil
}
func (db *DB) GetSplitCountByAccountId(id string) (int64, error) {
var count int64
query := "SELECT COUNT(*) FROM split WHERE deleted = false AND accountId = UNHEX(?)"
err := db.QueryRow(query, id).Scan(&count)
return count, err
}
func (db *DB) GetChildCountByAccountId(id string) (int64, error) {
var count int64
query := "SELECT COUNT(*) FROM account WHERE parent = UNHEX(?)"
err := db.QueryRow(query, id).Scan(&count)
return count, err
}
func (db *DB) DeleteAccount(id string) error {
query := "DELETE FROM account WHERE id = UNHEX(?)"
_, err := db.Exec(query, id)
return err
}
func (db *DB) AddBalances(accounts []*types.Account, date time.Time) error {
// TODO optimize
ids := make([]string, len(accounts))
for i, account := range accounts {
ids[i] = "UNHEX(\"" + account.Id + "\")"
}
balanceMap := make(map[string]*int64)
query := "SELECT LOWER(HEX(accountId)), SUM(amount) FROM split WHERE deleted = false AND accountId IN (" +
strings.Join(ids, ",") + ")" +
" AND date < ? GROUP BY accountId"
rows, err := db.Query(query, util.TimeToMs(date))
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var id string
var balance int64
err := rows.Scan(&id, &balance)
if err != nil {
return err
}
balanceMap[id] = &balance
}
err = rows.Err()
if err != nil {
return err
}
for _, account := range accounts {
account.Balance = balanceMap[account.Id]
}
return nil
}
func (db *DB) AddNativeBalancesCost(accounts []*types.Account, date time.Time) error {
// TODO optimize
ids := make([]string, len(accounts))
for i, account := range accounts {
ids[i] = "UNHEX(\"" + account.Id + "\")"
}
balanceMap := make(map[string]*int64)
query := "SELECT LOWER(HEX(accountId)), SUM(nativeAmount) FROM split WHERE deleted = false AND accountId IN (" +
strings.Join(ids, ",") + ")" +
" AND date < ? GROUP BY accountId"
rows, err := db.Query(query, util.TimeToMs(date))
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var id string
var balance int64
err := rows.Scan(&id, &balance)
if err != nil {
return err
}
balanceMap[id] = &balance
}
err = rows.Err()
if err != nil {
return err
}
for _, account := range accounts {
account.NativeBalance = balanceMap[account.Id]
}
return nil
}
func (db *DB) AddNativeBalancesNearestInTime(accounts []*types.Account, date time.Time) error {
// TODO Don't look up org currency every single time
for _, account := range accounts {
err := db.AddNativeBalanceNearestInTime(account, date)
if err != nil {
return err
}
}
return nil
}
func (db *DB) AddBalance(account *types.Account, date time.Time) error {
var balance sql.NullInt64
query := "SELECT SUM(amount) FROM split WHERE deleted = false AND accountId = UNHEX(?) AND date < ?"
err := db.QueryRow(query, account.Id, util.TimeToMs(date)).Scan(&balance)
if err != nil {
return err
}
account.Balance = &balance.Int64
return nil
}
func (db *DB) AddNativeBalanceCost(account *types.Account, date time.Time) error {
var nativeBalance sql.NullInt64
query := "SELECT SUM(nativeAmount) FROM split WHERE deleted = false AND accountId = UNHEX(?) AND date < ?"
err := db.QueryRow(query, account.Id, util.TimeToMs(date)).Scan(&nativeBalance)
if err != nil {
return err
}
account.NativeBalance = &nativeBalance.Int64
return nil
}
func (db *DB) AddNativeBalanceNearestInTime(account *types.Account, date time.Time) error {
var orgCurrency string
var orgPrecision int
query1 := "SELECT currency,`precision` FROM org WHERE id = UNHEX(?)"
err := db.QueryRow(query1, account.OrgId).Scan(&orgCurrency, &orgPrecision)
if err != nil {
return err
}
if account.Balance == nil {
return nil
}
if orgCurrency == account.Currency {
nativeBalance := int64(*account.Balance)
account.NativeBalance = &nativeBalance
return nil
}
var tmp sql.NullInt64
var price float64
query2 := "SELECT ABS(CAST(date AS SIGNED) - ?) AS datediff, price FROM price WHERE currency = ? ORDER BY datediff ASC LIMIT 1"
err = db.QueryRow(query2, util.TimeToMs(date), account.Currency).Scan(&tmp, &price)
if err == sql.ErrNoRows {
nativeBalance := int64(0)
account.NativeBalance = &nativeBalance
} else if err != nil {
return err
}
precisionAdj := math.Pow(10, float64(account.Precision-orgPrecision))
nativeBalance := int64(float64(*account.Balance) * price / precisionAdj)
account.NativeBalance = &nativeBalance
return nil
}
func (db *DB) GetRootAccount(orgId string) (*types.Account, error) {
a := types.Account{}
var inserted int64
var updated int64
err := db.QueryRow(
"SELECT LOWER(HEX(id)),LOWER(HEX(orgId)),inserted,updated,name,LOWER(HEX(parent)),currency,`precision`,debitBalance FROM account WHERE orgId = UNHEX(?) AND parent = UNHEX(?)",
orgId,
emptyAccountId).
Scan(&a.Id, &a.OrgId, &inserted, &updated, &a.Name, &a.Parent, &a.Currency, &a.Precision, &a.DebitBalance)
a.Parent = ""
switch {
case err == sql.ErrNoRows:
return nil, errors.New("Account not found")
case err != nil:
return nil, err
default:
a.Inserted = util.MsToTime(inserted)
a.Updated = util.MsToTime(updated)
return &a, nil
}
}

132
core/model/db/apikey.go Normal file
View File

@@ -0,0 +1,132 @@
package db
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"time"
)
type ApiKeyInterface interface {
InsertApiKey(*types.ApiKey) error
UpdateApiKey(*types.ApiKey) error
DeleteApiKey(string, string) error
GetApiKeys(string) ([]*types.ApiKey, error)
UpdateApiKeyActivity(string) error
}
const apiKeyFields = "LOWER(HEX(id)),inserted,updated,LOWER(HEX(userId)),label"
func (db *DB) InsertApiKey(key *types.ApiKey) error {
key.Inserted = time.Now()
key.Updated = key.Inserted
query := "INSERT INTO apikey(id,inserted,updated,userId,label) VALUES(UNHEX(?),?,?,UNHEX(?),?)"
res, err := db.Exec(
query,
key.Id,
util.TimeToMs(key.Inserted),
util.TimeToMs(key.Updated),
key.UserId,
key.Label,
)
if err != nil {
return err
}
rowCnt, err := res.RowsAffected()
if err != nil {
return err
}
if rowCnt < 1 {
return errors.New("Unable to insert apikey into db")
}
return nil
}
func (db *DB) UpdateApiKey(key *types.ApiKey) error {
key.Updated = time.Now()
query := "UPDATE apikey SET updated = ?, label = ? WHERE deleted IS NULL AND id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(key.Updated),
key.Label,
key.Id,
)
if err != nil {
return err
}
var inserted int64
err = db.QueryRow("SELECT inserted FROM apikey WHERE id = UNHEX(?)", key.Id).Scan(&inserted)
if err != nil {
return err
}
key.Inserted = util.MsToTime(inserted)
return nil
}
func (db *DB) DeleteApiKey(id string, userId string) error {
query := "UPDATE apikey SET deleted = ? WHERE id = UNHEX(?) AND userId = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(time.Now()),
id,
userId,
)
return err
}
func (db *DB) GetApiKeys(userId string) ([]*types.ApiKey, error) {
rows, err := db.Query("SELECT "+apiKeyFields+" from apikey WHERE deleted IS NULL AND userId = UNHEX(?)", userId)
if err != nil {
return nil, err
}
defer rows.Close()
keys := make([]*types.ApiKey, 0)
for rows.Next() {
k := new(types.ApiKey)
var inserted int64
var updated int64
err = rows.Scan(&k.Id, &inserted, &updated, &k.UserId, &k.Label)
if err != nil {
return nil, err
}
k.Inserted = util.MsToTime(inserted)
k.Updated = util.MsToTime(updated)
keys = append(keys, k)
}
err = rows.Err()
if err != nil {
return nil, err
}
return keys, nil
}
func (db *DB) UpdateApiKeyActivity(id string) error {
query := "UPDATE apikey SET updated = ? WHERE id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(time.Now()),
id,
)
return err
}

76
core/model/db/db.go Normal file
View File

@@ -0,0 +1,76 @@
package db
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
type DB struct {
*sql.DB
}
type Datastore interface {
Escape(string) string
UserInterface
OrgInterface
AccountInterface
TransactionInterface
PriceInterface
SessionInterface
ApiKeyInterface
}
func NewDB(dataSourceName string) (*DB, error) {
var err error
db, err := sql.Open("mysql", dataSourceName)
if err != nil {
return nil, err
}
if err = db.Ping(); err != nil {
return nil, err
}
return &DB{db}, nil
}
func (db *DB) Escape(sql string) string {
dest := make([]byte, 0, 2*len(sql))
var escape byte
for i := 0; i < len(sql); i++ {
c := sql[i]
escape = 0
switch c {
case 0: /* Must be escaped for 'mysql' */
escape = '0'
break
case '\n': /* Must be escaped for logs */
escape = 'n'
break
case '\r':
escape = 'r'
break
case '\\':
escape = '\\'
break
case '\'':
escape = '\''
break
case '"': /* Better safe than sorry */
escape = '"'
break
case '\032': /* This gives problems on Win32 */
escape = 'Z'
}
if escape != 0 {
dest = append(dest, '\\', escape)
} else {
dest = append(dest, c)
}
}
return string(dest)
}

370
core/model/db/org.go Normal file
View File

@@ -0,0 +1,370 @@
package db
import (
"database/sql"
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"time"
)
type OrgInterface interface {
CreateOrg(*types.Org, string, []*types.Account) error
UpdateOrg(*types.Org) error
GetOrg(string, string) (*types.Org, error)
GetOrgs(string) ([]*types.Org, error)
GetOrgUserIds(string) ([]string, error)
InsertInvite(*types.Invite) error
AcceptInvite(*types.Invite, string) error
GetInvites(string) ([]*types.Invite, error)
GetInvite(string) (*types.Invite, error)
DeleteInvite(string) error
}
const orgFields = "LOWER(HEX(o.id)),o.inserted,o.updated,o.name,o.currency,o.`precision`"
const inviteFields = "i.id,LOWER(HEX(i.orgId)),i.inserted,i.updated,i.email,i.accepted"
func (db *DB) CreateOrg(org *types.Org, userId string, accounts []*types.Account) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
org.Inserted = time.Now()
org.Updated = org.Inserted
// create org
query1 := "INSERT INTO org(id,inserted,updated,name,currency,`precision`) VALUES(UNHEX(?),?,?,?,?,?)"
res, err := tx.Exec(
query1,
org.Id,
util.TimeToMs(org.Inserted),
util.TimeToMs(org.Updated),
org.Name,
org.Currency,
org.Precision,
)
if err != nil {
return
}
// associate user with org
query2 := "INSERT INTO userorg(userId,orgId,admin) VALUES(UNHEX(?),UNHEX(?), 1)"
res, err = tx.Exec(query2, userId, org.Id)
if err != nil {
return
}
_, err = res.LastInsertId()
if err != nil {
return
}
// create Accounts: Root, Assets, Liabilities, Equity, Income, Expenses
for _, account := range accounts {
query := "INSERT INTO account(id,orgId,inserted,updated,name,parent,currency,`precision`,debitBalance) VALUES (UNHEX(?),UNHEX(?),?,?,?,UNHEX(?),?,?,?)"
if _, err = tx.Exec(
query,
account.Id,
org.Id,
util.TimeToMs(org.Inserted),
util.TimeToMs(org.Updated),
account.Name,
account.Parent,
account.Currency,
account.Precision,
account.DebitBalance,
); err != nil {
return
}
}
permissionId, err := util.NewGuid()
if err != nil {
return
}
// Grant root permission to user
query3 := "INSERT INTO permission (id,userId,orgId,accountId,type,inserted,updated) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),UNHEX(?),?,?,?)"
_, err = tx.Exec(
query3,
permissionId,
userId,
org.Id,
accounts[0].Id,
0,
util.TimeToMs(org.Inserted),
util.TimeToMs(org.Updated),
)
return
}
func (db *DB) UpdateOrg(org *types.Org) error {
org.Updated = time.Now()
query := "UPDATE org SET updated = ?, name = ? WHERE id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(org.Updated),
org.Name,
org.Id,
)
return err
}
func (db *DB) GetOrg(orgId string, userId string) (*types.Org, error) {
var o types.Org
var inserted int64
var updated int64
err := db.QueryRow("SELECT "+orgFields+" FROM org o JOIN userorg ON userorg.orgId = o.id WHERE o.id = UNHEX(?) AND userorg.userId = UNHEX(?)", orgId, userId).
Scan(&o.Id, &inserted, &updated, &o.Name, &o.Currency, &o.Precision)
switch {
case err == sql.ErrNoRows:
return nil, errors.New("Org not found")
case err != nil:
return nil, err
default:
o.Inserted = util.MsToTime(inserted)
o.Updated = util.MsToTime(updated)
return &o, nil
}
}
func (db *DB) GetOrgs(userId string) ([]*types.Org, error) {
rows, err := db.Query("SELECT "+orgFields+" from org o JOIN userorg ON userorg.orgId = o.id WHERE userorg.userId = UNHEX(?)", userId)
if err != nil {
return nil, err
}
defer rows.Close()
orgs := make([]*types.Org, 0)
for rows.Next() {
o := new(types.Org)
var inserted int64
var updated int64
err = rows.Scan(&o.Id, &inserted, &updated, &o.Name, &o.Currency, &o.Precision)
if err != nil {
return nil, err
}
o.Inserted = util.MsToTime(inserted)
o.Updated = util.MsToTime(updated)
orgs = append(orgs, o)
}
err = rows.Err()
if err != nil {
return nil, err
}
return orgs, nil
}
func (db *DB) GetOrgUserIds(orgId string) ([]string, error) {
rows, err := db.Query("SELECT LOWER(HEX(userId)) FROM userorg WHERE orgId = UNHEX(?)", orgId)
if err != nil {
return nil, err
}
defer rows.Close()
userIds := make([]string, 0)
for rows.Next() {
var userId string
err = rows.Scan(&userId)
if err != nil {
return nil, err
}
userIds = append(userIds, userId)
}
err = rows.Err()
if err != nil {
return nil, err
}
return userIds, nil
}
func (db *DB) InsertInvite(invite *types.Invite) error {
invite.Inserted = time.Now()
invite.Updated = invite.Inserted
query := "INSERT INTO invite(id,orgId,inserted,updated,email,accepted) VALUES(?,UNHEX(?),?,?,?,?)"
_, err := db.Exec(
query,
invite.Id,
invite.OrgId,
util.TimeToMs(invite.Inserted),
util.TimeToMs(invite.Updated),
invite.Email,
false,
)
return err
}
func (db *DB) AcceptInvite(invite *types.Invite, userId string) error {
invite.Updated = time.Now()
// Get root account for permission
rootAccount, err := db.GetRootAccount(invite.OrgId)
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
// associate user with org
query1 := "INSERT INTO userorg(userId,orgId,admin) VALUES(UNHEX(?),UNHEX(?), 0)"
_, err = tx.Exec(query1, userId, invite.OrgId)
if err != nil {
return err
}
query2 := "UPDATE invite SET accepted = 1, updated = ? WHERE id = ?"
_, err = tx.Exec(query2, util.TimeToMs(invite.Updated), invite.Id)
// Grant root permission to user
permissionId, err := util.NewGuid()
if err != nil {
return err
}
query3 := "INSERT INTO permission (id,userId,orgId,accountId,type,inserted,updated) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),UNHEX(?),?,?,?)"
_, err = tx.Exec(
query3,
permissionId,
userId,
invite.OrgId,
rootAccount.Id,
0,
util.TimeToMs(invite.Updated),
util.TimeToMs(invite.Updated),
)
return err
}
func (db *DB) GetInvites(orgId string) ([]*types.Invite, error) {
// don't include expired invoices
cutoff := util.TimeToMs(time.Now()) - 7*24*60*60*1000
rows, err := db.Query("SELECT "+inviteFields+" FROM invite i WHERE orgId = UNHEX(?) AND inserted > ?", orgId, cutoff)
if err != nil {
return nil, err
}
defer rows.Close()
invites := make([]*types.Invite, 0)
for rows.Next() {
i := new(types.Invite)
var inserted int64
var updated int64
err = rows.Scan(&i.Id, &i.OrgId, &inserted, &updated, &i.Email, &i.Accepted)
if err != nil {
return nil, err
}
i.Inserted = util.MsToTime(inserted)
i.Updated = util.MsToTime(updated)
invites = append(invites, i)
}
err = rows.Err()
if err != nil {
return nil, err
}
return invites, nil
}
func (db *DB) GetInvite(id string) (*types.Invite, error) {
var i types.Invite
var inserted int64
var updated int64
err := db.QueryRow("SELECT "+inviteFields+" FROM invite i WHERE i.id = ?", id).
Scan(&i.Id, &i.OrgId, &inserted, &updated, &i.Email, &i.Accepted)
switch {
case err == sql.ErrNoRows:
return nil, errors.New("Invite not found")
case err != nil:
return nil, err
default:
i.Inserted = util.MsToTime(inserted)
i.Updated = util.MsToTime(updated)
return &i, nil
}
}
func (db *DB) DeleteInvite(id string) error {
query := "DELETE FROM invite WHERE id = ?"
_, err := db.Exec(
query,
id,
)
return err
}

156
core/model/db/price.go Normal file
View File

@@ -0,0 +1,156 @@
package db
import (
"database/sql"
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"time"
)
type PriceInterface interface {
InsertPrice(*types.Price) error
GetPriceById(string) (*types.Price, error)
DeletePrice(string) error
GetPricesNearestInTime(string, time.Time) ([]*types.Price, error)
GetPricesByCurrency(string, string) ([]*types.Price, error)
}
const priceFields = "LOWER(HEX(p.id)),LOWER(HEX(p.orgId)),p.currency,p.date,p.inserted,p.updated,p.price"
func (db *DB) InsertPrice(price *types.Price) error {
price.Inserted = time.Now()
price.Updated = price.Inserted
if price.Date.IsZero() {
price.Date = price.Inserted
}
query := "INSERT INTO price(id,orgId,currency,date,inserted,updated,price) VALUES(UNHEX(?),UNHEX(?),?,?,?,?,?)"
_, err := db.Exec(
query,
price.Id,
price.OrgId,
price.Currency,
util.TimeToMs(price.Date),
util.TimeToMs(price.Inserted),
util.TimeToMs(price.Updated),
price.Price,
)
return err
}
func (db *DB) GetPriceById(id string) (*types.Price, error) {
var p types.Price
var date int64
var inserted int64
var updated int64
err := db.QueryRow("SELECT "+priceFields+" FROM price p WHERE id = UNHEX(?)", id).
Scan(&p.Id, &p.OrgId, &p.Currency, &date, &inserted, &updated, &p.Price)
switch {
case err == sql.ErrNoRows:
return nil, errors.New("Price not found")
case err != nil:
return nil, err
default:
p.Date = util.MsToTime(date)
p.Inserted = util.MsToTime(inserted)
p.Updated = util.MsToTime(updated)
return &p, nil
}
}
func (db *DB) DeletePrice(id string) error {
query := "DELETE FROM price WHERE id = UNHEX(?)"
_, err := db.Exec(query, id)
return err
}
func (db *DB) GetPricesNearestInTime(orgId string, date time.Time) ([]*types.Price, error) {
qSelect := "SELECT " + priceFields
qFrom := " FROM price p"
qJoin := " LEFT OUTER JOIN price p2 ON p.currency = p2.currency AND p.orgId = p2.orgId AND ABS(CAST(p.date AS SIGNED) - ?) > ABS(CAST(p2.date AS SIGNED) - ?)"
qWhere := " WHERE p2.id IS NULL AND p.orgId = UNHEX(?)"
query := qSelect + qFrom + qJoin + qWhere
rows, err := db.Query(query, util.TimeToMs(date), util.TimeToMs(date), orgId)
if err != nil {
return nil, err
}
defer rows.Close()
prices := make([]*types.Price, 0)
for rows.Next() {
var date int64
var inserted int64
var updated int64
p := new(types.Price)
err = rows.Scan(&p.Id, &p.OrgId, &p.Currency, &date, &inserted, &updated, &p.Price)
if err != nil {
return nil, err
}
p.Date = util.MsToTime(date)
p.Inserted = util.MsToTime(inserted)
p.Updated = util.MsToTime(updated)
prices = append(prices, p)
}
err = rows.Err()
if err != nil {
return nil, err
}
return prices, nil
}
func (db *DB) GetPricesByCurrency(orgId string, currency string) ([]*types.Price, error) {
qSelect := "SELECT " + priceFields
qFrom := " FROM price p"
qWhere := " WHERE p.orgId = UNHEX(?) AND p.currency = ?"
pOrder := " ORDER BY date ASC"
query := qSelect + qFrom + qWhere + pOrder
rows, err := db.Query(query, orgId, currency)
if err != nil {
return nil, err
}
defer rows.Close()
prices := make([]*types.Price, 0)
for rows.Next() {
var date int64
var inserted int64
var updated int64
p := new(types.Price)
err = rows.Scan(&p.Id, &p.OrgId, &p.Currency, &date, &inserted, &updated, &p.Price)
if err != nil {
return nil, err
}
p.Date = util.MsToTime(date)
p.Inserted = util.MsToTime(inserted)
p.Updated = util.MsToTime(updated)
prices = append(prices, p)
}
err = rows.Err()
if err != nil {
return nil, err
}
return prices, nil
}

65
core/model/db/session.go Normal file
View File

@@ -0,0 +1,65 @@
package db
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"time"
)
type SessionInterface interface {
InsertSession(*types.Session) error
DeleteSession(string, string) error
UpdateSessionActivity(string) error
}
func (db *DB) InsertSession(session *types.Session) error {
session.Inserted = time.Now()
session.Updated = session.Inserted
query := "INSERT INTO session(id,inserted,updated,userId) VALUES(UNHEX(?),?,?,UNHEX(?))"
res, err := db.Exec(
query,
session.Id,
util.TimeToMs(session.Inserted),
util.TimeToMs(session.Updated),
session.UserId,
)
if err != nil {
return err
}
rowCnt, err := res.RowsAffected()
if err != nil {
return err
}
if rowCnt < 1 {
return errors.New("Unable to insert session into db")
}
return nil
}
func (db *DB) DeleteSession(id string, userId string) error {
query := "UPDATE session SET `terminated` = ? WHERE id = UNHEX(?) AND userId = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(time.Now()),
id,
userId,
)
return err
}
func (db *DB) UpdateSessionActivity(id string) error {
query := "UPDATE session SET updated = ? WHERE id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(time.Now()),
id,
)
return err
}

View File

@@ -0,0 +1,558 @@
package db
import (
"database/sql"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"strconv"
"strings"
"time"
)
const txFields = "LOWER(HEX(id)),LOWER(HEX(orgId)),LOWER(HEX(userId)),date,inserted,updated,description,data,deleted"
const splitFields = "id,LOWER(HEX(transactionId)),LOWER(HEX(accountId)),date,inserted,updated,amount,nativeAmount,deleted"
type TransactionInterface interface {
InsertTransaction(*types.Transaction) error
GetTransactionById(string) (*types.Transaction, error)
GetTransactionsByAccount(string, *types.QueryOptions) ([]*types.Transaction, error)
GetTransactionsByOrg(string, *types.QueryOptions, []string) ([]*types.Transaction, error)
DeleteTransaction(string) error
DeleteAndInsertTransaction(string, *types.Transaction) error
}
func (db *DB) InsertTransaction(transaction *types.Transaction) (err error) {
// Save to db
dbTx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
dbTx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
dbTx.Rollback()
} else {
err = dbTx.Commit()
}
}()
// save tx
query1 := "INSERT INTO transaction(id,orgId,userId,date,inserted,updated,description,data) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),?,?,?,?,?)"
_, err = dbTx.Exec(
query1,
transaction.Id,
transaction.OrgId,
transaction.UserId,
util.TimeToMs(transaction.Date),
util.TimeToMs(transaction.Inserted),
util.TimeToMs(transaction.Updated),
transaction.Description,
transaction.Data,
)
if err != nil {
return
}
// save splits
for _, split := range transaction.Splits {
query := "INSERT INTO split(transactionId,accountId,date,inserted,updated,amount,nativeAmount) VALUES (UNHEX(?),UNHEX(?),?,?,?,?,?)"
_, err = dbTx.Exec(
query,
transaction.Id,
split.AccountId,
util.TimeToMs(transaction.Date),
util.TimeToMs(transaction.Inserted),
util.TimeToMs(transaction.Updated),
split.Amount,
split.NativeAmount)
if err != nil {
return
}
}
return
}
func (db *DB) GetTransactionById(id string) (*types.Transaction, error) {
row := db.QueryRow("SELECT "+txFields+" FROM transaction WHERE id = UNHEX(?)", id)
t, err := db.unmarshalTransaction(row)
if err != nil {
return nil, err
}
rows, err := db.Query("SELECT "+splitFields+" FROM split WHERE transactionId = UNHEX(?) ORDER BY id", t.Id)
if err != nil {
return nil, err
}
t.Splits, err = db.unmarshalSplits(rows)
if err != nil {
return nil, err
}
return t, nil
}
func (db *DB) GetTransactionsByAccount(accountId string, options *types.QueryOptions) ([]*types.Transaction, error) {
query := "SELECT LOWER(HEX(s.transactionId)) FROM split s"
if options.DescriptionStartsWith != "" {
query = query + " JOIN transaction t ON t.id = s.transactionId"
}
query = query + " WHERE s.accountId = UNHEX(?)"
query = db.addOptionsToQuery(query, options)
rows, err := db.Query(query, accountId)
if err != nil {
return nil, err
}
defer rows.Close()
var ids []string
for rows.Next() {
var id string
err = rows.Scan(&id)
if err != nil {
return nil, err
}
ids = append(ids, "UNHEX(\""+id+"\")")
}
err = rows.Err()
if err != nil {
return nil, err
}
if len(ids) == 0 {
return make([]*types.Transaction, 0), nil
}
query = "SELECT " + txFields + " FROM transaction WHERE id IN (" + strings.Join(ids, ",") + ")"
query = db.addSortToQuery(query, options)
rows, err = db.Query(query)
if err != nil {
return nil, err
}
transactions, err := db.unmarshalTransactions(rows)
if err != nil {
return nil, err
}
transactionMap := make(map[string]*types.Transaction)
for _, t := range transactions {
transactionMap[t.Id] = t
}
rows, err = db.Query("SELECT " + splitFields + " FROM split WHERE transactionId IN (" + strings.Join(ids, ",") + ") ORDER BY id")
if err != nil {
return nil, err
}
splits, err := db.unmarshalSplits(rows)
if err != nil {
return nil, err
}
for _, s := range splits {
transaction := transactionMap[s.TransactionId]
transaction.Splits = append(transaction.Splits, s)
}
return transactions, nil
}
func (db *DB) GetTransactionsByOrg(orgId string, options *types.QueryOptions, accountIds []string) ([]*types.Transaction, error) {
if len(accountIds) == 0 {
return make([]*types.Transaction, 0), nil
}
for i, accountId := range accountIds {
accountIds[i] = "UNHEX(\"" + accountId + "\")"
}
query := "SELECT DISTINCT LOWER(HEX(s.transactionId)),s.date,s.inserted,s.updated FROM split s"
if options.DescriptionStartsWith != "" {
query = query + " JOIN transaction t ON t.id = s.transactionId"
}
query = query + " WHERE s.accountId IN (" + strings.Join(accountIds, ",") + ")"
query = db.addOptionsToQuery(query, options)
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
ids := []string{}
for rows.Next() {
var id string
var date int64
var inserted int64
var updated int64
err = rows.Scan(&id, &date, &inserted, &updated)
if err != nil {
return nil, err
}
ids = append(ids, "UNHEX(\""+id+"\")")
}
err = rows.Err()
if err != nil {
return nil, err
}
if len(ids) == 0 {
return make([]*types.Transaction, 0), nil
}
query = "SELECT " + txFields + " FROM transaction WHERE id IN (" + strings.Join(ids, ",") + ")"
query = db.addSortToQuery(query, options)
rows, err = db.Query(query)
if err != nil {
return nil, err
}
transactions, err := db.unmarshalTransactions(rows)
if err != nil {
return nil, err
}
transactionMap := make(map[string]*types.Transaction)
for _, t := range transactions {
transactionMap[t.Id] = t
}
rows, err = db.Query("SELECT " + splitFields + " FROM split WHERE transactionId IN (" + strings.Join(ids, ",") + ") ORDER BY id")
if err != nil {
return nil, err
}
splits, err := db.unmarshalSplits(rows)
if err != nil {
return nil, err
}
for _, s := range splits {
transaction := transactionMap[s.TransactionId]
transaction.Splits = append(transaction.Splits, s)
}
return transactions, nil
}
func (db *DB) DeleteTransaction(id string) (err error) {
dbTx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
dbTx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
dbTx.Rollback()
} else {
err = dbTx.Commit()
}
}()
updatedTime := util.TimeToMs(time.Now())
// mark splits as deleted
query1 := "UPDATE split SET updated = ?, deleted = true WHERE transactionId = UNHEX(?)"
_, err = dbTx.Exec(
query1,
updatedTime,
id,
)
if err != nil {
return
}
// mark transaction as deleted
query2 := "UPDATE transaction SET updated = ?, deleted = true WHERE id = UNHEX(?)"
_, err = dbTx.Exec(
query2,
updatedTime,
id,
)
if err != nil {
return
}
return
}
func (db *DB) DeleteAndInsertTransaction(oldId string, transaction *types.Transaction) (err error) {
// Save to db
dbTx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
dbTx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
dbTx.Rollback()
} else {
err = dbTx.Commit()
}
}()
updatedTime := util.TimeToMs(transaction.Updated)
// mark splits as deleted
query1 := "UPDATE split SET updated = ?, deleted = true WHERE transactionId = UNHEX(?)"
_, err = dbTx.Exec(
query1,
updatedTime,
oldId,
)
if err != nil {
return
}
// mark transaction as deleted
query2 := "UPDATE transaction SET updated = ?, deleted = true WHERE id = UNHEX(?)"
_, err = dbTx.Exec(
query2,
updatedTime,
oldId,
)
if err != nil {
return
}
// save new tx
query3 := "INSERT INTO transaction(id,orgId,userId,date,inserted,updated,description,data) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),?,?,?,?,?)"
_, err = dbTx.Exec(
query3,
transaction.Id,
transaction.OrgId,
transaction.UserId,
util.TimeToMs(transaction.Date),
util.TimeToMs(transaction.Inserted),
updatedTime,
transaction.Description,
transaction.Data,
)
if err != nil {
return
}
// save splits
for _, split := range transaction.Splits {
query := "INSERT INTO split(transactionId,accountId,date,inserted,updated,amount,nativeAmount) VALUES (UNHEX(?),UNHEX(?),?,?,?,?,?)"
_, err = dbTx.Exec(
query,
transaction.Id,
split.AccountId,
util.TimeToMs(transaction.Date),
util.TimeToMs(transaction.Inserted),
updatedTime,
split.Amount,
split.NativeAmount)
if err != nil {
return
}
}
return
}
func (db *DB) unmarshalTransaction(row *sql.Row) (*types.Transaction, error) {
t := new(types.Transaction)
var date int64
var inserted int64
var updated int64
err := row.Scan(&t.Id, &t.OrgId, &t.UserId, &date, &inserted, &updated, &t.Description, &t.Data, &t.Deleted)
if err != nil {
return nil, err
}
t.Date = util.MsToTime(date)
t.Inserted = util.MsToTime(inserted)
t.Updated = util.MsToTime(updated)
return t, nil
}
func (db *DB) unmarshalTransactions(rows *sql.Rows) ([]*types.Transaction, error) {
defer rows.Close()
transactions := make([]*types.Transaction, 0)
for rows.Next() {
t := new(types.Transaction)
var date int64
var inserted int64
var updated int64
err := rows.Scan(&t.Id, &t.OrgId, &t.UserId, &date, &inserted, &updated, &t.Description, &t.Data, &t.Deleted)
if err != nil {
return nil, err
}
t.Date = util.MsToTime(date)
t.Inserted = util.MsToTime(inserted)
t.Updated = util.MsToTime(updated)
transactions = append(transactions, t)
}
err := rows.Err()
if err != nil {
return nil, err
}
return transactions, nil
}
func (db *DB) unmarshalSplits(rows *sql.Rows) ([]*types.Split, error) {
defer rows.Close()
splits := make([]*types.Split, 0)
for rows.Next() {
s := new(types.Split)
var id int64
var date int64
var inserted int64
var updated int64
var deleted bool
err := rows.Scan(&id, &s.TransactionId, &s.AccountId, &date, &inserted, &updated, &s.Amount, &s.NativeAmount, &deleted)
if err != nil {
return nil, err
}
splits = append(splits, s)
}
err := rows.Err()
if err != nil {
return nil, err
}
return splits, nil
}
func (db *DB) addOptionsToQuery(query string, options *types.QueryOptions) string {
if options.IncludeDeleted != true {
query += " AND s.deleted = false"
}
if options.SinceInserted != 0 {
query += " AND s.inserted > " + strconv.Itoa(options.SinceInserted)
}
if options.SinceUpdated != 0 {
query += " AND s.updated > " + strconv.Itoa(options.SinceUpdated)
}
if options.BeforeInserted != 0 {
query += " AND s.inserted < " + strconv.Itoa(options.BeforeInserted)
}
if options.BeforeUpdated != 0 {
query += " AND s.updated < " + strconv.Itoa(options.BeforeUpdated)
}
if options.StartDate != 0 {
query += " AND s.date >= " + strconv.Itoa(options.StartDate)
}
if options.EndDate != 0 {
query += " AND s.date < " + strconv.Itoa(options.EndDate)
}
if options.DescriptionStartsWith != "" {
query += " AND t.description LIKE '" + db.Escape(options.DescriptionStartsWith) + "%'"
}
if options.Sort == "updated-asc" {
query += " ORDER BY s.updated ASC"
} else {
query += " ORDER BY s.date DESC, s.inserted DESC"
}
if options.Limit != 0 && options.Skip != 0 {
query += " LIMIT " + strconv.Itoa(options.Skip) + ", " + strconv.Itoa(options.Limit)
} else if options.Limit != 0 {
query += " LIMIT " + strconv.Itoa(options.Limit)
}
return query
}
func (db *DB) addSortToQuery(query string, options *types.QueryOptions) string {
if options.Sort == "updated-asc" {
query += " ORDER BY updated ASC"
} else {
query += " ORDER BY date DESC, inserted DESC"
}
return query
}

264
core/model/db/user.go Normal file
View File

@@ -0,0 +1,264 @@
package db
import (
"database/sql"
"errors"
"fmt"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"time"
)
const userFields = "LOWER(HEX(u.id)),u.inserted,u.updated,u.firstName,u.lastName,u.email,u.passwordHash,u.agreeToTerms,u.passwordReset,u.emailVerified,u.emailVerifyCode"
type UserInterface interface {
InsertUser(*types.User) error
VerifyUser(string) error
UpdateUser(*types.User) error
UpdateUserResetPassword(*types.User) error
GetVerifiedUserByEmail(string) (*types.User, error)
GetUserByActiveSession(string) (*types.User, error)
GetUserByApiKey(string) (*types.User, error)
GetUserByResetCode(string) (*types.User, error)
GetOrgAdmins(string) ([]*types.User, error)
}
func (db *DB) InsertUser(user *types.User) error {
user.Inserted = time.Now()
user.Updated = user.Inserted
user.PasswordReset = ""
query := "INSERT INTO user(id,inserted,updated,firstName,lastName,email,passwordHash,agreeToTerms,passwordReset,emailVerified,emailVerifyCode) VALUES(UNHEX(?),?,?,?,?,?,?,?,?,?,?)"
res, err := db.Exec(
query,
user.Id,
util.TimeToMs(user.Inserted),
util.TimeToMs(user.Updated),
user.FirstName,
user.LastName,
user.Email,
user.PasswordHash,
user.AgreeToTerms,
user.PasswordReset,
user.EmailVerified,
user.EmailVerifyCode,
)
if err != nil {
return err
}
rowCnt, err := res.RowsAffected()
if err != nil {
return err
}
if rowCnt < 1 {
return errors.New("Unable to insert user into db")
}
return nil
}
func (db *DB) VerifyUser(code string) error {
query := "UPDATE user SET updated = ?, emailVerified = 1 WHERE emailVerifyCode = ?"
res, err := db.Exec(
query,
util.TimeToMs(time.Now()),
code,
)
count, err := res.RowsAffected()
if err != nil {
return nil
}
if count == 0 {
return errors.New("Invalid code")
}
return nil
}
func (db *DB) UpdateUser(user *types.User) error {
user.Updated = time.Now()
query := "UPDATE user SET updated = ?, passwordHash = ?, passwordReset = ? WHERE id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(user.Updated),
user.PasswordHash,
"",
user.Id,
)
return err
}
func (db *DB) UpdateUserResetPassword(user *types.User) error {
user.Updated = time.Now()
query := "UPDATE user SET updated = ?, passwordReset = ? WHERE id = UNHEX(?)"
_, err := db.Exec(
query,
util.TimeToMs(user.Updated),
user.PasswordReset,
user.Id,
)
return err
}
func (db *DB) GetVerifiedUserByEmail(email string) (*types.User, error) {
query := "SELECT " + userFields + " FROM user u WHERE email = ? AND emailVerified = 1"
row := db.QueryRow(query, email)
u, err := db.unmarshalUser(row)
if err != nil {
return nil, err
}
return u, nil
}
func (db *DB) GetUserByActiveSession(sessionId string) (*types.User, error) {
qSelect := "SELECT " + userFields
qFrom := " FROM user u"
qJoin := " JOIN session s ON s.userId = u.id"
qWhere := " WHERE s.terminated IS NULL AND s.id = UNHEX(?)"
query := qSelect + qFrom + qJoin + qWhere
row := db.QueryRow(query, sessionId)
u, err := db.unmarshalUser(row)
if err != nil {
return nil, err
}
return u, nil
}
func (db *DB) GetUserByApiKey(keyId string) (*types.User, error) {
qSelect := "SELECT " + userFields
qFrom := " FROM user u"
qJoin := " JOIN apikey a ON a.userId = u.id"
qWhere := " WHERE a.deleted IS NULL AND a.id = UNHEX(?)"
query := qSelect + qFrom + qJoin + qWhere
row := db.QueryRow(query, keyId)
u, err := db.unmarshalUser(row)
if err != nil {
return nil, err
}
return u, nil
}
func (db *DB) GetUserByResetCode(code string) (*types.User, error) {
qSelect := "SELECT " + userFields
qFrom := " FROM user u"
qWhere := " WHERE u.passwordReset = ?"
query := qSelect + qFrom + qWhere
row := db.QueryRow(query, code)
u, err := db.unmarshalUser(row)
if err != nil {
return nil, err
}
fmt.Println(u)
return u, nil
}
func (db *DB) GetOrgAdmins(orgId string) ([]*types.User, error) {
qSelect := "SELECT " + userFields
qFrom := " FROM user u"
qJoin := " JOIN userorg uo ON uo.userId = u.id"
qWhere := " WHERE uo.admin = true AND uo.orgId = UNHEX(?)"
query := qSelect + qFrom + qJoin + qWhere
rows, err := db.Query(query, orgId)
if err != nil {
return nil, err
}
return db.unmarshalUsers(rows)
}
func (db *DB) unmarshalUser(row *sql.Row) (*types.User, error) {
u := new(types.User)
var inserted int64
var updated int64
err := row.Scan(
&u.Id,
&inserted,
&updated,
&u.FirstName,
&u.LastName,
&u.Email,
&u.PasswordHash,
&u.AgreeToTerms,
&u.PasswordReset,
&u.EmailVerified,
&u.EmailVerifyCode,
)
if err != nil {
return nil, err
}
u.Inserted = util.MsToTime(inserted)
u.Updated = util.MsToTime(updated)
return u, nil
}
func (db *DB) unmarshalUsers(rows *sql.Rows) ([]*types.User, error) {
defer rows.Close()
users := make([]*types.User, 0)
for rows.Next() {
u := new(types.User)
var inserted int64
var updated int64
err := rows.Scan(
&u.Id,
&inserted,
&updated,
&u.FirstName,
&u.LastName,
&u.Email,
&u.PasswordHash,
&u.AgreeToTerms,
&u.PasswordReset,
&u.EmailVerified,
&u.EmailVerifyCode,
)
if err != nil {
return nil, err
}
u.Inserted = util.MsToTime(inserted)
u.Updated = util.MsToTime(updated)
users = append(users, u)
}
err := rows.Err()
return users, err
}

31
core/model/model.go Normal file
View File

@@ -0,0 +1,31 @@
package model
import (
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
)
var Instance Interface
type Model struct {
db db.Datastore
bcrypt util.Bcrypt
config types.Config
}
type Interface interface {
UserInterface
OrgInterface
AccountInterface
TransactionInterface
PriceInterface
SessionInterface
ApiKeyInterface
}
func NewModel(db db.Datastore, bcrypt util.Bcrypt, config types.Config) *Model {
model := &Model{db: db, bcrypt: bcrypt, config: config}
Instance = model
return model
}

294
core/model/org.go Normal file
View File

@@ -0,0 +1,294 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"time"
)
type OrgInterface interface {
CreateOrg(*types.Org, string) error
UpdateOrg(*types.Org, string) error
GetOrg(string, string) (*types.Org, error)
GetOrgs(string) ([]*types.Org, error)
CreateInvite(*types.Invite, string) error
AcceptInvite(*types.Invite, string) error
GetInvites(string, string) ([]*types.Invite, error)
DeleteInvite(string, string) error
}
func (model *Model) CreateOrg(org *types.Org, userId string) error {
if org.Name == "" {
return errors.New("name required")
}
if org.Currency == "" {
return errors.New("currency required")
}
accounts := make([]*types.Account, 6)
id, err := util.NewGuid()
if err != nil {
return err
}
accounts[0] = &types.Account{
Id: id,
Name: "Root",
Parent: "",
Currency: org.Currency,
Precision: org.Precision,
DebitBalance: true,
}
id, err = util.NewGuid()
if err != nil {
return err
}
accounts[1] = &types.Account{
Id: id,
Name: "Assets",
Parent: accounts[0].Id,
Currency: org.Currency,
Precision: org.Precision,
DebitBalance: true,
}
id, err = util.NewGuid()
if err != nil {
return err
}
accounts[2] = &types.Account{
Id: id,
Name: "Liabilities",
Parent: accounts[0].Id,
Currency: org.Currency,
Precision: org.Precision,
DebitBalance: false,
}
id, err = util.NewGuid()
if err != nil {
return err
}
accounts[3] = &types.Account{
Id: id,
Name: "Equity",
Parent: accounts[0].Id,
Currency: org.Currency,
Precision: org.Precision,
DebitBalance: false,
}
id, err = util.NewGuid()
if err != nil {
return err
}
accounts[4] = &types.Account{
Id: id,
Name: "Income",
Parent: accounts[0].Id,
Currency: org.Currency,
Precision: org.Precision,
DebitBalance: false,
}
id, err = util.NewGuid()
if err != nil {
return err
}
accounts[5] = &types.Account{
Id: id,
Name: "Expenses",
Parent: accounts[0].Id,
Currency: org.Currency,
Precision: org.Precision,
DebitBalance: true,
}
return model.db.CreateOrg(org, userId, accounts)
}
func (model *Model) UpdateOrg(org *types.Org, userId string) error {
_, err := model.GetOrg(org.Id, userId)
if err != nil {
// user doesn't have access to org
return errors.New("access denied")
}
if org.Name == "" {
return errors.New("name required")
}
return model.db.UpdateOrg(org)
}
func (model *Model) GetOrg(orgId string, userId string) (*types.Org, error) {
return model.db.GetOrg(orgId, userId)
}
func (model *Model) GetOrgs(userId string) ([]*types.Org, error) {
return model.db.GetOrgs(userId)
}
func (model *Model) UserBelongsToOrg(userId string, orgId string) (bool, error) {
orgs, err := model.GetOrgs(userId)
if err != nil {
return false, err
}
belongs := false
for _, org := range orgs {
if org.Id == orgId {
belongs = true
break
}
}
return belongs, nil
}
func (model *Model) CreateInvite(invite *types.Invite, userId string) error {
admins, err := model.db.GetOrgAdmins(invite.OrgId)
if err != nil {
return err
}
isAdmin := false
for _, admin := range admins {
if admin.Id == userId {
isAdmin = true
break
}
}
if isAdmin == false {
return errors.New("Must be org admin to invite users")
}
inviteId, err := util.NewInviteId()
if err != nil {
return err
}
invite.Id = inviteId
err = model.db.InsertInvite(invite)
if err != nil {
return err
}
if invite.Email != "" {
// TODO send email
}
return nil
}
func (model *Model) AcceptInvite(invite *types.Invite, userId string) error {
if invite.Accepted != true {
return errors.New("accepted must be true")
}
if invite.Id == "" {
return errors.New("missing invite id")
}
// Get original invite
original, err := model.db.GetInvite(invite.Id)
if err != nil {
return err
}
if original.Accepted == true {
return errors.New("invite already accepted")
}
oneWeekAfter := original.Inserted.Add(time.Hour * 24 * 7)
if time.Now().After(oneWeekAfter) == true {
return errors.New("invite has expired")
}
invite.OrgId = original.OrgId
invite.Email = original.Email
invite.Inserted = original.Inserted
return model.db.AcceptInvite(invite, userId)
}
func (model *Model) GetInvites(orgId string, userId string) ([]*types.Invite, error) {
admins, err := model.db.GetOrgAdmins(orgId)
if err != nil {
return nil, err
}
isAdmin := false
for _, admin := range admins {
if admin.Id == userId {
isAdmin = true
break
}
}
if isAdmin == false {
return nil, errors.New("Must be org admin to invite users")
}
return model.db.GetInvites(orgId)
}
func (model *Model) DeleteInvite(id string, userId string) error {
// Get original invite
invite, err := model.db.GetInvite(id)
if err != nil {
return err
}
// make sure user has access
admins, err := model.db.GetOrgAdmins(invite.OrgId)
if err != nil {
return nil
}
isAdmin := false
for _, admin := range admins {
if admin.Id == userId {
isAdmin = true
break
}
}
if isAdmin == false {
return errors.New("Must be org admin to delete invite")
}
return model.db.DeleteInvite(id)
}

74
core/model/org_test.go Normal file
View File

@@ -0,0 +1,74 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/stretchr/testify/assert"
"testing"
)
type TdOrg struct {
db.Datastore
}
func (td *TdOrg) GetOrg(orgId string, userId string) (*types.Org, error) {
if userId == "1" {
return &types.Org{
Id: "1",
Name: "MyOrg",
Currency: "USD",
Precision: 2,
}, nil
} else {
return nil, errors.New("not found")
}
}
func (td *TdOrg) UpdateOrg(org *types.Org) error {
return nil
}
func TestUpdateOrg(t *testing.T) {
tests := map[string]struct {
err error
org *types.Org
userId string
}{
"success": {
err: nil,
org: &types.Org{
Id: "1",
Name: "MyOrg2",
},
userId: "1",
},
"access denied": {
err: errors.New("access denied"),
org: &types.Org{
Id: "1",
Name: "MyOrg2",
},
userId: "2",
},
"error": {
err: errors.New("name required"),
org: &types.Org{
Id: "1",
Name: "",
},
userId: "1",
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdOrg{}
model := NewModel(td, nil, types.Config{})
err := model.UpdateOrg(test.org, test.userId)
assert.Equal(t, test.err, err)
}
}

117
core/model/price.go Normal file
View File

@@ -0,0 +1,117 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/ws"
"time"
)
type PriceInterface interface {
CreatePrice(*types.Price, string) error
DeletePrice(string, string) error
GetPricesNearestInTime(string, time.Time, string) ([]*types.Price, error)
GetPricesByCurrency(string, string, string) ([]*types.Price, error)
}
func (model *Model) CreatePrice(price *types.Price, userId string) error {
belongs, err := model.UserBelongsToOrg(userId, price.OrgId)
if err != nil {
return err
}
if belongs == false {
return errors.New("User does not belong to org")
}
if price.Id == "" {
return errors.New("id required")
}
if price.OrgId == "" {
return errors.New("orgId required")
}
if price.Currency == "" {
return errors.New("currency required")
}
err = model.db.InsertPrice(price)
if err != nil {
return err
}
// Notify web socket subscribers
userIds, err2 := model.db.GetOrgUserIds(price.OrgId)
if err2 == nil {
ws.PushPrice(price, userIds, "create")
}
return nil
}
func (model *Model) DeletePrice(id string, userId string) error {
// Get original price
price, err := model.db.GetPriceById(id)
if err != nil {
return err
}
belongs, err := model.UserBelongsToOrg(userId, price.OrgId)
if err != nil {
return err
}
if belongs == false {
return errors.New("User does not belong to org")
}
err = model.db.DeletePrice(id)
if err != nil {
return err
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access account
userIds, err2 := model.db.GetOrgUserIds(price.OrgId)
if err2 == nil {
ws.PushPrice(price, userIds, "delete")
}
return nil
}
func (model *Model) GetPricesNearestInTime(orgId string, date time.Time, userId string) ([]*types.Price, error) {
belongs, err := model.UserBelongsToOrg(userId, orgId)
if err != nil {
return nil, err
}
if belongs == false {
return nil, errors.New("User does not belong to org")
}
return model.db.GetPricesNearestInTime(orgId, date)
}
func (model *Model) GetPricesByCurrency(orgId string, currency string, userId string) ([]*types.Price, error) {
belongs, err := model.UserBelongsToOrg(userId, orgId)
if err != nil {
return nil, err
}
if belongs == false {
return nil, errors.New("User does not belong to org")
}
return model.db.GetPricesByCurrency(orgId, currency)
}

149
core/model/price_test.go Normal file
View File

@@ -0,0 +1,149 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/mocks"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestCreatePrice(t *testing.T) {
price := types.Price{
"1",
"2",
"BTC",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
}
badPrice := types.Price{
"1",
"2",
"",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
}
badOrg := types.Price{
"1",
"1",
"BTC",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
}
tests := map[string]struct {
err error
price types.Price
}{
"successful": {
err: nil,
price: price,
},
"with error": {
err: errors.New("currency required"),
price: badPrice,
},
"with org error": {
err: errors.New("User does not belong to org"),
price: badOrg,
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
price := test.price
userId := "3"
db := &mocks.Datastore{}
db.On("GetOrgs", userId).Return([]*types.Org{
{
Id: "2",
},
}, nil)
db.On("InsertPrice", &test.price).Return(nil)
db.On("GetOrgUserIds", price.OrgId).Return([]string{userId}, nil)
model := NewModel(db, &util.StandardBcrypt{}, types.Config{})
err := model.CreatePrice(&price, userId)
assert.Equal(t, test.err, err)
}
}
func TestDeletePrice(t *testing.T) {
price := types.Price{
"1",
"2",
"BTC",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
}
tests := map[string]struct {
err error
userId string
price types.Price
}{
"successful": {
err: nil,
price: price,
userId: "3",
},
"with org error": {
err: errors.New("User does not belong to org"),
price: price,
userId: "4",
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
price := test.price
db := &mocks.Datastore{}
db.On("GetPriceById", price.Id).Return(&price, nil)
db.On("GetOrgs", "3").Return([]*types.Org{
{
Id: "2",
},
}, nil)
db.On("GetOrgs", "4").Return([]*types.Org{
{
Id: "7",
},
}, nil)
db.On("DeletePrice", price.Id).Return(nil)
db.On("GetOrgUserIds", price.OrgId).Return([]string{test.userId}, nil)
model := NewModel(db, &util.StandardBcrypt{}, types.Config{})
err := model.DeletePrice(price.Id, test.userId)
assert.Equal(t, test.err, err)
}
}

23
core/model/session.go Normal file
View File

@@ -0,0 +1,23 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
)
type SessionInterface interface {
CreateSession(*types.Session) error
DeleteSession(string, string) error
}
func (model *Model) CreateSession(session *types.Session) error {
if session.Id == "" {
return errors.New("id required")
}
return model.db.InsertSession(session)
}
func (model *Model) DeleteSession(id string, userId string) error {
return model.db.DeleteSession(id, userId)
}

213
core/model/transaction.go Normal file
View File

@@ -0,0 +1,213 @@
package model
import (
"errors"
"fmt"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/ws"
"time"
)
type TransactionInterface interface {
CreateTransaction(*types.Transaction) error
UpdateTransaction(string, *types.Transaction) error
GetTransactionsByAccount(string, string, string, *types.QueryOptions) ([]*types.Transaction, error)
GetTransactionsByOrg(string, string, *types.QueryOptions) ([]*types.Transaction, error)
DeleteTransaction(string, string, string) error
}
func (model *Model) CreateTransaction(transaction *types.Transaction) (err error) {
err = model.checkSplits(transaction)
if err != nil {
return
}
if transaction.Id == "" {
return errors.New("id required")
}
transaction.Inserted = time.Now()
transaction.Updated = time.Now()
if transaction.Date.IsZero() {
transaction.Date = transaction.Inserted
}
err = model.db.InsertTransaction(transaction)
if err != nil {
return
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access transaction
userIds, err2 := model.db.GetOrgUserIds(transaction.OrgId)
if err2 == nil {
ws.PushTransaction(transaction, userIds, "create")
}
return
}
func (model *Model) UpdateTransaction(oldId string, transaction *types.Transaction) (err error) {
err = model.checkSplits(transaction)
if err != nil {
return
}
if oldId == "" || transaction.Id == "" {
return errors.New("id required")
}
// Get original transaction
original, err := model.getTransactionById(oldId)
if err != nil {
return
}
transaction.Updated = time.Now()
transaction.Inserted = original.Inserted
// We used to compare splits and if they hadn't changed just do an update
// on the transaction. The problem is then the updated field gets out of sync
// between the tranaction and its splits.
// It needs to be in sync for getTransactionsByOrg() to work correctly with pagination
// Delete old transaction and insert a new one
transaction.Inserted = transaction.Updated
err = model.db.DeleteAndInsertTransaction(oldId, transaction)
if err != nil {
return
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access transaction
userIds, err2 := model.db.GetOrgUserIds(transaction.OrgId)
if err2 == nil {
ws.PushTransaction(original, userIds, "delete")
ws.PushTransaction(transaction, userIds, "create")
}
return
}
func (model *Model) GetTransactionsByAccount(orgId string, userId string, accountId string, options *types.QueryOptions) ([]*types.Transaction, error) {
userAccounts, err := model.GetAccounts(orgId, userId, "")
if err != nil {
return nil, err
}
if !model.accountsContainWriteAccess(userAccounts, accountId) {
return nil, errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", accountId))
}
return model.db.GetTransactionsByAccount(accountId, options)
}
func (model *Model) GetTransactionsByOrg(orgId string, userId string, options *types.QueryOptions) ([]*types.Transaction, error) {
userAccounts, err := model.GetAccounts(orgId, userId, "")
if err != nil {
return nil, err
}
var accountIds []string
for _, account := range userAccounts {
accountIds = append(accountIds, account.Id)
}
return model.db.GetTransactionsByOrg(orgId, options, accountIds)
}
func (model *Model) DeleteTransaction(id string, userId string, orgId string) (err error) {
transaction, err := model.getTransactionById(id)
if err != nil {
return
}
userAccounts, err := model.GetAccounts(orgId, userId, "")
if err != nil {
return
}
for _, split := range transaction.Splits {
if !model.accountsContainWriteAccess(userAccounts, split.AccountId) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", split.AccountId))
}
}
err = model.db.DeleteTransaction(id)
if err != nil {
return
}
// Notify web socket subscribers
// TODO only get user ids that have permission to access transaction
userIds, err2 := model.db.GetOrgUserIds(transaction.OrgId)
if err2 == nil {
ws.PushTransaction(transaction, userIds, "delete")
}
return
}
func (model *Model) getTransactionById(id string) (*types.Transaction, error) {
// TODO if this is made public, make a separate version that checks permission
return model.db.GetTransactionById(id)
}
func (model *Model) checkSplits(transaction *types.Transaction) (err error) {
if len(transaction.Splits) < 2 {
return errors.New("at least 2 splits are required")
}
org, err := model.GetOrg(transaction.OrgId, transaction.UserId)
if err != nil {
return
}
userAccounts, err := model.GetAccounts(transaction.OrgId, transaction.UserId, "")
if err != nil {
return
}
var amount int64 = 0
for _, split := range transaction.Splits {
if !model.accountsContainWriteAccess(userAccounts, split.AccountId) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", split.AccountId))
}
account := model.getAccountFromList(userAccounts, split.AccountId)
if account.HasChildren == true {
return errors.New("Cannot use parent account for split")
}
if account.Currency == org.Currency && split.NativeAmount != split.Amount {
return errors.New("nativeAmount must equal amount for native currency splits")
}
amount += split.NativeAmount
}
if amount != 0 {
return errors.New("splits must add up to 0")
}
return
}

View File

@@ -0,0 +1,141 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
"time"
)
type TdTransaction struct {
db.Datastore
mock.Mock
}
func (td *TdTransaction) GetOrg(orgId string, userId string) (*types.Org, error) {
org := &types.Org{
Currency: "USD",
}
return org, nil
}
func (td *TdTransaction) GetPermissionedAccountIds(userId string, orgId string, tokenId string) ([]string, error) {
return []string{"1", "2"}, nil
}
func (td *TdTransaction) GetAccountsByOrgId(orgId string) ([]*types.Account, error) {
return []*types.Account{&types.Account{Id: "1", Currency: "USD"}, &types.Account{Id: "2"}}, nil
}
func (td *TdTransaction) InsertTransaction(transaction *types.Transaction) (err error) {
return nil
}
func (td *TdTransaction) GetTransactionById(id string) (*types.Transaction, error) {
args := td.Called(id)
return args.Get(0).(*types.Transaction), args.Error(1)
}
func (td *TdTransaction) UpdateTransaction(oldId string, transaction *types.Transaction) error {
args := td.Called(oldId, transaction)
return args.Error(0)
}
func (td *TdTransaction) GetOrgUserIds(id string) ([]string, error) {
return []string{"1"}, nil
}
func TestCreateTransaction(t *testing.T) {
tests := map[string]struct {
err error
tx *types.Transaction
}{
"successful": {
err: nil,
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 1000},
&types.Split{"1", "2", -1000, -1000},
},
},
},
"bad split amounts": {
err: errors.New("splits must add up to 0"),
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 1000},
&types.Split{"1", "2", -500, -500},
},
},
},
"lacking permission": {
err: errors.New("user does not have permission to access account 3"),
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 1000},
&types.Split{"1", "3", -1000, -1000},
},
},
},
"nativeAmount mismatch": {
err: errors.New("nativeAmount must equal amount for native currency splits"),
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 500},
&types.Split{"1", "2", -1000, -500},
},
},
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
td := &TdTransaction{}
model := NewModel(td, nil, types.Config{})
err := model.CreateTransaction(test.tx)
assert.Equal(t, err, test.err)
}
}

View File

@@ -0,0 +1,31 @@
package types
import (
"time"
)
type Account struct {
Id string `json:"id"`
OrgId string `json:"orgId"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
Name string `json:"name"`
Parent string `json:"parent"`
Currency string `json:"currency"`
Precision int `json:"precision"`
DebitBalance bool `json:"debitBalance"`
Balance *int64 `json:"balance"`
NativeBalance *int64 `json:"nativeBalance"`
ReadOnly bool `json:"readOnly"`
HasChildren bool `json:"-"`
}
type AccountNode struct {
Account *Account
Parent *AccountNode
Children []*AccountNode
}
func NewAccount() *Account {
return &Account{Precision: 2}
}

View File

@@ -0,0 +1,15 @@
package types
import (
"github.com/go-sql-driver/mysql"
"time"
)
type ApiKey struct {
Id string `json:"id"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
UserId string `json:"userId"`
Label string `json:"label"`
Deleted mysql.NullTime `json:"-"` // Can we marshal this correctly?
}

View File

@@ -0,0 +1,14 @@
package types
type Config struct {
WebUrl string
Port int
KeyFile string
CertFile string
Database string
User string
Password string
SendgridKey string
SendgridEmail string
SendgridSender string
}

View File

@@ -0,0 +1,14 @@
package types
import (
"time"
)
type Invite struct {
Id string `json:"id"`
OrgId string `json:"orgId"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
Email string `json:"email"`
Accepted bool `json:"accepted"`
}

14
core/model/types/org.go Normal file
View File

@@ -0,0 +1,14 @@
package types
import (
"time"
)
type Org struct {
Id string `json:"id"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
Name string `json:"name"`
Currency string `json:"currency"`
Precision int `json:"precision"`
}

15
core/model/types/price.go Normal file
View File

@@ -0,0 +1,15 @@
package types
import (
"time"
)
type Price struct {
Id string `json:"id"`
OrgId string `json:"orgId"`
Currency string `json:"currency"`
Date time.Time `json:"date"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
Price float64 `json:"price"`
}

View File

@@ -0,0 +1,104 @@
package types
import (
"net/url"
"strconv"
)
type QueryOptions struct {
Limit int `json:"limit"`
Skip int `json:"skip"`
SinceInserted int `json:"sinceInserted"`
SinceUpdated int `json:"sinceUpdated"`
BeforeInserted int `json:"beforeInserted"`
BeforeUpdated int `json:"beforeUpdated"`
StartDate int `json:"startDate"`
EndDate int `json:"endDate"`
DescriptionStartsWith string `json:"descriptionStartsWith"`
IncludeDeleted bool `json:"includeDeleted"`
Sort string `json:"string"`
}
func QueryOptionsFromURLQuery(urlQuery url.Values) (*QueryOptions, error) {
qo := &QueryOptions{}
var err error
if urlQuery.Get("limit") != "" {
qo.Limit, err = strconv.Atoi(urlQuery.Get("limit"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("skip") != "" {
qo.Skip, err = strconv.Atoi(urlQuery.Get("skip"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("sinceInserted") != "" {
qo.SinceInserted, err = strconv.Atoi(urlQuery.Get("sinceInserted"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("sinceUpdated") != "" {
qo.SinceUpdated, err = strconv.Atoi(urlQuery.Get("sinceUpdated"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("beforeInserted") != "" {
qo.BeforeInserted, err = strconv.Atoi(urlQuery.Get("beforeInserted"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("beforeUpdated") != "" {
qo.BeforeUpdated, err = strconv.Atoi(urlQuery.Get("beforeUpdated"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("startDate") != "" {
qo.StartDate, err = strconv.Atoi(urlQuery.Get("startDate"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("endDate") != "" {
qo.EndDate, err = strconv.Atoi(urlQuery.Get("endDate"))
if err != nil {
return nil, err
}
}
if urlQuery.Get("descriptionStartsWith") != "" {
qo.DescriptionStartsWith = urlQuery.Get("descriptionStartsWith")
}
if urlQuery.Get("includeDeleted") == "true" {
qo.IncludeDeleted = true
}
if urlQuery.Get("sort") != "" {
qo.Sort = urlQuery.Get("sort")
}
return qo, nil
}

View File

@@ -0,0 +1,14 @@
package types
import (
"github.com/go-sql-driver/mysql"
"time"
)
type Session struct {
Id string `json:"id"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
UserId string `json:"userId"`
Terminated mysql.NullTime `json:"-"` // Can we marshal this correctly?
}

View File

@@ -0,0 +1,25 @@
package types
import (
"time"
)
type Transaction struct {
Id string `json:"id"`
OrgId string `json:"orgId"`
UserId string `json:"userId"`
Date time.Time `json:"date"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
Description string `json:"description"`
Data string `json:"data"`
Deleted bool `json:"deleted"`
Splits []*Split `json:"splits"`
}
type Split struct {
TransactionId string `json:"-"`
AccountId string `json:"accountId"`
Amount int64 `json:"amount"`
NativeAmount int64 `json:"nativeAmount"`
}

20
core/model/types/user.go Normal file
View File

@@ -0,0 +1,20 @@
package types
import (
"time"
)
type User struct {
Id string `json:"id"`
Inserted time.Time `json:"inserted"`
Updated time.Time `json:"updated"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Email string `json:"email"`
Password string `json:"password"`
PasswordHash string `json:"-"`
AgreeToTerms bool `json:"agreeToTerms"`
PasswordReset string `json:"-"`
EmailVerified bool `json:"emailVerified"`
EmailVerifyCode string `json:"-"`
}

228
core/model/user.go Normal file
View File

@@ -0,0 +1,228 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/sendgrid/sendgrid-go"
"github.com/sendgrid/sendgrid-go/helpers/mail"
"log"
)
type UserInterface interface {
CreateUser(user *types.User) error
VerifyUser(string) error
UpdateUser(user *types.User) error
ResetPassword(email string) error
ConfirmResetPassword(string, string) (*types.User, error)
}
func (model *Model) CreateUser(user *types.User) error {
if user.Id == "" {
return errors.New("id required")
}
if user.FirstName == "" {
return errors.New("first name required")
}
if user.LastName == "" {
return errors.New("last name required")
}
if user.Email == "" {
return errors.New("email required")
}
if user.Password == "" {
return errors.New("password required")
}
if user.AgreeToTerms != true {
return errors.New("must agree to terms")
}
// hash password
// bcrypt's function also generates a salt
passwordHash, err := model.bcrypt.GenerateFromPassword([]byte(user.Password), model.bcrypt.GetDefaultCost())
if err != nil {
return err
}
user.PasswordHash = string(passwordHash)
user.Password = ""
user.EmailVerified = false
user.EmailVerifyCode, err = util.NewGuid()
if err != nil {
return err
}
err = model.db.InsertUser(user)
if err != nil {
return err
}
err = model.SendVerificationEmail(user)
if err != nil {
log.Println(err)
}
return nil
}
func (model *Model) VerifyUser(code string) error {
if code == "" {
return errors.New("code required")
}
return model.db.VerifyUser(code)
}
func (model *Model) UpdateUser(user *types.User) error {
if user.Id == "" {
return errors.New("id required")
}
if user.Password == "" {
return errors.New("password required")
}
// hash password
// bcrypt's function also generates a salt
passwordHash, err := model.bcrypt.GenerateFromPassword([]byte(user.Password), model.bcrypt.GetDefaultCost())
if err != nil {
return err
}
user.PasswordHash = string(passwordHash)
user.Password = ""
return model.db.UpdateUser(user)
}
func (model *Model) ResetPassword(email string) error {
if email == "" {
return errors.New("email required")
}
user, err := model.db.GetVerifiedUserByEmail(email)
if err != nil {
// Don't send back error so people can't try to find user accounts
log.Printf("Invalid email for reset password " + email)
return nil
}
user.PasswordReset, err = util.NewGuid()
if err != nil {
return err
}
err = model.db.UpdateUserResetPassword(user)
if err != nil {
return err
}
return model.SendPasswordResetEmail(user)
}
func (model *Model) ConfirmResetPassword(password string, code string) (*types.User, error) {
if password == "" {
return nil, errors.New("password required")
}
if code == "" {
return nil, errors.New("code required")
}
user, err := model.db.GetUserByResetCode(code)
if err != nil {
return nil, errors.New("Invalid code")
}
passwordHash, err := model.bcrypt.GenerateFromPassword([]byte(password), model.bcrypt.GetDefaultCost())
if err != nil {
return nil, err
}
user.PasswordHash = string(passwordHash)
user.Password = ""
err = model.db.UpdateUser(user)
if err != nil {
return nil, err
}
return user, nil
}
func (model *Model) SendVerificationEmail(user *types.User) error {
log.Println("Sending verification email to " + user.Email)
link := model.config.WebUrl + "/user/verify?code=" + user.EmailVerifyCode
from := mail.NewEmail(model.config.SendgridSender, model.config.SendgridEmail)
subject := "Verify your email"
to := mail.NewEmail(user.FirstName+" "+user.LastName, user.Email)
plainTextContent := "Thank you for signing up with Open Accounting! " +
"Please click on the link below to verify your email address:\n\n" + link
htmlContent := "Thank you for signing up with Open Accounting! " +
"Please click on the link below to verify your email address:<br><br>" +
"<a href=\"" + link + "\">" + link + "</a>"
message := mail.NewSingleEmail(from, subject, to, plainTextContent, htmlContent)
client := sendgrid.NewSendClient(model.config.SendgridKey)
response, err := client.Send(message)
if err != nil {
return err
}
log.Println(response.StatusCode)
log.Println(response.Body)
log.Println(response.Headers)
return nil
}
func (model *Model) SendPasswordResetEmail(user *types.User) error {
log.Println("Sending password reset email to " + user.Email)
link := model.config.WebUrl + "/user/reset-password?code=" + user.PasswordReset
from := mail.NewEmail(model.config.SendgridSender, model.config.SendgridEmail)
subject := "Reset password"
to := mail.NewEmail(user.FirstName+" "+user.LastName, user.Email)
plainTextContent := "Please click the following link to reset your password:\n\n" + link +
"If you did not request to have your password reset, please ignore this email and " +
"nothing will happen."
htmlContent := "Please click the following link to reset your password:<br><br>\n" +
"<a href=\"" + link + "\">" + link + "</a><br><br>\n" +
"If you did not request to have your password reset, please ignore this email and " +
"nothing will happen."
message := mail.NewSingleEmail(from, subject, to, plainTextContent, htmlContent)
client := sendgrid.NewSendClient(model.config.SendgridKey)
response, err := client.Send(message)
if err != nil {
return err
}
log.Println(response.StatusCode)
log.Println(response.Body)
log.Println(response.Headers)
return nil
}

177
core/model/user_test.go Normal file
View File

@@ -0,0 +1,177 @@
package model
import (
"errors"
"github.com/openaccounting/oa-server/core/mocks"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type TdUser struct {
db.Datastore
testNum int
}
func (td *TdUser) InsertUser(user *types.User) error {
return nil
}
func (td *TdUser) UpdateUser(user *types.User) error {
return nil
}
func TestCreateUser(t *testing.T) {
// Id string `json:"id"`
// Inserted time.Time `json:"inserted"`
// Updated time.Time `json:"updated"`
// FirstName string `json:"firstName"`
// LastName string `json:"lastName"`
// Email string `json:"email"`
// Password string `json:"password"`
// PasswordHash string `json:"-"`
// AgreeToTerms bool `json:"agreeToTerms"`
// PasswordReset string `json:"-"`
// EmailVerified bool `json:"emailVerified"`
// EmailVerifyCode string `json:"-"`
user := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John",
"Doe",
"johndoe@email.com",
"password",
"",
true,
"",
false,
"",
}
badUser := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John",
"Doe",
"",
"password",
"",
true,
"",
false,
"",
}
tests := map[string]struct {
err error
user types.User
}{
"successful": {
err: nil,
user: user,
},
"with error": {
err: errors.New("email required"),
user: badUser,
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
user := test.user
mockBcrypt := new(mocks.Bcrypt)
mockBcrypt.On("GetDefaultCost").Return(10)
mockBcrypt.On("GenerateFromPassword", []byte(user.Password), 10).
Return(make([]byte, 0), nil)
model := NewModel(&TdUser{}, mockBcrypt, types.Config{})
err := model.CreateUser(&user)
assert.Equal(t, err, test.err)
if err == nil {
mockBcrypt.AssertExpectations(t)
}
}
}
func TestUpdateUser(t *testing.T) {
user := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John2",
"Doe",
"johndoe@email.com",
"password",
"",
true,
"",
false,
"",
}
badUser := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John2",
"Doe",
"johndoe@email.com",
"",
"",
true,
"",
false,
"",
}
tests := map[string]struct {
err error
user types.User
}{
"successful": {
err: nil,
user: user,
},
"with error": {
err: errors.New("password required"),
user: badUser,
},
}
for name, test := range tests {
t.Logf("Running test case: %s", name)
user := test.user
mockBcrypt := new(mocks.Bcrypt)
mockBcrypt.On("GetDefaultCost").Return(10)
mockBcrypt.On("GenerateFromPassword", []byte(user.Password), 10).
Return(make([]byte, 0), nil)
model := NewModel(&TdUser{}, mockBcrypt, types.Config{})
err := model.UpdateUser(&user)
assert.Equal(t, err, test.err)
if err == nil {
mockBcrypt.AssertExpectations(t)
}
}
}