diff --git a/core/auth/auth.go b/core/auth/auth.go index 9a16423..4849a31 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -19,6 +19,7 @@ type Interface interface { AuthenticateUser(email string, password string) (*types.User, error) AuthenticateSession(string) (*types.User, error) AuthenticateApiKey(string) (*types.User, error) + AuthenticateEmailVerifyCode(string) (*types.User, error) } func NewAuthService(db db.Datastore, bcrypt util.Bcrypt) *AuthService { @@ -47,6 +48,12 @@ func (auth *AuthService) Authenticate(emailOrKey string, password string) (*type return user, nil } + user, err = auth.AuthenticateEmailVerifyCode(emailOrKey) + + if err == nil { + return user, nil + } + return nil, errors.New("Unauthorized") } @@ -89,3 +96,13 @@ func (auth *AuthService) AuthenticateApiKey(id string) (*types.User, error) { return u, nil } + +func (auth *AuthService) AuthenticateEmailVerifyCode(code string) (*types.User, error) { + u, err := auth.db.GetUserByEmailVerifyCode(code) + + if err != nil { + return nil, errors.New("Access denied") + } + + return u, nil +} \ No newline at end of file diff --git a/core/model/db/user.go b/core/model/db/user.go index 0fbf3aa..cfbd038 100644 --- a/core/model/db/user.go +++ b/core/model/db/user.go @@ -3,7 +3,6 @@ package db import ( "database/sql" "errors" - "fmt" "github.com/openaccounting/oa-server/core/model/types" "github.com/openaccounting/oa-server/core/util" "time" @@ -20,6 +19,7 @@ type UserInterface interface { GetUserByActiveSession(string) (*types.User, error) GetUserByApiKey(string) (*types.User, error) GetUserByResetCode(string) (*types.User, error) + GetUserByEmailVerifyCode(string) (*types.User, error) GetOrgAdmins(string) ([]*types.User, error) } @@ -172,7 +172,24 @@ func (db *DB) GetUserByResetCode(code string) (*types.User, error) { return nil, err } - fmt.Println(u) + return u, nil +} + +func (db *DB) GetUserByEmailVerifyCode(code string) (*types.User, error) { + // only allow this for 3 days + minInserted := (time.Now().UnixNano() / 1000000) - (3 * 24 * 60 * 60 * 1000) + qSelect := "SELECT " + userFields + qFrom := " FROM user u" + qWhere := " WHERE u.emailVerifyCode = ? AND inserted > ?" + + query := qSelect + qFrom + qWhere + + row := db.QueryRow(query, code, minInserted) + u, err := db.unmarshalUser(row) + + if err != nil { + return nil, err + } return u, nil }