package api import ( "fmt" "io" "mime/multipart" "net/http" "os" "path/filepath" "strings" "time" "github.com/ant0ine/go-json-rest/rest" "github.com/openaccounting/oa-server/core/model" "github.com/openaccounting/oa-server/core/model/types" "github.com/openaccounting/oa-server/core/util" ) const ( MaxFileSize = 10 * 1024 * 1024 // 10MB MaxFilesPerTx = 10 AttachmentDir = "attachments" ) var AllowedMimeTypes = map[string]bool{ "image/jpeg": true, "image/png": true, "image/gif": true, "application/pdf": true, "text/plain": true, "text/csv": true, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true, // .xlsx "application/vnd.ms-excel": true, // .xls } func PostAttachment(w rest.ResponseWriter, r *rest.Request) { orgId := r.PathParam("orgId") transactionId := r.PathParam("transactionId") if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) { rest.Error(w, "Invalid UUID format", http.StatusBadRequest) return } user := r.Env["USER"].(*types.User) // Parse multipart form err := r.ParseMultipartForm(MaxFileSize) if err != nil { rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest) return } files := r.MultipartForm.File["files"] if len(files) == 0 { rest.Error(w, "No files provided", http.StatusBadRequest) return } if len(files) > MaxFilesPerTx { rest.Error(w, fmt.Sprintf("Too many files. Maximum %d files allowed", MaxFilesPerTx), http.StatusBadRequest) return } // Verify transaction exists and user has permission tx, err := model.Instance.GetTransaction(transactionId, orgId, user.Id) if err != nil { rest.Error(w, "Transaction not found or access denied", http.StatusNotFound) return } if tx == nil { rest.Error(w, "Transaction not found", http.StatusNotFound) return } var attachments []*types.Attachment var description string if desc := r.FormValue("description"); desc != "" { description = desc } for _, fileHeader := range files { attachment, err := processFileUpload(fileHeader, transactionId, orgId, user.Id, description) if err != nil { // Clean up any successfully uploaded files for _, att := range attachments { os.Remove(att.FilePath) } rest.Error(w, err.Error(), http.StatusBadRequest) return } // Save attachment to database createdAttachment, err := model.Instance.CreateAttachment(attachment) if err != nil { // Clean up file and any previously uploaded files os.Remove(attachment.FilePath) for _, att := range attachments { os.Remove(att.FilePath) } rest.Error(w, "Failed to save attachment", http.StatusInternalServerError) return } attachments = append(attachments, createdAttachment) } w.WriteJson(map[string]interface{}{ "attachments": attachments, "count": len(attachments), }) } func GetAttachments(w rest.ResponseWriter, r *rest.Request) { orgId := r.PathParam("orgId") transactionId := r.PathParam("transactionId") if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) { rest.Error(w, "Invalid UUID format", http.StatusBadRequest) return } user := r.Env["USER"].(*types.User) attachments, err := model.Instance.GetAttachmentsByTransaction(transactionId, orgId, user.Id) if err != nil { rest.Error(w, "Failed to retrieve attachments", http.StatusInternalServerError) return } w.WriteJson(attachments) } func GetAttachment(w rest.ResponseWriter, r *rest.Request) { orgId := r.PathParam("orgId") transactionId := r.PathParam("transactionId") attachmentId := r.PathParam("attachmentId") if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) { rest.Error(w, "Invalid UUID format", http.StatusBadRequest) return } user := r.Env["USER"].(*types.User) attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id) if err != nil { rest.Error(w, "Attachment not found or access denied", http.StatusNotFound) return } w.WriteJson(attachment) } func DownloadAttachment(w rest.ResponseWriter, r *rest.Request) { orgId := r.PathParam("orgId") transactionId := r.PathParam("transactionId") attachmentId := r.PathParam("attachmentId") if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) { rest.Error(w, "Invalid UUID format", http.StatusBadRequest) return } user := r.Env["USER"].(*types.User) attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id) if err != nil { rest.Error(w, "Attachment not found or access denied", http.StatusNotFound) return } // Check if file exists if _, err := os.Stat(attachment.FilePath); os.IsNotExist(err) { rest.Error(w, "File not found on disk", http.StatusNotFound) return } // Set headers for file download w.Header().Set("Content-Type", attachment.ContentType) w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName)) // Open and serve file file, err := os.Open(attachment.FilePath) if err != nil { rest.Error(w, "Failed to open file", http.StatusInternalServerError) return } defer file.Close() io.Copy(w.(http.ResponseWriter), file) } func DeleteAttachment(w rest.ResponseWriter, r *rest.Request) { orgId := r.PathParam("orgId") transactionId := r.PathParam("transactionId") attachmentId := r.PathParam("attachmentId") if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) { rest.Error(w, "Invalid UUID format", http.StatusBadRequest) return } user := r.Env["USER"].(*types.User) err := model.Instance.DeleteAttachment(attachmentId, transactionId, orgId, user.Id) if err != nil { rest.Error(w, "Failed to delete attachment or access denied", http.StatusInternalServerError) return } w.WriteJson(map[string]string{"status": "deleted"}) } func processFileUpload(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) { // Validate file size if fileHeader.Size > MaxFileSize { return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize) } // Open uploaded file file, err := fileHeader.Open() if err != nil { return nil, fmt.Errorf("failed to open uploaded file: %v", err) } defer file.Close() // Validate file type from header contentType := fileHeader.Header.Get("Content-Type") if !AllowedMimeTypes[contentType] { return nil, fmt.Errorf("file type %s not allowed", contentType) } // Validate file type by detecting content (more secure) buffer := make([]byte, 512) n, err := file.Read(buffer) if err != nil { return nil, fmt.Errorf("failed to read file for content detection: %v", err) } // Reset file pointer to beginning if _, err := file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to reset file pointer: %v", err) } detectedType := http.DetectContentType(buffer[:n]) if !AllowedMimeTypes[detectedType] { return nil, fmt.Errorf("detected file type %s not allowed (header claimed %s)", detectedType, contentType) } // Generate unique filename attachmentId := util.NewUUID() ext := filepath.Ext(fileHeader.Filename) fileName := attachmentId + ext // Create attachments directory if it doesn't exist uploadDir := filepath.Join(AttachmentDir, orgId, transactionId) if err := os.MkdirAll(uploadDir, 0755); err != nil { return nil, fmt.Errorf("failed to create upload directory: %v", err) } // Create file path filePath := filepath.Join(uploadDir, fileName) // Create destination file dst, err := os.Create(filePath) if err != nil { return nil, fmt.Errorf("failed to create destination file: %v", err) } defer dst.Close() // Copy file contents if _, err := io.Copy(dst, file); err != nil { return nil, fmt.Errorf("failed to save file: %v", err) } // Create attachment object attachment := &types.Attachment{ Id: attachmentId, TransactionId: transactionId, OrgId: orgId, UserId: userId, FileName: fileName, OriginalName: fileHeader.Filename, ContentType: contentType, FileSize: fileHeader.Size, FilePath: filePath, Description: description, Uploaded: time.Now(), Deleted: false, } return attachment, nil } func sanitizeFilename(filename string) string { // Remove potentially dangerous characters filename = strings.ReplaceAll(filename, "..", "") filename = strings.ReplaceAll(filename, "/", "") filename = strings.ReplaceAll(filename, "\\", "") filename = strings.ReplaceAll(filename, "\x00", "") // null bytes filename = strings.ReplaceAll(filename, "\r", "") // carriage return filename = strings.ReplaceAll(filename, "\n", "") // newline // Limit filename length if len(filename) > 255 { ext := filepath.Ext(filename) base := filename[:255-len(ext)] filename = base + ext } return filename }