You've already forked openaccounting-server
forked from cybercinch/openaccounting-server
feat: implement secure file upload system with JWT authentication
- Add JWT-based secure file access for local storage with 1-hour expiry - Implement GORM repository methods for attachment CRUD operations - Add secure file serving endpoint with token validation - Update storage interface to support user context in URL generation - Add comprehensive security features including path traversal protection - Update documentation with security model and configuration examples - Add utility functions for hex/byte conversion and UUID validation - Configure secure file permissions (0600) for uploaded files 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -131,10 +131,10 @@ func TestAttachmentIntegration(t *testing.T) {
|
||||
// Set up the model instance for the API handlers
|
||||
bc := &util.StandardBcrypt{}
|
||||
|
||||
// Use the existing datastore model which has the attachment implementation
|
||||
// Use the GORM model which has the attachment implementation
|
||||
// We need to create it with the database connection
|
||||
datastoreModel := model.NewModel(nil, bc, types.Config{})
|
||||
model.Instance = datastoreModel
|
||||
gormModel := model.NewGormModel(db, bc, types.Config{})
|
||||
model.Instance = gormModel
|
||||
|
||||
t.Run("Database Integration Test", func(t *testing.T) {
|
||||
// Test direct database operations first
|
||||
|
||||
@@ -162,8 +162,8 @@ func GetAttachmentDownloadURL(w rest.ResponseWriter, r *rest.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate download URL (valid for 1 hour)
|
||||
url, err := attachmentHandler.storage.GetURL(attachment.FilePath, time.Hour)
|
||||
// Generate download URL (valid for 1 hour) with user context for JWT tokens
|
||||
url, err := attachmentHandler.storage.GetURLWithContext(attachment.FilePath, time.Hour, user.Id, attachment.OrgId)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to generate download URL", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -42,6 +42,9 @@ func GetRouter(auth *AuthMiddleware, prefix string) (rest.App, error) {
|
||||
rest.Get(prefix+"/attachments/:id", auth.RequireAuth(GetAttachmentWithStorage)),
|
||||
rest.Get(prefix+"/attachments/:id/url", auth.RequireAuth(GetAttachmentDownloadURL)),
|
||||
rest.Delete(prefix+"/attachments/:id", auth.RequireAuth(DeleteAttachmentWithStorage)),
|
||||
|
||||
// Secure file serving endpoint (no auth required - token validates access)
|
||||
rest.Get("/secure-files", GetSecureFile),
|
||||
rest.Get(prefix+"/orgs/:orgId/prices", auth.RequireAuth(GetPrices)),
|
||||
rest.Post(prefix+"/orgs/:orgId/prices", auth.RequireAuth(PostPrice)),
|
||||
rest.Delete(prefix+"/orgs/:orgId/prices/:priceId", auth.RequireAuth(DeletePrice)),
|
||||
|
||||
160
core/api/secure_files.go
Normal file
160
core/api/secure_files.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ant0ine/go-json-rest/rest"
|
||||
"github.com/openaccounting/oa-server/core/storage"
|
||||
)
|
||||
|
||||
// TokenService instance for file access
|
||||
var tokenService *storage.TokenService
|
||||
|
||||
// InitSecureFileServer initializes the token service for secure file serving
|
||||
func InitSecureFileServer(signingKey string) {
|
||||
tokenService = storage.NewTokenService(signingKey)
|
||||
}
|
||||
|
||||
// GetSecureFile serves files with JWT token validation
|
||||
func GetSecureFile(w rest.ResponseWriter, r *rest.Request) {
|
||||
// Extract token from query parameter
|
||||
token := r.URL.Query().Get("token")
|
||||
if token == "" {
|
||||
rest.Error(w, "Missing access token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the token
|
||||
claims, err := tokenService.ValidateFileToken(token)
|
||||
if err != nil {
|
||||
rest.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the file path from the token claims
|
||||
filePath := claims.FilePath
|
||||
|
||||
// Validate the file path (additional security check)
|
||||
if err := validateSecureFilePath(filePath); err != nil {
|
||||
rest.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Serve the file
|
||||
if err := serveFile(w, r, filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
rest.Error(w, "File not found", http.StatusNotFound)
|
||||
} else {
|
||||
rest.Error(w, "Failed to serve file", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// serveFile serves a file with proper headers and security measures
|
||||
func serveFile(w rest.ResponseWriter, r *rest.Request, filePath string) error {
|
||||
// Get the full path relative to the uploads directory
|
||||
// This assumes the local storage root directory is "./uploads"
|
||||
fullPath := filepath.Join("./uploads", filePath)
|
||||
|
||||
// Open the file
|
||||
file, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Get file info for headers
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set security headers
|
||||
responseWriter := w.(http.ResponseWriter)
|
||||
responseWriter.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
responseWriter.Header().Set("X-Frame-Options", "DENY")
|
||||
responseWriter.Header().Set("Content-Security-Policy", "default-src 'none'")
|
||||
|
||||
// Set content headers
|
||||
responseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", info.Size()))
|
||||
responseWriter.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
|
||||
|
||||
// Detect content type based on file extension
|
||||
contentType := getContentType(filePath)
|
||||
responseWriter.Header().Set("Content-Type", contentType)
|
||||
|
||||
// Set cache headers for temporary access
|
||||
responseWriter.Header().Set("Cache-Control", "private, max-age=300") // 5 minutes
|
||||
responseWriter.Header().Set("Expires", time.Now().Add(5*time.Minute).UTC().Format(http.TimeFormat))
|
||||
|
||||
// Copy file content to response
|
||||
_, err = io.Copy(responseWriter, file)
|
||||
return err
|
||||
}
|
||||
|
||||
// validateSecureFilePath validates that the file path is safe to serve
|
||||
func validateSecureFilePath(path string) error {
|
||||
// Clean the path and check for traversal attempts
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Reject paths that try to go up directories
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("path traversal attempt detected")
|
||||
}
|
||||
|
||||
// Reject absolute paths
|
||||
if filepath.IsAbs(cleanPath) {
|
||||
return fmt.Errorf("absolute paths not allowed")
|
||||
}
|
||||
|
||||
// Additional validation: ensure path starts with expected date format
|
||||
parts := strings.Split(cleanPath, string(filepath.Separator))
|
||||
if len(parts) < 4 {
|
||||
return fmt.Errorf("invalid path structure")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getContentType returns the MIME type based on file extension
|
||||
func getContentType(filePath string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
|
||||
switch ext {
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
case ".doc":
|
||||
return "application/msword"
|
||||
case ".docx":
|
||||
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
case ".xls":
|
||||
return "application/vnd.ms-excel"
|
||||
case ".xlsx":
|
||||
return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
case ".ppt":
|
||||
return "application/vnd.ms-powerpoint"
|
||||
case ".pptx":
|
||||
return "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
case ".txt":
|
||||
return "text/plain"
|
||||
case ".csv":
|
||||
return "text/csv"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
@@ -262,21 +262,25 @@ func (m *GormModel) CreateAttachment(attachment *types.Attachment) (*types.Attac
|
||||
attachment.Uploaded = time.Now()
|
||||
attachment.Deleted = false
|
||||
|
||||
// For GORM implementation, we'd need to implement repository methods
|
||||
// For now, return an error indicating not implemented
|
||||
return nil, errors.New("attachment operations not yet implemented for GORM model")
|
||||
// Use repository to insert attachment
|
||||
err := m.repository.InsertAttachment(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
|
||||
return nil, errors.New("attachment operations not yet implemented for GORM model")
|
||||
return m.repository.GetAttachmentsByTransaction(transactionId, orgId, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
|
||||
return nil, errors.New("attachment operations not yet implemented for GORM model")
|
||||
return m.repository.GetAttachment(attachmentId, transactionId, orgId, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
|
||||
return errors.New("attachment operations not yet implemented for GORM model")
|
||||
return m.repository.DeleteAttachment(attachmentId, transactionId, orgId, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetTransactionById(id string) (*types.Transaction, error) {
|
||||
|
||||
@@ -372,4 +372,167 @@ func (r *GormRepository) Escape(sql string) string {
|
||||
// GORM handles SQL injection protection automatically
|
||||
// This method is kept for interface compatibility
|
||||
return sql
|
||||
}
|
||||
|
||||
// Attachment repository methods
|
||||
func (r *GormRepository) InsertAttachment(attachment *types.Attachment) error {
|
||||
// Convert UUID strings to bytes (remove dashes if present)
|
||||
idBytes, err := stringToIDBytes(attachment.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transactionIdBytes, err := stringToIDBytes(attachment.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
orgIdBytes, err := stringToIDBytes(attachment.OrgId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userIdBytes, err := stringToIDBytes(attachment.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert types.Attachment to models.Attachment
|
||||
gormAttachment := &models.Attachment{
|
||||
ID: idBytes,
|
||||
TransactionID: transactionIdBytes,
|
||||
OrgID: orgIdBytes,
|
||||
UserID: userIdBytes,
|
||||
FileName: attachment.FileName,
|
||||
OriginalName: attachment.OriginalName,
|
||||
ContentType: attachment.ContentType,
|
||||
FileSize: attachment.FileSize,
|
||||
FilePath: attachment.FilePath,
|
||||
Description: attachment.Description,
|
||||
Uploaded: attachment.Uploaded,
|
||||
Deleted: attachment.Deleted,
|
||||
}
|
||||
|
||||
result := r.db.Create(gormAttachment)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
|
||||
transactionIdBytes, err := stringToIDBytes(transactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
orgIdBytes, err := stringToIDBytes(orgId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var gormAttachments []models.Attachment
|
||||
result := r.db.Where("transactionId = ? AND orgId = ? AND deleted = ?",
|
||||
transactionIdBytes, orgIdBytes, false).Find(&gormAttachments)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
attachments := make([]*types.Attachment, len(gormAttachments))
|
||||
for i, gormAttachment := range gormAttachments {
|
||||
attachments[i] = convertGormToTypesAttachment(&gormAttachment)
|
||||
}
|
||||
|
||||
return attachments, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
|
||||
attachmentIdBytes, err := stringToIDBytes(attachmentId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var gormAttachment models.Attachment
|
||||
query := r.db.Where("id = ? AND deleted = ?", attachmentIdBytes, false)
|
||||
|
||||
// Add additional filters if provided
|
||||
if transactionId != "" {
|
||||
transactionIdBytes, err := stringToIDBytes(transactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = query.Where("transactionId = ?", transactionIdBytes)
|
||||
}
|
||||
|
||||
if orgId != "" {
|
||||
orgIdBytes, err := stringToIDBytes(orgId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = query.Where("orgId = ?", orgIdBytes)
|
||||
}
|
||||
|
||||
result := query.First(&gormAttachment)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return convertGormToTypesAttachment(&gormAttachment), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
|
||||
attachmentIdBytes, err := stringToIDBytes(attachmentId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := r.db.Model(&models.Attachment{}).Where("id = ?", attachmentIdBytes)
|
||||
|
||||
// Add additional filters if provided
|
||||
if transactionId != "" {
|
||||
transactionIdBytes, err := stringToIDBytes(transactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query = query.Where("transactionId = ?", transactionIdBytes)
|
||||
}
|
||||
|
||||
if orgId != "" {
|
||||
orgIdBytes, err := stringToIDBytes(orgId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query = query.Where("orgId = ?", orgIdBytes)
|
||||
}
|
||||
|
||||
// Soft delete by setting deleted = true
|
||||
result := query.Update("deleted", true)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// Helper function to convert UUID string (with or without dashes) to bytes
|
||||
func stringToIDBytes(id string) ([]byte, error) {
|
||||
// Remove dashes if present
|
||||
cleanId := strings.ReplaceAll(id, "-", "")
|
||||
return util.HexToBytes(cleanId)
|
||||
}
|
||||
|
||||
// Helper function to convert bytes to UUID string (without dashes, for compatibility)
|
||||
func idBytesToString(bytes []byte) string {
|
||||
return util.BytesToHex(bytes)
|
||||
}
|
||||
|
||||
// Helper function to convert GORM attachment to types attachment
|
||||
func convertGormToTypesAttachment(gormAttachment *models.Attachment) *types.Attachment {
|
||||
return &types.Attachment{
|
||||
Id: idBytesToString(gormAttachment.ID),
|
||||
TransactionId: idBytesToString(gormAttachment.TransactionID),
|
||||
OrgId: idBytesToString(gormAttachment.OrgID),
|
||||
UserId: idBytesToString(gormAttachment.UserID),
|
||||
FileName: gormAttachment.FileName,
|
||||
OriginalName: gormAttachment.OriginalName,
|
||||
ContentType: gormAttachment.ContentType,
|
||||
FileSize: gormAttachment.FileSize,
|
||||
FilePath: gormAttachment.FilePath,
|
||||
Description: gormAttachment.Description,
|
||||
Uploaded: gormAttachment.Uploaded,
|
||||
Deleted: gormAttachment.Deleted,
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,7 @@ func main() {
|
||||
viper.BindEnv("Storage.backend", "OA_STORAGE_BACKEND")
|
||||
viper.BindEnv("Storage.local.root_dir", "OA_STORAGE_LOCAL_ROOTDIR")
|
||||
viper.BindEnv("Storage.local.base_url", "OA_STORAGE_LOCAL_BASEURL")
|
||||
viper.BindEnv("Storage.local.signing_key", "OA_STORAGE_LOCAL_SIGNINGKEY")
|
||||
viper.BindEnv("Storage.s3.region", "OA_STORAGE_S3_REGION")
|
||||
viper.BindEnv("Storage.s3.bucket", "OA_STORAGE_S3_BUCKET")
|
||||
viper.BindEnv("Storage.s3.prefix", "OA_STORAGE_S3_PREFIX")
|
||||
@@ -59,6 +60,7 @@ func main() {
|
||||
viper.SetDefault("Storage.backend", "local")
|
||||
viper.SetDefault("Storage.local.root_dir", "./uploads")
|
||||
viper.SetDefault("Storage.local.base_url", "")
|
||||
viper.SetDefault("Storage.local.signing_key", "") // Will auto-generate if empty
|
||||
|
||||
// Read configuration
|
||||
err := viper.ReadInConfig()
|
||||
@@ -141,6 +143,9 @@ func main() {
|
||||
log.Fatal(fmt.Errorf("failed to initialize storage backend: %s", err.Error()))
|
||||
}
|
||||
|
||||
// Initialize secure file server with signing key for local storage
|
||||
api.InitSecureFileServer(config.Storage.Local.SigningKey)
|
||||
|
||||
app, err := api.Init(config.ApiPrefix)
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to create api instance with: %s", err.Error()))
|
||||
|
||||
@@ -19,6 +19,9 @@ type Storage interface {
|
||||
// GetURL returns a URL for accessing the file (may be signed/temporary)
|
||||
GetURL(path string, expiry time.Duration) (string, error)
|
||||
|
||||
// GetURLWithContext returns a URL for accessing the file with user context (for JWT tokens)
|
||||
GetURLWithContext(path string, expiry time.Duration, userID, orgID string) (string, error)
|
||||
|
||||
// Exists checks if a file exists at the given path
|
||||
Exists(path string) (bool, error)
|
||||
|
||||
@@ -53,6 +56,9 @@ type LocalConfig struct {
|
||||
|
||||
// Base URL for serving files (optional)
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
|
||||
// Signing key for JWT tokens (optional, will be auto-generated if empty)
|
||||
SigningKey string `mapstructure:"signing_key"`
|
||||
}
|
||||
|
||||
// S3Config configures S3-compatible storage (AWS S3, Backblaze B2, Cloudflare R2, etc.)
|
||||
|
||||
@@ -54,19 +54,13 @@ func TestNewStorage(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("B2 Storage", func(t *testing.T) {
|
||||
t.Run("Invalid Backend", func(t *testing.T) {
|
||||
config := Config{
|
||||
Backend: "b2",
|
||||
B2: B2Config{
|
||||
AccountID: "test-account",
|
||||
ApplicationKey: "test-key",
|
||||
Bucket: "test-bucket",
|
||||
},
|
||||
Backend: "invalid",
|
||||
}
|
||||
|
||||
// This will fail because we don't have real B2 credentials
|
||||
storage, err := NewStorage(config)
|
||||
assert.Error(t, err) // Expected to fail without credentials
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, storage)
|
||||
})
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package storage
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -13,8 +14,9 @@ import (
|
||||
|
||||
// LocalStorage implements the Storage interface for local filesystem
|
||||
type LocalStorage struct {
|
||||
rootDir string
|
||||
baseURL string
|
||||
rootDir string
|
||||
baseURL string
|
||||
tokenService *TokenService
|
||||
}
|
||||
|
||||
// NewLocalStorage creates a new local filesystem storage backend
|
||||
@@ -24,14 +26,18 @@ func NewLocalStorage(config LocalConfig) (*LocalStorage, error) {
|
||||
rootDir = "./uploads"
|
||||
}
|
||||
|
||||
// Ensure the root directory exists
|
||||
if err := os.MkdirAll(rootDir, 0755); err != nil {
|
||||
// Ensure the root directory exists with secure permissions
|
||||
if err := os.MkdirAll(rootDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage directory: %w", err)
|
||||
}
|
||||
|
||||
// Initialize token service for secure URL generation
|
||||
tokenService := NewTokenService(config.SigningKey)
|
||||
|
||||
return &LocalStorage{
|
||||
rootDir: rootDir,
|
||||
baseURL: config.BaseURL,
|
||||
rootDir: rootDir,
|
||||
baseURL: config.BaseURL,
|
||||
tokenService: tokenService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -41,14 +47,14 @@ func (l *LocalStorage) Store(filename string, content io.Reader, contentType str
|
||||
storagePath := l.generateStoragePath(filename)
|
||||
fullPath := filepath.Join(l.rootDir, storagePath)
|
||||
|
||||
// Ensure the directory exists
|
||||
// Ensure the directory exists with secure permissions
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Create and write the file
|
||||
file, err := os.Create(fullPath)
|
||||
// Create and write the file with secure permissions
|
||||
file, err := os.OpenFile(fullPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
@@ -102,8 +108,13 @@ func (l *LocalStorage) Delete(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetURL returns a URL for accessing the file
|
||||
// GetURL returns a secure URL for accessing the file with JWT token
|
||||
func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error) {
|
||||
return l.GetURLWithContext(path, expiry, "", "")
|
||||
}
|
||||
|
||||
// GetURLWithContext returns a secure URL for accessing the file with JWT token and user context
|
||||
func (l *LocalStorage) GetURLWithContext(path string, expiry time.Duration, userID, orgID string) (string, error) {
|
||||
// Validate path to prevent directory traversal
|
||||
if err := l.validatePath(path); err != nil {
|
||||
return "", err
|
||||
@@ -118,14 +129,16 @@ func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error)
|
||||
return "", &FileNotFoundError{Path: path}
|
||||
}
|
||||
|
||||
if l.baseURL != "" {
|
||||
// Return a public URL if base URL is configured
|
||||
return l.baseURL + "/" + path, nil
|
||||
// Generate secure JWT token for file access
|
||||
token, err := l.tokenService.GenerateFileToken(path, userID, orgID, expiry)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
// For local storage without a base URL, return the file path
|
||||
// In a real application, you might serve these through an endpoint
|
||||
return "/files/" + path, nil
|
||||
// Return secure URL with token parameter
|
||||
params := url.Values{}
|
||||
params.Set("token", token)
|
||||
return "/secure-files?" + params.Encode(), nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists at the given path
|
||||
|
||||
@@ -72,8 +72,11 @@ func TestLocalStorage(t *testing.T) {
|
||||
|
||||
url, err := storage.GetURL(path, time.Hour)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, url, path)
|
||||
assert.Contains(t, url, config.BaseURL)
|
||||
// New JWT token-based URLs should start with /secure-files and contain a token parameter
|
||||
assert.Contains(t, url, "/secure-files")
|
||||
assert.Contains(t, url, "token=")
|
||||
// The token should be a JWT (contains dots for header.payload.signature)
|
||||
assert.Contains(t, url, ".")
|
||||
})
|
||||
|
||||
t.Run("Delete File", func(t *testing.T) {
|
||||
|
||||
@@ -155,6 +155,12 @@ func (s *S3Storage) GetURL(path string, expiry time.Duration) (string, error) {
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// GetURLWithContext returns a presigned URL for accessing the file (S3 doesn't use user context)
|
||||
func (s *S3Storage) GetURLWithContext(path string, expiry time.Duration, userID, orgID string) (string, error) {
|
||||
// For S3, user context is not needed as presigned URLs are cryptographically secure
|
||||
return s.GetURL(path, expiry)
|
||||
}
|
||||
|
||||
// Exists checks if a file exists in S3
|
||||
func (s *S3Storage) Exists(path string) (bool, error) {
|
||||
input := &s3.HeadObjectInput{
|
||||
|
||||
83
core/storage/token.go
Normal file
83
core/storage/token.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
)
|
||||
|
||||
// TokenService handles JWT token generation and validation for file access
|
||||
type TokenService struct {
|
||||
signingKey []byte
|
||||
}
|
||||
|
||||
// NewTokenService creates a new token service with a signing key
|
||||
func NewTokenService(signingKey string) *TokenService {
|
||||
if signingKey == "" {
|
||||
// Generate a random signing key if none provided
|
||||
// In production, this should be a consistent secret from config
|
||||
signingKey = id.String(id.New())
|
||||
}
|
||||
|
||||
return &TokenService{
|
||||
signingKey: []byte(signingKey),
|
||||
}
|
||||
}
|
||||
|
||||
// FileClaims represents the JWT claims for file access
|
||||
type FileClaims struct {
|
||||
FilePath string `json:"file_path"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OrgID string `json:"org_id,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateFileToken creates a JWT token for accessing a specific file
|
||||
func (ts *TokenService) GenerateFileToken(filePath string, userID, orgID string, expiry time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
claims := FileClaims{
|
||||
FilePath: filePath,
|
||||
UserID: userID,
|
||||
OrgID: orgID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: "openaccounting-server",
|
||||
Subject: "file-access",
|
||||
ID: id.String(id.New()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(ts.signingKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateFileToken validates a JWT token and returns the file claims
|
||||
func (ts *TokenService) ValidateFileToken(tokenString string) (*FileClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &FileClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return ts.signingKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*FileClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
@@ -65,3 +65,13 @@ func IsValidUUID(uuid string) bool {
|
||||
matched, _ := regexp.MatchString("^[0-9a-f]{32}$", uuid)
|
||||
return matched
|
||||
}
|
||||
|
||||
// HexToBytes converts a hex string to bytes
|
||||
func HexToBytes(hexString string) ([]byte, error) {
|
||||
return hex.DecodeString(hexString)
|
||||
}
|
||||
|
||||
// BytesToHex converts bytes to a hex string
|
||||
func BytesToHex(bytes []byte) string {
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user