feat: implement unified S3-compatible storage system

Consolidates storage backends into a single S3-compatible driver that supports:
- AWS S3 (native)
- Backblaze B2 (S3-compatible API)
- Cloudflare R2 (S3-compatible API)
- MinIO and other S3-compatible services
- Local filesystem for development

This replaces the previous separate B2 driver with a unified approach,
reducing dependencies and complexity while adding support for more services.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-07-01 23:07:44 +12:00
parent e3152d9f40
commit f99a866e13
14 changed files with 1650 additions and 2 deletions

View File

@@ -0,0 +1,306 @@
package api
import (
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/core/util/id"
"github.com/openaccounting/oa-server/database"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupTestDatabase(t *testing.T) (*gorm.DB, func()) {
// Create temporary database file
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
// Set global DB for database package
database.DB = db
// Run migrations
err = database.AutoMigrate()
if err != nil {
t.Fatalf("Failed to run auto migrations: %v", err)
}
err = database.Migrate()
if err != nil {
t.Fatalf("Failed to run custom migrations: %v", err)
}
// Cleanup function
cleanup := func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
os.RemoveAll(tmpDir)
}
return db, cleanup
}
func setupTestData(t *testing.T, db *gorm.DB) (orgID, userID, transactionID string) {
// Use hardcoded UUIDs without dashes for hex format
orgID = "550e8400e29b41d4a716446655440000"
userID = "550e8400e29b41d4a716446655440001"
transactionID = "550e8400e29b41d4a716446655440002"
accountID := "550e8400e29b41d4a716446655440003"
// Insert test data using raw SQL for reliability
now := time.Now()
// Insert org
err := db.Exec("INSERT INTO orgs (id, inserted, updated, name, currency, `precision`, timezone) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?)",
orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Org", "USD", 2, "UTC").Error
if err != nil {
t.Fatalf("Failed to insert org: %v", err)
}
// Insert user
err = db.Exec("INSERT INTO users (id, inserted, updated, firstName, lastName, email, passwordHash, agreeToTerms, passwordReset, emailVerified, emailVerifyCode, signupSource) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test", "User", "test@example.com", "hashedpassword", true, "", true, "", "test").Error
if err != nil {
t.Fatalf("Failed to insert user: %v", err)
}
// Insert user-org relationship
err = db.Exec("INSERT INTO user_orgs (userId, orgId, admin) VALUES (UNHEX(?), UNHEX(?), ?)",
userID, orgID, false).Error
if err != nil {
t.Fatalf("Failed to insert user-org: %v", err)
}
// Insert account
err = db.Exec("INSERT INTO accounts (id, orgId, inserted, updated, name, parent, currency, `precision`, debitBalance) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?, ?)",
accountID, orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Account", []byte{}, "USD", 2, true).Error
if err != nil {
t.Fatalf("Failed to insert account: %v", err)
}
// Insert transaction
err = db.Exec("INSERT INTO transactions (id, orgId, userId, inserted, updated, date, description, data, deleted) VALUES (UNHEX(?), UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
transactionID, orgID, userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Transaction", "", false).Error
if err != nil {
t.Fatalf("Failed to insert transaction: %v", err)
}
// Insert split
err = db.Exec("INSERT INTO splits (transactionId, accountId, date, inserted, updated, amount, nativeAmount, deleted) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
transactionID, accountID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), 100, 100, false).Error
if err != nil {
t.Fatalf("Failed to insert split: %v", err)
}
return orgID, userID, transactionID
}
func createTestFile(t *testing.T) (string, []byte) {
content := []byte("This is a test file content for attachment testing")
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.txt")
err := os.WriteFile(filePath, content, 0644)
if err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
return filePath, content
}
func TestAttachmentIntegration(t *testing.T) {
db, cleanup := setupTestDatabase(t)
defer cleanup()
orgID, userID, transactionID := setupTestData(t, db)
// Set up the model instance for the API handlers
bc := &util.StandardBcrypt{}
// Use the existing datastore model which has the attachment implementation
// We need to create it with the database connection
datastoreModel := model.NewModel(nil, bc, types.Config{})
model.Instance = datastoreModel
t.Run("Database Integration Test", func(t *testing.T) {
// Test direct database operations first
filePath, originalContent := createTestFile(t)
defer os.Remove(filePath)
// Create attachment record directly
attachmentID := id.String(id.New())
uploadTime := time.Now()
attachment := types.Attachment{
Id: attachmentID,
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
FileName: "stored_test.txt",
OriginalName: "test.txt",
ContentType: "text/plain",
FileSize: int64(len(originalContent)),
FilePath: "uploads/test/" + attachmentID + ".txt",
Description: "Test attachment description",
Uploaded: uploadTime,
Deleted: false,
}
// Insert using the existing model
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.NoError(t, err)
assert.NotNil(t, createdAttachment)
assert.Equal(t, attachmentID, createdAttachment.Id)
// Verify database persistence
var dbAttachment types.Attachment
err = db.Raw("SELECT HEX(id) as id, HEX(transactionId) as transactionId, HEX(orgId) as orgId, HEX(userId) as userId, fileName, originalName, contentType, fileSize, filePath, description, uploaded, deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&dbAttachment).Error
assert.NoError(t, err)
assert.Equal(t, attachmentID, dbAttachment.Id)
assert.Equal(t, transactionID, dbAttachment.TransactionId)
assert.Equal(t, "Test attachment description", dbAttachment.Description)
// Test retrieval
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
assert.NotNil(t, retrievedAttachment)
assert.Equal(t, attachmentID, retrievedAttachment.Id)
// Test listing by transaction
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionID, orgID, userID)
assert.NoError(t, err)
assert.Len(t, attachments, 1)
assert.Equal(t, attachmentID, attachments[0].Id)
// Test soft deletion
err = model.Instance.DeleteAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
// Verify soft deletion in database
var deletedAttachment types.Attachment
err = db.Raw("SELECT deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&deletedAttachment).Error
assert.NoError(t, err)
assert.True(t, deletedAttachment.Deleted)
// Verify attachment is no longer accessible
retrievedAttachment, err = model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.Error(t, err)
assert.Nil(t, retrievedAttachment)
})
t.Run("File Upload Integration Test", func(t *testing.T) {
// Test file upload functionality
filePath, originalContent := createTestFile(t)
defer os.Remove(filePath)
// Create upload directory
uploadDir := "uploads/test"
os.MkdirAll(uploadDir, 0755)
defer os.RemoveAll("uploads")
// Simulate file upload process
attachmentID := id.String(id.New())
storedFilePath := filepath.Join(uploadDir, attachmentID+".txt")
// Copy file to upload location
err := copyFile(filePath, storedFilePath)
assert.NoError(t, err)
// Create attachment record
attachment := types.Attachment{
Id: attachmentID,
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
FileName: filepath.Base(storedFilePath),
OriginalName: "test.txt",
ContentType: "text/plain",
FileSize: int64(len(originalContent)),
FilePath: storedFilePath,
Description: "Uploaded test file",
Uploaded: time.Now(),
Deleted: false,
}
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.NoError(t, err)
assert.NotNil(t, createdAttachment)
// Verify file exists
_, err = os.Stat(storedFilePath)
assert.NoError(t, err)
// Verify database record
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
assert.Equal(t, storedFilePath, retrievedAttachment.FilePath)
assert.Equal(t, int64(len(originalContent)), retrievedAttachment.FileSize)
})
}
// Helper function to copy files
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
func TestAttachmentValidation(t *testing.T) {
db, cleanup := setupTestDatabase(t)
defer cleanup()
orgID, userID, transactionID := setupTestData(t, db)
// Set up the model instance
bc := &util.StandardBcrypt{}
gormModel := model.NewGormModel(db, bc, types.Config{})
model.Instance = gormModel
t.Run("Invalid attachment data", func(t *testing.T) {
// Test with missing required fields
attachment := types.Attachment{
// Missing ID
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
}
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.Error(t, err)
assert.Nil(t, createdAttachment)
})
t.Run("Non-existent attachment retrieval", func(t *testing.T) {
nonExistentID := id.String(id.New())
attachment, err := model.Instance.GetAttachment(nonExistentID, transactionID, orgID, userID)
assert.Error(t, err)
assert.Nil(t, attachment)
})
}

View File

@@ -0,0 +1,289 @@
package api
import (
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
"github.com/ant0ine/go-json-rest/rest"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/storage"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/core/util/id"
)
// AttachmentHandler handles attachment operations with configurable storage
type AttachmentHandler struct {
storage storage.Storage
}
// Global attachment handler instance (will be initialized during server startup)
var attachmentHandler *AttachmentHandler
// InitializeAttachmentHandler initializes the global attachment handler with storage backend
func InitializeAttachmentHandler(storageConfig storage.Config) error {
storageBackend, err := storage.NewStorage(storageConfig)
if err != nil {
return fmt.Errorf("failed to initialize storage backend: %w", err)
}
attachmentHandler = &AttachmentHandler{
storage: storageBackend,
}
return nil
}
// PostAttachmentWithStorage handles file upload using the configured storage backend
func PostAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
transactionId := r.FormValue("transactionId")
if transactionId == "" {
rest.Error(w, "Transaction ID is required", http.StatusBadRequest)
return
}
if !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid transaction ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Parse multipart form
err := r.ParseMultipartForm(MaxFileSize)
if err != nil {
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return
}
files := r.MultipartForm.File["file"]
if len(files) == 0 {
rest.Error(w, "No file provided", http.StatusBadRequest)
return
}
fileHeader := files[0] // Take the first file
// Verify transaction exists and user has permission
tx, err := model.Instance.GetTransaction(transactionId, "", user.Id)
if err != nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
if tx == nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
// Process the file upload
attachment, err := attachmentHandler.processFileUploadWithStorage(fileHeader, transactionId, tx.OrgId, user.Id, r.FormValue("description"))
if err != nil {
rest.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Save attachment to database
createdAttachment, err := model.Instance.CreateAttachment(attachment)
if err != nil {
// Clean up the stored file on database error
attachmentHandler.storage.Delete(attachment.FilePath)
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
w.WriteJson(createdAttachment)
}
// GetAttachmentWithStorage retrieves an attachment using the configured storage backend
func GetAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Check if this is a download request
if r.URL.Query().Get("download") == "true" {
// Stream the file directly to the client
err := attachmentHandler.streamFile(w, attachment)
if err != nil {
rest.Error(w, "Failed to retrieve file", http.StatusInternalServerError)
return
}
return
}
// Return attachment metadata
w.WriteJson(attachment)
}
// GetAttachmentDownloadURL returns a download URL for an attachment
func GetAttachmentDownloadURL(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Generate download URL (valid for 1 hour)
url, err := attachmentHandler.storage.GetURL(attachment.FilePath, time.Hour)
if err != nil {
rest.Error(w, "Failed to generate download URL", http.StatusInternalServerError)
return
}
response := map[string]string{
"url": url,
"expiresIn": "3600", // 1 hour in seconds
}
w.WriteJson(response)
}
// DeleteAttachmentWithStorage deletes an attachment using the configured storage backend
func DeleteAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database first
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Delete from database (soft delete)
err = model.Instance.DeleteAttachment(attachmentId, attachment.TransactionId, attachment.OrgId, user.Id)
if err != nil {
rest.Error(w, "Failed to delete attachment", http.StatusInternalServerError)
return
}
// Delete from storage backend
// Note: For production, you might want to delay physical deletion
// and run a cleanup job later to handle any issues
err = attachmentHandler.storage.Delete(attachment.FilePath)
if err != nil {
// Log the error but don't fail the request since database deletion succeeded
// The file can be cleaned up later by a maintenance job
fmt.Printf("Warning: Failed to delete file from storage: %v\n", err)
}
w.WriteHeader(http.StatusOK)
w.WriteJson(map[string]string{"status": "deleted"})
}
// processFileUploadWithStorage processes a file upload using the storage backend
func (h *AttachmentHandler) processFileUploadWithStorage(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
// Validate file size
if fileHeader.Size > MaxFileSize {
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
}
// Validate content type
contentType := fileHeader.Header.Get("Content-Type")
if !AllowedMimeTypes[contentType] {
return nil, fmt.Errorf("unsupported file type: %s", contentType)
}
// Open the file
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
}
defer file.Close()
// Store the file using the storage backend
storagePath, err := h.storage.Store(fileHeader.Filename, file, contentType)
if err != nil {
return nil, fmt.Errorf("failed to store file: %w", err)
}
// Create attachment record
attachment := &types.Attachment{
Id: id.String(id.New()),
TransactionId: transactionId,
OrgId: orgId,
UserId: userId,
FileName: storagePath, // Store the storage path/key
OriginalName: fileHeader.Filename,
ContentType: contentType,
FileSize: fileHeader.Size,
FilePath: storagePath, // For backward compatibility
Description: description,
Uploaded: time.Now(),
Deleted: false,
}
return attachment, nil
}
// streamFile streams a file from storage to the HTTP response
func (h *AttachmentHandler) streamFile(w rest.ResponseWriter, attachment *types.Attachment) error {
// Get file from storage
reader, err := h.storage.Retrieve(attachment.FilePath)
if err != nil {
return fmt.Errorf("failed to retrieve file: %w", err)
}
defer reader.Close()
// Set appropriate headers
w.Header().Set("Content-Type", attachment.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
// If we know the file size, set Content-Length
if attachment.FileSize > 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", attachment.FileSize))
}
// Stream the file to the client
_, err = io.Copy(w.(http.ResponseWriter), reader)
return err
}

View File

@@ -36,6 +36,12 @@ func GetRouter(auth *AuthMiddleware, prefix string) (rest.App, error) {
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(GetAttachment)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId/download", auth.RequireAuth(DownloadAttachment)),
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(DeleteAttachment)),
// New storage-based attachment endpoints
rest.Post(prefix+"/attachments", auth.RequireAuth(PostAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id", auth.RequireAuth(GetAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id/url", auth.RequireAuth(GetAttachmentDownloadURL)),
rest.Delete(prefix+"/attachments/:id", auth.RequireAuth(DeleteAttachmentWithStorage)),
rest.Get(prefix+"/orgs/:orgId/prices", auth.RequireAuth(GetPrices)),
rest.Post(prefix+"/orgs/:orgId/prices", auth.RequireAuth(PostPrice)),
rest.Delete(prefix+"/orgs/:orgId/prices/:priceId", auth.RequireAuth(DeletePrice)),

View File

@@ -1,5 +1,7 @@
package types
import "github.com/openaccounting/oa-server/core/storage"
type Config struct {
WebUrl string `mapstructure:"weburl"`
Address string `mapstructure:"address"`
@@ -15,8 +17,11 @@ type Config struct {
Password string `mapstructure:"password"` // Sensitive: use OA_PASSWORD env var
// SQLite specific
DatabaseFile string `mapstructure:"databasefile"`
// Email configuration
MailgunDomain string `mapstructure:"mailgundomain"`
MailgunKey string `mapstructure:"mailgunkey"` // Sensitive: use OA_MAILGUN_KEY env var
MailgunEmail string `mapstructure:"mailgunemail"`
MailgunSender string `mapstructure:"mailgunsender"`
// Storage configuration
Storage storage.Config `mapstructure:"storage"`
}

View File

@@ -5,6 +5,7 @@ import (
"log"
"net/http"
"strconv"
"strings"
"github.com/openaccounting/oa-server/core/api"
"github.com/openaccounting/oa-server/core/auth"
@@ -31,6 +32,22 @@ func main() {
viper.AutomaticEnv()
viper.SetEnvPrefix("OA") // will look for OA_DATABASE_PASSWORD, etc.
// Configure Viper to handle nested config with environment variables
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Bind specific storage environment variables for better support
// Using mapstructure field names (snake_case)
viper.BindEnv("Storage.backend", "OA_STORAGE_BACKEND")
viper.BindEnv("Storage.local.root_dir", "OA_STORAGE_LOCAL_ROOTDIR")
viper.BindEnv("Storage.local.base_url", "OA_STORAGE_LOCAL_BASEURL")
viper.BindEnv("Storage.s3.region", "OA_STORAGE_S3_REGION")
viper.BindEnv("Storage.s3.bucket", "OA_STORAGE_S3_BUCKET")
viper.BindEnv("Storage.s3.prefix", "OA_STORAGE_S3_PREFIX")
viper.BindEnv("Storage.s3.access_key_id", "OA_STORAGE_S3_ACCESSKEYID")
viper.BindEnv("Storage.s3.secret_access_key", "OA_STORAGE_S3_SECRETACCESSKEY")
viper.BindEnv("Storage.s3.endpoint", "OA_STORAGE_S3_ENDPOINT")
viper.BindEnv("Storage.s3.path_style", "OA_STORAGE_S3_PATHSTYLE")
// Set default values
viper.SetDefault("Address", "localhost")
viper.SetDefault("Port", 8080)
@@ -38,6 +55,11 @@ func main() {
viper.SetDefault("DatabaseFile", "./openaccounting.db")
viper.SetDefault("ApiPrefix", "/api/v1")
// Set storage defaults (using mapstructure field names)
viper.SetDefault("Storage.backend", "local")
viper.SetDefault("Storage.local.root_dir", "./uploads")
viper.SetDefault("Storage.local.base_url", "")
// Read configuration
err := viper.ReadInConfig()
if err != nil {
@@ -50,6 +72,14 @@ func main() {
if err != nil {
log.Fatal(fmt.Errorf("failed to unmarshal config: %s", err.Error()))
}
// Set storage defaults if not configured (Viper doesn't handle nested defaults well)
if config.Storage.Backend == "" {
config.Storage.Backend = "local"
}
if config.Storage.Local.RootDir == "" {
config.Storage.Local.RootDir = "./uploads"
}
// Parse database address (assuming format host:port for MySQL)
host := config.DatabaseAddress
@@ -105,6 +135,12 @@ func main() {
// Set the global model instance
model.Instance = gormModel
// Initialize storage backend for attachments
err = api.InitializeAttachmentHandler(config.Storage)
if err != nil {
log.Fatal(fmt.Errorf("failed to initialize storage backend: %s", err.Error()))
}
app, err := api.Init(config.ApiPrefix)
if err != nil {
log.Fatal(fmt.Errorf("failed to create api instance with: %s", err.Error()))

106
core/storage/interface.go Normal file
View File

@@ -0,0 +1,106 @@
package storage
import (
"io"
"time"
)
// Storage defines the interface for file storage backends
type Storage interface {
// Store saves a file and returns the storage path/key
Store(filename string, content io.Reader, contentType string) (string, error)
// Retrieve gets a file by its storage path/key
Retrieve(path string) (io.ReadCloser, error)
// Delete removes a file by its storage path/key
Delete(path string) error
// GetURL returns a URL for accessing the file (may be signed/temporary)
GetURL(path string, expiry time.Duration) (string, error)
// Exists checks if a file exists at the given path
Exists(path string) (bool, error)
// GetMetadata returns file metadata (size, last modified, etc.)
GetMetadata(path string) (*FileMetadata, error)
}
// FileMetadata contains information about a stored file
type FileMetadata struct {
Size int64
LastModified time.Time
ContentType string
ETag string
}
// Config holds configuration for storage backends
type Config struct {
// Storage backend type: "local", "s3"
Backend string `mapstructure:"backend"`
// Local filesystem configuration
Local LocalConfig `mapstructure:"local"`
// S3-compatible storage configuration (S3, B2, R2, etc.)
S3 S3Config `mapstructure:"s3"`
}
// LocalConfig configures local filesystem storage
type LocalConfig struct {
// Root directory for file storage
RootDir string `mapstructure:"root_dir"`
// Base URL for serving files (optional)
BaseURL string `mapstructure:"base_url"`
}
// S3Config configures S3-compatible storage (AWS S3, Backblaze B2, Cloudflare R2, etc.)
type S3Config struct {
// AWS Region (use "auto" for Cloudflare R2)
Region string `mapstructure:"region"`
// S3 Bucket name
Bucket string `mapstructure:"bucket"`
// Optional prefix for all objects
Prefix string `mapstructure:"prefix"`
// Access Key ID
AccessKeyID string `mapstructure:"access_key_id"`
// Secret Access Key
SecretAccessKey string `mapstructure:"secret_access_key"`
// Custom endpoint URL for S3-compatible services:
// - Backblaze B2: https://s3.us-west-004.backblazeb2.com
// - Cloudflare R2: https://<account-id>.r2.cloudflarestorage.com
// - MinIO: http://localhost:9000
// Leave empty for AWS S3
Endpoint string `mapstructure:"endpoint"`
// Use path-style addressing (required for some S3-compatible services)
PathStyle bool `mapstructure:"path_style"`
}
// NewStorage creates a new storage backend based on configuration
func NewStorage(config Config) (Storage, error) {
switch config.Backend {
case "local", "":
return NewLocalStorage(config.Local)
case "s3":
return NewS3Storage(config.S3)
default:
return nil, &UnsupportedBackendError{Backend: config.Backend}
}
}
// UnsupportedBackendError is returned when an unknown storage backend is requested
type UnsupportedBackendError struct {
Backend string
}
func (e *UnsupportedBackendError) Error() string {
return "unsupported storage backend: " + e.Backend
}

View File

@@ -0,0 +1,101 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewStorage(t *testing.T) {
t.Run("Local Storage", func(t *testing.T) {
config := Config{
Backend: "local",
Local: LocalConfig{
RootDir: t.TempDir(),
},
}
storage, err := NewStorage(config)
assert.NoError(t, err)
assert.IsType(t, &LocalStorage{}, storage)
})
t.Run("Default to Local Storage", func(t *testing.T) {
config := Config{
// No backend specified
Local: LocalConfig{
RootDir: t.TempDir(),
},
}
storage, err := NewStorage(config)
assert.NoError(t, err)
assert.IsType(t, &LocalStorage{}, storage)
})
t.Run("S3 Storage", func(t *testing.T) {
config := Config{
Backend: "s3",
S3: S3Config{
Region: "us-east-1",
Bucket: "test-bucket",
},
}
// This might succeed if AWS credentials are available via IAM roles or env vars
// Let's just check that we get an S3Storage instance or an error
storage, err := NewStorage(config)
if err != nil {
// If it fails, that's expected in test environments without AWS access
assert.Nil(t, storage)
} else {
// If it succeeds, we should get an S3Storage instance
assert.IsType(t, &S3Storage{}, storage)
}
})
t.Run("B2 Storage", func(t *testing.T) {
config := Config{
Backend: "b2",
B2: B2Config{
AccountID: "test-account",
ApplicationKey: "test-key",
Bucket: "test-bucket",
},
}
// This will fail because we don't have real B2 credentials
storage, err := NewStorage(config)
assert.Error(t, err) // Expected to fail without credentials
assert.Nil(t, storage)
})
t.Run("Unsupported Backend", func(t *testing.T) {
config := Config{
Backend: "unsupported",
}
storage, err := NewStorage(config)
assert.Error(t, err)
assert.IsType(t, &UnsupportedBackendError{}, err)
assert.Nil(t, storage)
assert.Contains(t, err.Error(), "unsupported")
})
}
func TestStorageErrors(t *testing.T) {
t.Run("UnsupportedBackendError", func(t *testing.T) {
err := &UnsupportedBackendError{Backend: "ftp"}
assert.Equal(t, "unsupported storage backend: ftp", err.Error())
})
t.Run("FileNotFoundError", func(t *testing.T) {
err := &FileNotFoundError{Path: "missing.txt"}
assert.Equal(t, "file not found: missing.txt", err.Error())
})
t.Run("InvalidPathError", func(t *testing.T) {
err := &InvalidPathError{Path: "../../../etc/passwd"}
assert.Equal(t, "invalid path: ../../../etc/passwd", err.Error())
})
}

243
core/storage/local.go Normal file
View File

@@ -0,0 +1,243 @@
package storage
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/openaccounting/oa-server/core/util/id"
)
// LocalStorage implements the Storage interface for local filesystem
type LocalStorage struct {
rootDir string
baseURL string
}
// NewLocalStorage creates a new local filesystem storage backend
func NewLocalStorage(config LocalConfig) (*LocalStorage, error) {
rootDir := config.RootDir
if rootDir == "" {
rootDir = "./uploads"
}
// Ensure the root directory exists
if err := os.MkdirAll(rootDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create storage directory: %w", err)
}
return &LocalStorage{
rootDir: rootDir,
baseURL: config.BaseURL,
}, nil
}
// Store saves a file to the local filesystem
func (l *LocalStorage) Store(filename string, content io.Reader, contentType string) (string, error) {
// Generate a unique storage path
storagePath := l.generateStoragePath(filename)
fullPath := filepath.Join(l.rootDir, storagePath)
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
// Create and write the file
file, err := os.Create(fullPath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
_, err = io.Copy(file, content)
if err != nil {
// Clean up the file if write failed
os.Remove(fullPath)
return "", fmt.Errorf("failed to write file: %w", err)
}
return storagePath, nil
}
// Retrieve gets a file from the local filesystem
func (l *LocalStorage) Retrieve(path string) (io.ReadCloser, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return nil, err
}
fullPath := filepath.Join(l.rootDir, path)
file, err := os.Open(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, &FileNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to open file: %w", err)
}
return file, nil
}
// Delete removes a file from the local filesystem
func (l *LocalStorage) Delete(path string) error {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return err
}
fullPath := filepath.Join(l.rootDir, path)
err := os.Remove(fullPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete file: %w", err)
}
// Try to remove empty parent directories
l.cleanupEmptyDirs(filepath.Dir(fullPath))
return nil
}
// GetURL returns a URL for accessing the file
func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return "", err
}
// Check if file exists
exists, err := l.Exists(path)
if err != nil {
return "", err
}
if !exists {
return "", &FileNotFoundError{Path: path}
}
if l.baseURL != "" {
// Return a public URL if base URL is configured
return l.baseURL + "/" + path, nil
}
// For local storage without a base URL, return the file path
// In a real application, you might serve these through an endpoint
return "/files/" + path, nil
}
// Exists checks if a file exists at the given path
func (l *LocalStorage) Exists(path string) (bool, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return false, err
}
fullPath := filepath.Join(l.rootDir, path)
_, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check file existence: %w", err)
}
return true, nil
}
// GetMetadata returns file metadata
func (l *LocalStorage) GetMetadata(path string) (*FileMetadata, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return nil, err
}
fullPath := filepath.Join(l.rootDir, path)
info, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, &FileNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to get file metadata: %w", err)
}
return &FileMetadata{
Size: info.Size(),
LastModified: info.ModTime(),
ContentType: "", // Local storage doesn't store content type
ETag: "", // Local storage doesn't have ETags
}, nil
}
// generateStoragePath creates a unique storage path for a file
func (l *LocalStorage) generateStoragePath(filename string) string {
// Generate a unique ID for the file
fileID := id.String(id.New())
// Extract file extension
ext := filepath.Ext(filename)
// Create a path structure: YYYY/MM/DD/uuid.ext
now := time.Now()
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
return filepath.Join(datePath, fileID+ext)
}
// validatePath ensures the path doesn't contain directory traversal attempts
func (l *LocalStorage) validatePath(path string) error {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Reject paths that try to go up directories
if strings.Contains(cleanPath, "..") {
return &InvalidPathError{Path: path}
}
// Reject absolute paths
if filepath.IsAbs(cleanPath) {
return &InvalidPathError{Path: path}
}
return nil
}
// cleanupEmptyDirs removes empty parent directories up to the root
func (l *LocalStorage) cleanupEmptyDirs(dir string) {
// Don't remove the root directory
if dir == l.rootDir {
return
}
// Check if directory is empty
entries, err := os.ReadDir(dir)
if err != nil || len(entries) > 0 {
return
}
// Remove empty directory
if err := os.Remove(dir); err == nil {
// Recursively clean parent directories
l.cleanupEmptyDirs(filepath.Dir(dir))
}
}
// FileNotFoundError is returned when a file doesn't exist
type FileNotFoundError struct {
Path string
}
func (e *FileNotFoundError) Error() string {
return "file not found: " + e.Path
}
// InvalidPathError is returned when a path is invalid or contains traversal attempts
type InvalidPathError struct {
Path string
}
func (e *InvalidPathError) Error() string {
return "invalid path: " + e.Path
}

202
core/storage/local_test.go Normal file
View File

@@ -0,0 +1,202 @@
package storage
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLocalStorage(t *testing.T) {
// Create temporary directory for testing
tmpDir := t.TempDir()
config := LocalConfig{
RootDir: tmpDir,
BaseURL: "http://localhost:8080/files",
}
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.NotNil(t, storage)
t.Run("Store and Retrieve File", func(t *testing.T) {
content := []byte("test file content")
reader := bytes.NewReader(content)
// Store file
path, err := storage.Store("test.txt", reader, "text/plain")
assert.NoError(t, err)
assert.NotEmpty(t, path)
// Verify file exists
exists, err := storage.Exists(path)
assert.NoError(t, err)
assert.True(t, exists)
// Retrieve file
retrievedReader, err := storage.Retrieve(path)
assert.NoError(t, err)
defer retrievedReader.Close()
retrievedContent, err := io.ReadAll(retrievedReader)
assert.NoError(t, err)
assert.Equal(t, content, retrievedContent)
})
t.Run("Get File Metadata", func(t *testing.T) {
content := []byte("metadata test content")
reader := bytes.NewReader(content)
path, err := storage.Store("metadata.txt", reader, "text/plain")
assert.NoError(t, err)
metadata, err := storage.GetMetadata(path)
assert.NoError(t, err)
assert.Equal(t, int64(len(content)), metadata.Size)
assert.False(t, metadata.LastModified.IsZero())
})
t.Run("Get File URL", func(t *testing.T) {
content := []byte("url test content")
reader := bytes.NewReader(content)
path, err := storage.Store("url.txt", reader, "text/plain")
assert.NoError(t, err)
url, err := storage.GetURL(path, time.Hour)
assert.NoError(t, err)
assert.Contains(t, url, path)
assert.Contains(t, url, config.BaseURL)
})
t.Run("Delete File", func(t *testing.T) {
content := []byte("delete test content")
reader := bytes.NewReader(content)
path, err := storage.Store("delete.txt", reader, "text/plain")
assert.NoError(t, err)
// Verify file exists
exists, err := storage.Exists(path)
assert.NoError(t, err)
assert.True(t, exists)
// Delete file
err = storage.Delete(path)
assert.NoError(t, err)
// Verify file no longer exists
exists, err = storage.Exists(path)
assert.NoError(t, err)
assert.False(t, exists)
})
t.Run("Path Validation", func(t *testing.T) {
// Test directory traversal prevention
_, err := storage.Retrieve("../../../etc/passwd")
assert.Error(t, err)
assert.IsType(t, &InvalidPathError{}, err)
// Test absolute path rejection
_, err = storage.Retrieve("/etc/passwd")
assert.Error(t, err)
assert.IsType(t, &InvalidPathError{}, err)
})
t.Run("File Not Found", func(t *testing.T) {
_, err := storage.Retrieve("nonexistent.txt")
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
_, err = storage.GetMetadata("nonexistent.txt")
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
_, err = storage.GetURL("nonexistent.txt", time.Hour)
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
})
t.Run("Storage Path Generation", func(t *testing.T) {
content := []byte("path test content")
reader1 := bytes.NewReader(content)
reader2 := bytes.NewReader(content)
// Store two files with same name
path1, err := storage.Store("same.txt", reader1, "text/plain")
assert.NoError(t, err)
path2, err := storage.Store("same.txt", reader2, "text/plain")
assert.NoError(t, err)
// Paths should be different (unique)
assert.NotEqual(t, path1, path2)
// Both should exist
exists1, err := storage.Exists(path1)
assert.NoError(t, err)
assert.True(t, exists1)
exists2, err := storage.Exists(path2)
assert.NoError(t, err)
assert.True(t, exists2)
// Both should have correct extension
assert.True(t, strings.HasSuffix(path1, ".txt"))
assert.True(t, strings.HasSuffix(path2, ".txt"))
// Should be organized by date
now := time.Now()
expectedPrefix := filepath.Join(
fmt.Sprintf("%04d", now.Year()),
fmt.Sprintf("%02d", now.Month()),
fmt.Sprintf("%02d", now.Day()),
)
assert.True(t, strings.HasPrefix(path1, expectedPrefix))
assert.True(t, strings.HasPrefix(path2, expectedPrefix))
})
}
func TestLocalStorageConfig(t *testing.T) {
t.Run("Default Root Directory", func(t *testing.T) {
config := LocalConfig{} // Empty config
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.NotNil(t, storage)
// Should create default uploads directory
assert.Equal(t, "./uploads", storage.rootDir)
// Verify directory was created
_, err = os.Stat("./uploads")
assert.NoError(t, err)
// Clean up
os.RemoveAll("./uploads")
})
t.Run("Custom Root Directory", func(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom", "storage")
config := LocalConfig{
RootDir: customDir,
}
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.Equal(t, customDir, storage.rootDir)
// Verify custom directory was created
_, err = os.Stat(customDir)
assert.NoError(t, err)
})
}

236
core/storage/s3.go Normal file
View File

@@ -0,0 +1,236 @@
package storage
import (
"fmt"
"io"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/openaccounting/oa-server/core/util/id"
)
// S3Storage implements the Storage interface for Amazon S3
type S3Storage struct {
client *s3.S3
uploader *s3manager.Uploader
bucket string
prefix string
}
// NewS3Storage creates a new S3 storage backend
func NewS3Storage(config S3Config) (*S3Storage, error) {
if config.Bucket == "" {
return nil, fmt.Errorf("S3 bucket name is required")
}
// Create AWS config
awsConfig := &aws.Config{
Region: aws.String(config.Region),
}
// Set custom endpoint if provided (for S3-compatible services)
if config.Endpoint != "" {
awsConfig.Endpoint = aws.String(config.Endpoint)
awsConfig.S3ForcePathStyle = aws.Bool(config.PathStyle)
}
// Set credentials if provided
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
awsConfig.Credentials = credentials.NewStaticCredentials(
config.AccessKeyID,
config.SecretAccessKey,
"",
)
}
// Create session
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create AWS session: %w", err)
}
// Create S3 client
client := s3.New(sess)
uploader := s3manager.NewUploader(sess)
return &S3Storage{
client: client,
uploader: uploader,
bucket: config.Bucket,
prefix: config.Prefix,
}, nil
}
// Store saves a file to S3
func (s *S3Storage) Store(filename string, content io.Reader, contentType string) (string, error) {
// Generate a unique storage key
storageKey := s.generateStorageKey(filename)
// Prepare upload input
input := &s3manager.UploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(storageKey),
Body: content,
}
// Set content type if provided
if contentType != "" {
input.ContentType = aws.String(contentType)
}
// Upload the file
_, err := s.uploader.Upload(input)
if err != nil {
return "", fmt.Errorf("failed to upload file to S3: %w", err)
}
return storageKey, nil
}
// Retrieve gets a file from S3
func (s *S3Storage) Retrieve(path string) (io.ReadCloser, error) {
input := &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
result, err := s.client.GetObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey:
return nil, &FileNotFoundError{Path: path}
}
}
return nil, fmt.Errorf("failed to retrieve file from S3: %w", err)
}
return result.Body, nil
}
// Delete removes a file from S3
func (s *S3Storage) Delete(path string) error {
input := &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
_, err := s.client.DeleteObject(input)
if err != nil {
return fmt.Errorf("failed to delete file from S3: %w", err)
}
return nil
}
// GetURL returns a presigned URL for accessing the file
func (s *S3Storage) GetURL(path string, expiry time.Duration) (string, error) {
// Check if file exists first
exists, err := s.Exists(path)
if err != nil {
return "", err
}
if !exists {
return "", &FileNotFoundError{Path: path}
}
// Generate presigned URL
req, _ := s.client.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
})
url, err := req.Presign(expiry)
if err != nil {
return "", fmt.Errorf("failed to generate presigned URL: %w", err)
}
return url, nil
}
// Exists checks if a file exists in S3
func (s *S3Storage) Exists(path string) (bool, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
_, err := s.client.HeadObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey, "NotFound":
return false, nil
}
}
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
}
return true, nil
}
// GetMetadata returns file metadata from S3
func (s *S3Storage) GetMetadata(path string) (*FileMetadata, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
result, err := s.client.HeadObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey, "NotFound":
return nil, &FileNotFoundError{Path: path}
}
}
return nil, fmt.Errorf("failed to get file metadata from S3: %w", err)
}
metadata := &FileMetadata{
Size: aws.Int64Value(result.ContentLength),
}
if result.LastModified != nil {
metadata.LastModified = *result.LastModified
}
if result.ContentType != nil {
metadata.ContentType = *result.ContentType
}
if result.ETag != nil {
metadata.ETag = strings.Trim(*result.ETag, "\"")
}
return metadata, nil
}
// generateStorageKey creates a unique storage key for a file
func (s *S3Storage) generateStorageKey(filename string) string {
// Generate a unique ID for the file
fileID := id.String(id.New())
// Extract file extension
ext := path.Ext(filename)
// Create a key structure: prefix/YYYY/MM/DD/uuid.ext
now := time.Now()
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
key := path.Join(datePath, fileID+ext)
// Add prefix if configured
if s.prefix != "" {
key = path.Join(s.prefix, key)
}
return key
}