feat: implement secure file upload system with JWT authentication

- Add JWT-based secure file access for local storage with 1-hour expiry
- Implement GORM repository methods for attachment CRUD operations
- Add secure file serving endpoint with token validation
- Update storage interface to support user context in URL generation
- Add comprehensive security features including path traversal protection
- Update documentation with security model and configuration examples
- Add utility functions for hex/byte conversion and UUID validation
- Configure secure file permissions (0600) for uploaded files

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-07-03 15:45:25 +12:00
parent b2b77eb4da
commit 8b6ba74ce9
19 changed files with 546 additions and 43 deletions

View File

@@ -131,10 +131,10 @@ func TestAttachmentIntegration(t *testing.T) {
// Set up the model instance for the API handlers
bc := &util.StandardBcrypt{}
// Use the existing datastore model which has the attachment implementation
// Use the GORM model which has the attachment implementation
// We need to create it with the database connection
datastoreModel := model.NewModel(nil, bc, types.Config{})
model.Instance = datastoreModel
gormModel := model.NewGormModel(db, bc, types.Config{})
model.Instance = gormModel
t.Run("Database Integration Test", func(t *testing.T) {
// Test direct database operations first

View File

@@ -162,8 +162,8 @@ func GetAttachmentDownloadURL(w rest.ResponseWriter, r *rest.Request) {
return
}
// Generate download URL (valid for 1 hour)
url, err := attachmentHandler.storage.GetURL(attachment.FilePath, time.Hour)
// Generate download URL (valid for 1 hour) with user context for JWT tokens
url, err := attachmentHandler.storage.GetURLWithContext(attachment.FilePath, time.Hour, user.Id, attachment.OrgId)
if err != nil {
rest.Error(w, "Failed to generate download URL", http.StatusInternalServerError)
return

View File

@@ -42,6 +42,9 @@ func GetRouter(auth *AuthMiddleware, prefix string) (rest.App, error) {
rest.Get(prefix+"/attachments/:id", auth.RequireAuth(GetAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id/url", auth.RequireAuth(GetAttachmentDownloadURL)),
rest.Delete(prefix+"/attachments/:id", auth.RequireAuth(DeleteAttachmentWithStorage)),
// Secure file serving endpoint (no auth required - token validates access)
rest.Get("/secure-files", GetSecureFile),
rest.Get(prefix+"/orgs/:orgId/prices", auth.RequireAuth(GetPrices)),
rest.Post(prefix+"/orgs/:orgId/prices", auth.RequireAuth(PostPrice)),
rest.Delete(prefix+"/orgs/:orgId/prices/:priceId", auth.RequireAuth(DeletePrice)),

160
core/api/secure_files.go Normal file
View File

@@ -0,0 +1,160 @@
package api
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/ant0ine/go-json-rest/rest"
"github.com/openaccounting/oa-server/core/storage"
)
// TokenService instance for file access
var tokenService *storage.TokenService
// InitSecureFileServer initializes the token service for secure file serving
func InitSecureFileServer(signingKey string) {
tokenService = storage.NewTokenService(signingKey)
}
// GetSecureFile serves files with JWT token validation
func GetSecureFile(w rest.ResponseWriter, r *rest.Request) {
// Extract token from query parameter
token := r.URL.Query().Get("token")
if token == "" {
rest.Error(w, "Missing access token", http.StatusUnauthorized)
return
}
// Validate the token
claims, err := tokenService.ValidateFileToken(token)
if err != nil {
rest.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
// Get the file path from the token claims
filePath := claims.FilePath
// Validate the file path (additional security check)
if err := validateSecureFilePath(filePath); err != nil {
rest.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
// Serve the file
if err := serveFile(w, r, filePath); err != nil {
if os.IsNotExist(err) {
rest.Error(w, "File not found", http.StatusNotFound)
} else {
rest.Error(w, "Failed to serve file", http.StatusInternalServerError)
}
return
}
}
// serveFile serves a file with proper headers and security measures
func serveFile(w rest.ResponseWriter, r *rest.Request, filePath string) error {
// Get the full path relative to the uploads directory
// This assumes the local storage root directory is "./uploads"
fullPath := filepath.Join("./uploads", filePath)
// Open the file
file, err := os.Open(fullPath)
if err != nil {
return err
}
defer file.Close()
// Get file info for headers
info, err := file.Stat()
if err != nil {
return err
}
// Set security headers
responseWriter := w.(http.ResponseWriter)
responseWriter.Header().Set("X-Content-Type-Options", "nosniff")
responseWriter.Header().Set("X-Frame-Options", "DENY")
responseWriter.Header().Set("Content-Security-Policy", "default-src 'none'")
// Set content headers
responseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", info.Size()))
responseWriter.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
// Detect content type based on file extension
contentType := getContentType(filePath)
responseWriter.Header().Set("Content-Type", contentType)
// Set cache headers for temporary access
responseWriter.Header().Set("Cache-Control", "private, max-age=300") // 5 minutes
responseWriter.Header().Set("Expires", time.Now().Add(5*time.Minute).UTC().Format(http.TimeFormat))
// Copy file content to response
_, err = io.Copy(responseWriter, file)
return err
}
// validateSecureFilePath validates that the file path is safe to serve
func validateSecureFilePath(path string) error {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Reject paths that try to go up directories
if strings.Contains(cleanPath, "..") {
return fmt.Errorf("path traversal attempt detected")
}
// Reject absolute paths
if filepath.IsAbs(cleanPath) {
return fmt.Errorf("absolute paths not allowed")
}
// Additional validation: ensure path starts with expected date format
parts := strings.Split(cleanPath, string(filepath.Separator))
if len(parts) < 4 {
return fmt.Errorf("invalid path structure")
}
return nil
}
// getContentType returns the MIME type based on file extension
func getContentType(filePath string) string {
ext := strings.ToLower(filepath.Ext(filePath))
switch ext {
case ".pdf":
return "application/pdf"
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".gif":
return "image/gif"
case ".webp":
return "image/webp"
case ".doc":
return "application/msword"
case ".docx":
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
case ".xls":
return "application/vnd.ms-excel"
case ".xlsx":
return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
case ".ppt":
return "application/vnd.ms-powerpoint"
case ".pptx":
return "application/vnd.openxmlformats-officedocument.presentationml.presentation"
case ".txt":
return "text/plain"
case ".csv":
return "text/csv"
default:
return "application/octet-stream"
}
}

View File

@@ -262,21 +262,25 @@ func (m *GormModel) CreateAttachment(attachment *types.Attachment) (*types.Attac
attachment.Uploaded = time.Now()
attachment.Deleted = false
// For GORM implementation, we'd need to implement repository methods
// For now, return an error indicating not implemented
return nil, errors.New("attachment operations not yet implemented for GORM model")
// Use repository to insert attachment
err := m.repository.InsertAttachment(attachment)
if err != nil {
return nil, err
}
return attachment, nil
}
func (m *GormModel) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
return nil, errors.New("attachment operations not yet implemented for GORM model")
return m.repository.GetAttachmentsByTransaction(transactionId, orgId, userId)
}
func (m *GormModel) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
return nil, errors.New("attachment operations not yet implemented for GORM model")
return m.repository.GetAttachment(attachmentId, transactionId, orgId, userId)
}
func (m *GormModel) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
return errors.New("attachment operations not yet implemented for GORM model")
return m.repository.DeleteAttachment(attachmentId, transactionId, orgId, userId)
}
func (m *GormModel) GetTransactionById(id string) (*types.Transaction, error) {

View File

@@ -372,4 +372,167 @@ func (r *GormRepository) Escape(sql string) string {
// GORM handles SQL injection protection automatically
// This method is kept for interface compatibility
return sql
}
// Attachment repository methods
func (r *GormRepository) InsertAttachment(attachment *types.Attachment) error {
// Convert UUID strings to bytes (remove dashes if present)
idBytes, err := stringToIDBytes(attachment.Id)
if err != nil {
return err
}
transactionIdBytes, err := stringToIDBytes(attachment.TransactionId)
if err != nil {
return err
}
orgIdBytes, err := stringToIDBytes(attachment.OrgId)
if err != nil {
return err
}
userIdBytes, err := stringToIDBytes(attachment.UserId)
if err != nil {
return err
}
// Convert types.Attachment to models.Attachment
gormAttachment := &models.Attachment{
ID: idBytes,
TransactionID: transactionIdBytes,
OrgID: orgIdBytes,
UserID: userIdBytes,
FileName: attachment.FileName,
OriginalName: attachment.OriginalName,
ContentType: attachment.ContentType,
FileSize: attachment.FileSize,
FilePath: attachment.FilePath,
Description: attachment.Description,
Uploaded: attachment.Uploaded,
Deleted: attachment.Deleted,
}
result := r.db.Create(gormAttachment)
return result.Error
}
func (r *GormRepository) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
transactionIdBytes, err := stringToIDBytes(transactionId)
if err != nil {
return nil, err
}
orgIdBytes, err := stringToIDBytes(orgId)
if err != nil {
return nil, err
}
var gormAttachments []models.Attachment
result := r.db.Where("transactionId = ? AND orgId = ? AND deleted = ?",
transactionIdBytes, orgIdBytes, false).Find(&gormAttachments)
if result.Error != nil {
return nil, result.Error
}
attachments := make([]*types.Attachment, len(gormAttachments))
for i, gormAttachment := range gormAttachments {
attachments[i] = convertGormToTypesAttachment(&gormAttachment)
}
return attachments, nil
}
func (r *GormRepository) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
attachmentIdBytes, err := stringToIDBytes(attachmentId)
if err != nil {
return nil, err
}
var gormAttachment models.Attachment
query := r.db.Where("id = ? AND deleted = ?", attachmentIdBytes, false)
// Add additional filters if provided
if transactionId != "" {
transactionIdBytes, err := stringToIDBytes(transactionId)
if err != nil {
return nil, err
}
query = query.Where("transactionId = ?", transactionIdBytes)
}
if orgId != "" {
orgIdBytes, err := stringToIDBytes(orgId)
if err != nil {
return nil, err
}
query = query.Where("orgId = ?", orgIdBytes)
}
result := query.First(&gormAttachment)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, result.Error
}
return convertGormToTypesAttachment(&gormAttachment), nil
}
func (r *GormRepository) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
attachmentIdBytes, err := stringToIDBytes(attachmentId)
if err != nil {
return err
}
query := r.db.Model(&models.Attachment{}).Where("id = ?", attachmentIdBytes)
// Add additional filters if provided
if transactionId != "" {
transactionIdBytes, err := stringToIDBytes(transactionId)
if err != nil {
return err
}
query = query.Where("transactionId = ?", transactionIdBytes)
}
if orgId != "" {
orgIdBytes, err := stringToIDBytes(orgId)
if err != nil {
return err
}
query = query.Where("orgId = ?", orgIdBytes)
}
// Soft delete by setting deleted = true
result := query.Update("deleted", true)
return result.Error
}
// Helper function to convert UUID string (with or without dashes) to bytes
func stringToIDBytes(id string) ([]byte, error) {
// Remove dashes if present
cleanId := strings.ReplaceAll(id, "-", "")
return util.HexToBytes(cleanId)
}
// Helper function to convert bytes to UUID string (without dashes, for compatibility)
func idBytesToString(bytes []byte) string {
return util.BytesToHex(bytes)
}
// Helper function to convert GORM attachment to types attachment
func convertGormToTypesAttachment(gormAttachment *models.Attachment) *types.Attachment {
return &types.Attachment{
Id: idBytesToString(gormAttachment.ID),
TransactionId: idBytesToString(gormAttachment.TransactionID),
OrgId: idBytesToString(gormAttachment.OrgID),
UserId: idBytesToString(gormAttachment.UserID),
FileName: gormAttachment.FileName,
OriginalName: gormAttachment.OriginalName,
ContentType: gormAttachment.ContentType,
FileSize: gormAttachment.FileSize,
FilePath: gormAttachment.FilePath,
Description: gormAttachment.Description,
Uploaded: gormAttachment.Uploaded,
Deleted: gormAttachment.Deleted,
}
}

View File

@@ -40,6 +40,7 @@ func main() {
viper.BindEnv("Storage.backend", "OA_STORAGE_BACKEND")
viper.BindEnv("Storage.local.root_dir", "OA_STORAGE_LOCAL_ROOTDIR")
viper.BindEnv("Storage.local.base_url", "OA_STORAGE_LOCAL_BASEURL")
viper.BindEnv("Storage.local.signing_key", "OA_STORAGE_LOCAL_SIGNINGKEY")
viper.BindEnv("Storage.s3.region", "OA_STORAGE_S3_REGION")
viper.BindEnv("Storage.s3.bucket", "OA_STORAGE_S3_BUCKET")
viper.BindEnv("Storage.s3.prefix", "OA_STORAGE_S3_PREFIX")
@@ -59,6 +60,7 @@ func main() {
viper.SetDefault("Storage.backend", "local")
viper.SetDefault("Storage.local.root_dir", "./uploads")
viper.SetDefault("Storage.local.base_url", "")
viper.SetDefault("Storage.local.signing_key", "") // Will auto-generate if empty
// Read configuration
err := viper.ReadInConfig()
@@ -141,6 +143,9 @@ func main() {
log.Fatal(fmt.Errorf("failed to initialize storage backend: %s", err.Error()))
}
// Initialize secure file server with signing key for local storage
api.InitSecureFileServer(config.Storage.Local.SigningKey)
app, err := api.Init(config.ApiPrefix)
if err != nil {
log.Fatal(fmt.Errorf("failed to create api instance with: %s", err.Error()))

View File

@@ -19,6 +19,9 @@ type Storage interface {
// GetURL returns a URL for accessing the file (may be signed/temporary)
GetURL(path string, expiry time.Duration) (string, error)
// GetURLWithContext returns a URL for accessing the file with user context (for JWT tokens)
GetURLWithContext(path string, expiry time.Duration, userID, orgID string) (string, error)
// Exists checks if a file exists at the given path
Exists(path string) (bool, error)
@@ -53,6 +56,9 @@ type LocalConfig struct {
// Base URL for serving files (optional)
BaseURL string `mapstructure:"base_url"`
// Signing key for JWT tokens (optional, will be auto-generated if empty)
SigningKey string `mapstructure:"signing_key"`
}
// S3Config configures S3-compatible storage (AWS S3, Backblaze B2, Cloudflare R2, etc.)

View File

@@ -54,19 +54,13 @@ func TestNewStorage(t *testing.T) {
}
})
t.Run("B2 Storage", func(t *testing.T) {
t.Run("Invalid Backend", func(t *testing.T) {
config := Config{
Backend: "b2",
B2: B2Config{
AccountID: "test-account",
ApplicationKey: "test-key",
Bucket: "test-bucket",
},
Backend: "invalid",
}
// This will fail because we don't have real B2 credentials
storage, err := NewStorage(config)
assert.Error(t, err) // Expected to fail without credentials
assert.Error(t, err)
assert.Nil(t, storage)
})

View File

@@ -3,6 +3,7 @@ package storage
import (
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strings"
@@ -13,8 +14,9 @@ import (
// LocalStorage implements the Storage interface for local filesystem
type LocalStorage struct {
rootDir string
baseURL string
rootDir string
baseURL string
tokenService *TokenService
}
// NewLocalStorage creates a new local filesystem storage backend
@@ -24,14 +26,18 @@ func NewLocalStorage(config LocalConfig) (*LocalStorage, error) {
rootDir = "./uploads"
}
// Ensure the root directory exists
if err := os.MkdirAll(rootDir, 0755); err != nil {
// Ensure the root directory exists with secure permissions
if err := os.MkdirAll(rootDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create storage directory: %w", err)
}
// Initialize token service for secure URL generation
tokenService := NewTokenService(config.SigningKey)
return &LocalStorage{
rootDir: rootDir,
baseURL: config.BaseURL,
rootDir: rootDir,
baseURL: config.BaseURL,
tokenService: tokenService,
}, nil
}
@@ -41,14 +47,14 @@ func (l *LocalStorage) Store(filename string, content io.Reader, contentType str
storagePath := l.generateStoragePath(filename)
fullPath := filepath.Join(l.rootDir, storagePath)
// Ensure the directory exists
// Ensure the directory exists with secure permissions
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
if err := os.MkdirAll(dir, 0700); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
// Create and write the file
file, err := os.Create(fullPath)
// Create and write the file with secure permissions
file, err := os.OpenFile(fullPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
@@ -102,8 +108,13 @@ func (l *LocalStorage) Delete(path string) error {
return nil
}
// GetURL returns a URL for accessing the file
// GetURL returns a secure URL for accessing the file with JWT token
func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error) {
return l.GetURLWithContext(path, expiry, "", "")
}
// GetURLWithContext returns a secure URL for accessing the file with JWT token and user context
func (l *LocalStorage) GetURLWithContext(path string, expiry time.Duration, userID, orgID string) (string, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return "", err
@@ -118,14 +129,16 @@ func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error)
return "", &FileNotFoundError{Path: path}
}
if l.baseURL != "" {
// Return a public URL if base URL is configured
return l.baseURL + "/" + path, nil
// Generate secure JWT token for file access
token, err := l.tokenService.GenerateFileToken(path, userID, orgID, expiry)
if err != nil {
return "", fmt.Errorf("failed to generate access token: %w", err)
}
// For local storage without a base URL, return the file path
// In a real application, you might serve these through an endpoint
return "/files/" + path, nil
// Return secure URL with token parameter
params := url.Values{}
params.Set("token", token)
return "/secure-files?" + params.Encode(), nil
}
// Exists checks if a file exists at the given path

View File

@@ -72,8 +72,11 @@ func TestLocalStorage(t *testing.T) {
url, err := storage.GetURL(path, time.Hour)
assert.NoError(t, err)
assert.Contains(t, url, path)
assert.Contains(t, url, config.BaseURL)
// New JWT token-based URLs should start with /secure-files and contain a token parameter
assert.Contains(t, url, "/secure-files")
assert.Contains(t, url, "token=")
// The token should be a JWT (contains dots for header.payload.signature)
assert.Contains(t, url, ".")
})
t.Run("Delete File", func(t *testing.T) {

View File

@@ -155,6 +155,12 @@ func (s *S3Storage) GetURL(path string, expiry time.Duration) (string, error) {
return url, nil
}
// GetURLWithContext returns a presigned URL for accessing the file (S3 doesn't use user context)
func (s *S3Storage) GetURLWithContext(path string, expiry time.Duration, userID, orgID string) (string, error) {
// For S3, user context is not needed as presigned URLs are cryptographically secure
return s.GetURL(path, expiry)
}
// Exists checks if a file exists in S3
func (s *S3Storage) Exists(path string) (bool, error) {
input := &s3.HeadObjectInput{

83
core/storage/token.go Normal file
View File

@@ -0,0 +1,83 @@
package storage
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/openaccounting/oa-server/core/util/id"
)
// TokenService handles JWT token generation and validation for file access
type TokenService struct {
signingKey []byte
}
// NewTokenService creates a new token service with a signing key
func NewTokenService(signingKey string) *TokenService {
if signingKey == "" {
// Generate a random signing key if none provided
// In production, this should be a consistent secret from config
signingKey = id.String(id.New())
}
return &TokenService{
signingKey: []byte(signingKey),
}
}
// FileClaims represents the JWT claims for file access
type FileClaims struct {
FilePath string `json:"file_path"`
UserID string `json:"user_id,omitempty"`
OrgID string `json:"org_id,omitempty"`
jwt.RegisteredClaims
}
// GenerateFileToken creates a JWT token for accessing a specific file
func (ts *TokenService) GenerateFileToken(filePath string, userID, orgID string, expiry time.Duration) (string, error) {
now := time.Now()
claims := FileClaims{
FilePath: filePath,
UserID: userID,
OrgID: orgID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: "openaccounting-server",
Subject: "file-access",
ID: id.String(id.New()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(ts.signingKey)
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return tokenString, nil
}
// ValidateFileToken validates a JWT token and returns the file claims
func (ts *TokenService) ValidateFileToken(tokenString string) (*FileClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &FileClaims{}, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return ts.signingKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*FileClaims)
if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
}
return claims, nil
}

View File

@@ -65,3 +65,13 @@ func IsValidUUID(uuid string) bool {
matched, _ := regexp.MatchString("^[0-9a-f]{32}$", uuid)
return matched
}
// HexToBytes converts a hex string to bytes
func HexToBytes(hexString string) ([]byte, error) {
return hex.DecodeString(hexString)
}
// BytesToHex converts bytes to a hex string
func BytesToHex(bytes []byte) string {
return hex.EncodeToString(bytes)
}