26 Commits

Author SHA1 Message Date
6558a09258 deps: update vendor dependencies for S3-compatible storage
Updates AWS SDK and removes Blazer B2 dependency in favor of unified
S3-compatible approach. Includes configuration examples and documentation.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 23:07:58 +12:00
f99a866e13 feat: implement unified S3-compatible storage system
Consolidates storage backends into a single S3-compatible driver that supports:
- AWS S3 (native)
- Backblaze B2 (S3-compatible API)
- Cloudflare R2 (S3-compatible API)
- MinIO and other S3-compatible services
- Local filesystem for development

This replaces the previous separate B2 driver with a unified approach,
reducing dependencies and complexity while adding support for more services.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 23:07:44 +12:00
e3152d9f40 fix: correct environment variables in justfile run-dev command
- Fix OA_DATABASE_DRIVER to OA_DATABASEDRIVER (no underscore)
- Fix OA_DATABASE_FILE to OA_DATABASEFILE (no underscore)
- Change database filename from test2.db to dev.db for better naming
- Ensure SQLite database is created in project root as expected

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:22:28 +12:00
e78098ad45 feat: update gitignore for attachment system
- Add .vscode/ to ignore IDE-specific files
- Add server to ignore build artifacts
- Add attachments/ to ignore uploaded attachment files
- Maintain clean repository without development artifacts

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:06:29 +12:00
7c43726abf fix: correct WebSocket message logging format
- Change format specifier from %s to %+v for struct logging
- Resolve compilation error in WebSocket message handling
- Maintain proper logging functionality for debugging

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:05:37 +12:00
b7ac4b0152 fix: add missing mock expectations in account tests
- Add GetSplitCountByAccountId mock expectations for CreateAccount tests
- Add GetSplitCountByAccountId mock expectations for UpdateAccount tests
- Resolve "unexpected method call" errors in account test suite
- Maintain existing test logic while fixing mock setup

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:05:21 +12:00
1b115fe0ff feat: add attachment methods to mock datastore
- Add InsertAttachment mock method for testing
- Add GetAttachment and GetAttachmentsByTransaction mock methods
- Add DeleteAttachment mock method for testing
- Maintain consistency with existing mock patterns

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:05:05 +12:00
a87df47231 feat: register attachment API routes
- Add 5 RESTful endpoints for transaction attachment management
- Include proper authentication middleware for all attachment operations
- Follow existing URL pattern: /orgs/:orgId/transactions/:transactionId/attachments
- Support nested resource access with proper authorization

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:50 +12:00
8b0a72c81f feat: implement attachment REST API endpoints
- Add POST /attachments for secure multi-file upload with validation
- Add GET /attachments for listing transaction attachments
- Add GET /attachments/:id for attachment metadata retrieval
- Add GET /attachments/:id/download for secure file download
- Add DELETE /attachments/:id for soft deletion
- Include comprehensive security validation: file type, size, content detection
- Implement proper error handling and cleanup on failures

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:32 +12:00
f64f83e66f feat: add attachment support to GORM model
- Implement AttachmentInterface methods in GormModel
- Add GetTransaction method for interface compliance
- Include placeholder implementation for future GORM repository development
- Maintain backward compatibility with existing GORM usage

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:15 +12:00
f5f0853040 feat: implement attachment business logic layer
- Add AttachmentInterface to main model interface
- Implement attachment CRUD operations with permission checking
- Add GetTransaction method for secure attachment access validation
- Add accountsContainReadAccess for permission verification
- Ensure users can only access attachments for authorized transactions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:01 +12:00
04653f2f02 feat: implement attachment database layer
- Add AttachmentInterface to main Datastore interface
- Implement CRUD operations for attachments following existing patterns
- Add proper SQL marshalling/unmarshalling with HEX/UNHEX for binary IDs
- Include soft deletion and proper indexing support

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:03:44 +12:00
3b89d8137e feat: add UUID utility functions for attachment system
- Add NewUUID() function for generating unique attachment identifiers
- Add IsValidUUID() function for validating UUID format in API requests
- Support 32-character hex string validation for secure file handling

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:03:30 +12:00
d10686e70f feat: add attachment model type definitions
- Add core attachment type with metadata fields for transaction files
- Add GORM model for attachment with proper relationships
- Include file information, upload timestamps, and soft deletion support

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:03:13 +12:00
c335c834ba feat: add database schema for transaction attachments
- Add attachment table with fields for file metadata and relationships
- Include indexes for optimal query performance on transactionId, orgId, userId, and uploaded fields
- Support for file storage with path tracking and soft deletion

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:02:58 +12:00
b5ea2095e4 fix: update authentication layer for GORM repository integration
- Update auth.go to use new repository interfaces
- Fix test compilation errors in auth_test.go
- Maintain compatibility with existing authentication flows
- Update mock implementations for repository pattern

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:09:36 +12:00
88c996a383 deps: update dependencies for GORM, Viper, and SQLite support
- Add GORM v1.25.12 with MySQL and SQLite drivers
- Add Viper v1.19.0 for configuration management
- Add UUID package for GORM model IDs
- Update vendor directory with new dependencies
- Update Go module requirements and checksums

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:09:22 +12:00
8c7088040d docs: comprehensive README update with Docker and Viper documentation
- Add complete Viper configuration documentation
- Include Docker deployment examples and best practices
- Document environment variable configuration options
- Add Just build automation usage examples
- Create troubleshooting and migration guides
- Update prerequisites and setup instructions
- Add security guidelines for production deployments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:09:10 +12:00
77ab4b0e1d feat: add Just build automation with comprehensive recipes
- Create justfile with development and production recipes
- Add Docker build and run commands with proper configuration
- Include database management utilities and migration helpers
- Add development setup and dependency management
- Create configuration help and documentation commands
- Support both SQLite and MySQL deployment scenarios

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:09:00 +12:00
62dea0e53c feat: add Docker containerization with multi-stage build
- Create production-ready Dockerfile with multi-stage build
- Add CGO support for SQLite driver compilation
- Implement security best practices with non-root user
- Add health checks with proper API version headers
- Create .dockerignore for optimized build context
- Support both SQLite and MySQL in containerized environment
- Include volume mounting for data persistence

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:08:49 +12:00
d2ea9960bf feat: update config samples with SQLite and Viper support
- Add SQLite configuration options to config.json.sample
- Create config.mysql.json.sample for MySQL deployments
- Add security comments for sensitive configuration
- Include environment variable examples and documentation
- Add Viper configuration comments and usage examples

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:08:34 +12:00
f547d8d75b feat: integrate Viper for advanced configuration management
- Replace basic config loading with Viper framework
- Add support for multiple config sources (files, env vars, defaults)
- Add mapstructure tags for proper config binding
- Support JSON, YAML, and TOML config formats
- Add environment variable support with OA_ prefix
- Implement secure config loading with multiple search paths
- Maintain backward compatibility with existing config.json files

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:08:19 +12:00
0d1cb22044 refactor: update data access layer to use GORM repositories
- Replace SQL-based queries with GORM repository calls
- Update all model interfaces to use repository pattern
- Fix compilation errors in core/model/ files
- Update mocks to match new repository interfaces
- Modify API handlers to use new repository layer
- Maintain backward compatibility with existing interfaces

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:08:08 +12:00
bd3f101fb4 feat: add GORM integration with repository pattern
- Add GORM models in models/ directory with proper column tags
- Create repository interfaces and implementations in core/repository/
- Add database package with MySQL and SQLite support
- Add UUID ID utility for GORM models
- Implement complete repository layer replacing SQL-based data access
- Add database migrations and index creation
- Support both MySQL and SQLite drivers with auto-migration

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-30 22:07:51 +12:00
e865c4c1a2 fix: Add gorm and driver
Updated existing vendored dependencies
2025-06-09 22:56:57 +12:00
51deace1da refactor: Removed standalone migration apps. Causes build fail 🔥 2025-06-09 22:53:25 +12:00
1274 changed files with 777268 additions and 3004 deletions

39
.dockerignore Normal file
View File

@@ -0,0 +1,39 @@
# Git
.git
.gitignore
# Documentation
README.md
*.md
# Docker
Dockerfile
.dockerignore
# Build artifacts
server
*.exe
# Development files
.vscode/
.idea/
# Local config and data
config.json
*.db
data/
# Test files
*_test.go
test*
# Temporary files
*.tmp
*.log
# OS files
.DS_Store
Thumbs.db
# Dependencies (will be downloaded)
vendor/

38
.env.storage.example Normal file
View File

@@ -0,0 +1,38 @@
# OpenAccounting Storage Configuration
# Copy this file to .env and modify as needed
# Database Configuration
OA_DATABASEDRIVER=sqlite
OA_DATABASEFILE=./openaccounting.db
OA_ADDRESS=localhost
OA_PORT=8080
OA_APIPREFIX=/api/v1
# Storage Backend Configuration
# Options: local, s3, b2
OA_STORAGE_BACKEND=local
# Local Storage Configuration
OA_STORAGE_LOCAL_ROOTDIR=./uploads
OA_STORAGE_LOCAL_BASEURL=
# Amazon S3 Storage Configuration (uncomment if using S3)
# OA_STORAGE_S3_REGION=us-east-1
# OA_STORAGE_S3_BUCKET=my-openaccounting-attachments
# OA_STORAGE_S3_PREFIX=attachments
# OA_STORAGE_S3_ACCESSKEYID=AKIA...
# OA_STORAGE_S3_SECRETACCESSKEY=...
# OA_STORAGE_S3_ENDPOINT=
# OA_STORAGE_S3_PATHSTYLE=false
# Backblaze B2 Storage Configuration (uncomment if using B2)
# OA_STORAGE_B2_ACCOUNTID=your-b2-account-id
# OA_STORAGE_B2_APPLICATIONKEY=your-b2-application-key
# OA_STORAGE_B2_BUCKET=my-openaccounting-attachments
# OA_STORAGE_B2_PREFIX=attachments
# Email Configuration (optional)
# OA_MAILGUNDOMAIN=
# OA_MAILGUNKEY=
# OA_MAILGUNEMAIL=
# OA_MAILGUNSENDER=

3
.gitignore vendored
View File

@@ -97,3 +97,6 @@ config.json
*.csr
*.sublime-project
*.sublime-workspace
.vscode/
server
attachments/

64
Dockerfile Normal file
View File

@@ -0,0 +1,64 @@
# Build stage
FROM golang:1.24-alpine AS builder
# Install build dependencies for CGO (needed for SQLite)
RUN apk add --no-cache git gcc musl-dev
# Set working directory
WORKDIR /app
# Copy go mod files
COPY go.mod go.sum ./
# Download dependencies
RUN go mod download
# Copy source code
COPY . .
# Build the application
RUN CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -o server ./core/
# Final stage
FROM alpine:latest
# Install ca-certificates for HTTPS and sqlite for database
RUN apk --no-cache add ca-certificates sqlite
# Create app user for security
RUN adduser -D -s /bin/sh appuser
# Set working directory
WORKDIR /app
# Copy binary from builder stage
COPY --from=builder /app/server .
# Create data directory for SQLite
RUN mkdir -p /app/data && chown appuser:appuser /app/data
# Copy config sample (optional)
COPY config.json.sample .
# Change ownership to app user
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Expose port (default 8080, can be overridden with OA_PORT)
EXPOSE 8080
# Health check - requires Accept-Version header
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider --header="Accept-Version: v1" http://localhost:8080/ || exit 1
# Set default environment variables
ENV OA_DATABASE_DRIVER=sqlite \
OA_DATABASE_FILE=/app/data/openaccounting.db \
OA_ADDRESS=0.0.0.0 \
OA_PORT=8080 \
OA_API_PREFIX=/api/v1
# Run the application
CMD ["./server"]

485
README.md
View File

@@ -1,30 +1,491 @@
# Open Accounting Server
Open Accounting Server is a modern financial accounting system built with Go, featuring GORM integration, Viper configuration management, and Docker support.
## Features
- **GORM Integration**: Modern ORM with SQLite and MySQL support
- **Viper Configuration**: Flexible config management with environment variables
- **Modular Storage**: S3-compatible attachment storage (Local, AWS S3, Backblaze B2, Cloudflare R2, MinIO)
- **Docker Ready**: Containerized deployment with multi-stage builds
- **SQLite Support**: Easy local development and testing
- **Security**: Environment variable support for sensitive data
## Prerequisites
1. Go 1.8+
2. MySQL 5.7+
- **Go 1.24+** (updated from 1.8+)
- **SQLite** (for development) or **MySQL 5.7+** (for production)
- **Docker** (optional, for containerized deployment)
- **Just** (optional, for build automation)
## Database setup
## Quick Start
Use schema.sql and indexes.sql to create a MySQL database to store Open Accounting data.
### Using Just (Recommended)
```bash
# Setup development environment
just dev-setup
# Run in development mode
just run-dev
# Build and run with Docker
just docker-run
```
### Manual Setup
```bash
# Install dependencies
go mod download
# Run with SQLite (development)
OA_DATABASE_DRIVER=sqlite ./server
# Run with MySQL (production)
OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret ./server
```
## Configuration
Copy config.json.sample to config.json and edit to match your information.
The server now uses **Viper** for advanced configuration management with multiple sources:
## Run
### Configuration Sources (in order of precedence)
`go run core/server.go`
1. **Environment Variables** (highest priority)
2. **Config Files**: `config.json`, `config.yaml`, `config.toml`
3. **Default Values** (lowest priority)
## Build
### Config File Locations
`go build core/server.go`
- `./config.json` (current directory)
- `/etc/openaccounting/config.json`
- `~/.openaccounting/config.json`
### Environment Variables
All configuration can be overridden with environment variables using the `OA_` prefix:
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_ADDRESS` | Address | `localhost` | Server bind address |
| `OA_PORT` | Port | `8080` | Server port |
| `OA_API_PREFIX` | ApiPrefix | `/api/v1` | API route prefix |
| `OA_DATABASE_DRIVER` | DatabaseDriver | `sqlite` | Database type: `sqlite` or `mysql` |
| `OA_DATABASE_FILE` | DatabaseFile | `./openaccounting.db` | SQLite database file |
| `OA_DATABASE_ADDRESS` | DatabaseAddress | `localhost:3306` | MySQL server address |
| `OA_DATABASE` | Database | | MySQL database name |
| `OA_USER` | User | | Database username |
| `OA_PASSWORD` | Password | | Database password ⚠️ |
| `OA_MAILGUN_DOMAIN` | MailgunDomain | | Mailgun domain |
| `OA_MAILGUN_KEY` | MailgunKey | | Mailgun API key ⚠️ |
| `OA_MAILGUN_EMAIL` | MailgunEmail | | Mailgun email |
| `OA_MAILGUN_SENDER` | MailgunSender | | Mailgun sender name |
#### Storage Configuration
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_STORAGE_BACKEND` | Storage.Backend | `local` | Storage backend: `local` or `s3` |
**Local Storage**
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_STORAGE_LOCAL_ROOTDIR` | Storage.Local.RootDir | `./uploads` | Root directory for file storage |
| `OA_STORAGE_LOCAL_BASEURL` | Storage.Local.BaseURL | | Base URL for serving files |
**S3-Compatible Storage** (AWS S3, Backblaze B2, Cloudflare R2, MinIO)
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_STORAGE_S3_REGION` | Storage.S3.Region | | Region (use "auto" for Cloudflare R2) |
| `OA_STORAGE_S3_BUCKET` | Storage.S3.Bucket | | Bucket name |
| `OA_STORAGE_S3_PREFIX` | Storage.S3.Prefix | | Optional prefix for all objects |
| `OA_STORAGE_S3_ACCESSKEYID` | Storage.S3.AccessKeyID | | Access Key ID ⚠️ |
| `OA_STORAGE_S3_SECRETACCESSKEY` | Storage.S3.SecretAccessKey | | Secret Access Key ⚠️ |
| `OA_STORAGE_S3_ENDPOINT` | Storage.S3.Endpoint | | Custom endpoint (see examples below) |
| `OA_STORAGE_S3_PATHSTYLE` | Storage.S3.PathStyle | `false` | Use path-style addressing |
**S3-Compatible Service Endpoints:**
- **AWS S3**: Leave endpoint empty, set appropriate region
- **Backblaze B2**: `https://s3.us-west-004.backblazeb2.com` (replace region as needed)
- **Cloudflare R2**: `https://<account-id>.r2.cloudflarestorage.com`
- **MinIO**: `http://localhost:9000` (or your MinIO server URL)
⚠️ **Security**: Always use environment variables for sensitive data like passwords and API keys.
### Configuration Examples
#### Development (SQLite)
```bash
# Minimal - uses defaults
./server
# Custom database file and port
OA_DATABASE_FILE=./dev.db OA_PORT=9090 ./server
```
#### Production (MySQL)
```bash
# With environment variables (recommended)
export OA_DATABASE_DRIVER=mysql
export OA_DATABASE_ADDRESS=db.example.com:3306
export OA_DATABASE=openaccounting_prod
export OA_USER=openaccounting
export OA_PASSWORD=secure_password
export OA_MAILGUN_KEY=key-abc123
./server
# Or inline
OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret OA_MAILGUN_KEY=key-123 ./server
```
#### Storage Configuration Examples
```bash
# Local storage (default)
export OA_STORAGE_BACKEND=local
export OA_STORAGE_LOCAL_ROOTDIR=./uploads
./server
# AWS S3
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=us-west-2
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=your-access-key
export OA_STORAGE_S3_SECRETACCESSKEY=your-secret-key
./server
# Backblaze B2 (S3-compatible)
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=us-west-004
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=your-b2-key-id
export OA_STORAGE_S3_SECRETACCESSKEY=your-b2-application-key
export OA_STORAGE_S3_ENDPOINT=https://s3.us-west-004.backblazeb2.com
export OA_STORAGE_S3_PATHSTYLE=true
./server
# Cloudflare R2
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=auto
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=your-r2-access-key
export OA_STORAGE_S3_SECRETACCESSKEY=your-r2-secret-key
export OA_STORAGE_S3_ENDPOINT=https://your-account-id.r2.cloudflarestorage.com
./server
# MinIO (self-hosted)
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=us-east-1
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=minioadmin
export OA_STORAGE_S3_SECRETACCESSKEY=minioadmin
export OA_STORAGE_S3_ENDPOINT=http://localhost:9000
export OA_STORAGE_S3_PATHSTYLE=true
./server
```
#### Docker
```bash
# SQLite with volume mount
docker run -p 8080:8080 \
-e OA_DATABASE_DRIVER=sqlite \
-v ./data:/app/data \
openaccounting-server:latest
# MySQL with environment variables
docker run -p 8080:8080 \
-e OA_DATABASE_DRIVER=mysql \
-e OA_DATABASE_ADDRESS=mysql:3306 \
-e OA_PASSWORD=secret \
openaccounting-server:latest
# With AWS S3 storage
docker run -p 8080:8080 \
-e OA_STORAGE_BACKEND=s3 \
-e OA_STORAGE_S3_REGION=us-west-2 \
-e OA_STORAGE_S3_BUCKET=my-attachments \
-e OA_STORAGE_S3_ACCESSKEYID=your-key \
-e OA_STORAGE_S3_SECRETACCESSKEY=your-secret \
openaccounting-server:latest
# With Cloudflare R2 storage
docker run -p 8080:8080 \
-e OA_STORAGE_BACKEND=s3 \
-e OA_STORAGE_S3_REGION=auto \
-e OA_STORAGE_S3_BUCKET=my-attachments \
-e OA_STORAGE_S3_ACCESSKEYID=your-r2-key \
-e OA_STORAGE_S3_SECRETACCESSKEY=your-r2-secret \
-e OA_STORAGE_S3_ENDPOINT=https://account-id.r2.cloudflarestorage.com \
openaccounting-server:latest
```
## Database Setup
### SQLite (Development)
SQLite databases are created automatically. No manual setup required.
```bash
# Uses ./openaccounting.db by default
OA_DATABASE_DRIVER=sqlite ./server
# Custom location
OA_DATABASE_DRIVER=sqlite OA_DATABASE_FILE=./data/myapp.db ./server
```
### MySQL (Production)
Use the provided schema files to create your MySQL database:
```sql
-- Create database and user
CREATE DATABASE openaccounting;
CREATE USER 'openaccounting'@'%' IDENTIFIED BY 'secure_password';
GRANT ALL PRIVILEGES ON openaccounting.* TO 'openaccounting'@'%';
```
The server will automatically create tables and run migrations on startup.
## Building
### Local Build
```bash
# Development build
go build -o server ./core/
# Production build (optimized)
CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -ldflags="-w -s" -o server ./core/
```
### Docker Build
```bash
# Build image
docker build -t openaccounting-server:latest .
# Multi-platform build
docker buildx build --platform linux/amd64,linux/arm64 -t openaccounting-server:latest .
```
## Running
### Development
```bash
# Local with SQLite
just run-dev
# Or manually
OA_DATABASE_DRIVER=sqlite OA_PORT=8080 ./server
```
### Production
```bash
# With Docker Compose (recommended)
docker-compose up -d
# Or manually with environment file
export $(cat .env | xargs)
./server
```
## Just Recipes
This project includes a `justfile` with common tasks:
```bash
just --list # Show all available recipes
just build # Build the application
just run-dev # Run in development mode
just docker-build # Build Docker image
just docker-run # Run container
just test # Run tests
just config-help # Show configuration help
just dev-setup # Complete development setup
```
## API
The server provides a REST API at `/api/v1/` (configurable via `OA_API_PREFIX`).
### Health Check
```bash
curl http://localhost:8080/api/v1/health
```
## Development
### Prerequisites
```bash
# Install Go dependencies
go mod download
# Install development tools (optional)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
```
### Running Tests
```bash
just test
# or
go test ./...
```
### Code Quality
```bash
# Format code
just fmt
# Lint code (requires golangci-lint)
just lint
```
## Docker
If you are interested in running Open Accounting via Docker, @alokmenghrajani has created a [repo](https://github.com/alokmenghrajani/openaccounting-docker) for this.
### Official Images
## Help
Docker images are available with multi-stage builds for optimal size and security:
[Join our Slack chatroom](https://join.slack.com/t/openaccounting/shared_invite/zt-23zy988e8-93HP1GfLDB7osoQ6umpfiA) and talk with us!
- Non-root user for security
- Alpine Linux base for minimal attack surface
- Health checks included
- Volume support for data persistence
### Environment Variables in Docker
```dockerfile
ENV OA_DATABASE_DRIVER=sqlite \
OA_DATABASE_FILE=/app/data/openaccounting.db \
OA_ADDRESS=0.0.0.0 \
OA_PORT=8080
```
### Data Persistence
```bash
# Mount volume for SQLite data
docker run -v ./data:/app/data openaccounting-server:latest
# Use named volume
docker volume create openaccounting-data
docker run -v openaccounting-data:/app/data openaccounting-server:latest
```
## Deployment
### Docker Compose
```yaml
version: '3.8'
services:
openaccounting:
image: openaccounting-server:latest
ports:
- "8080:8080"
environment:
OA_DATABASE_DRIVER: mysql
OA_DATABASE_ADDRESS: mysql:3306
OA_DATABASE: openaccounting
OA_USER: openaccounting
OA_PASSWORD: ${DB_PASSWORD}
depends_on:
- mysql
mysql:
image: mysql:8.0
environment:
MYSQL_DATABASE: openaccounting
MYSQL_USER: openaccounting
MYSQL_PASSWORD: ${DB_PASSWORD}
MYSQL_ROOT_PASSWORD: ${DB_ROOT_PASSWORD}
volumes:
- mysql_data:/var/lib/mysql
volumes:
mysql_data:
```
### Kubernetes
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: openaccounting-server
spec:
replicas: 3
selector:
matchLabels:
app: openaccounting-server
template:
metadata:
labels:
app: openaccounting-server
spec:
containers:
- name: openaccounting-server
image: openaccounting-server:latest
ports:
- containerPort: 8080
env:
- name: OA_DATABASE_DRIVER
value: "mysql"
- name: OA_PASSWORD
valueFrom:
secretKeyRef:
name: openaccounting-secrets
key: db-password
```
## Troubleshooting
### Common Issues
1. **Config file not found**: The server will use environment variables and defaults if no config file is found
2. **Database connection failed**: Check your database credentials and connectivity
3. **Permission denied**: Ensure proper file permissions for SQLite database files
### Debug Mode
```bash
# Enable verbose logging
OA_LOG_LEVEL=debug ./server
# Check configuration
just config-help
```
### Health Checks
```bash
# Application health
curl http://localhost:8080/api/v1/health
# Docker health check
docker inspect --format='{{.State.Health.Status}}' container_name
```
## Migration from Legacy Setup
The server maintains backward compatibility with existing `config.json` files while adding Viper features:
1. Existing `config.json` files continue to work
2. Add environment variables for sensitive data
3. Use SQLite for easier local development
4. Leverage Docker for production deployments
## Help & Support
- **Documentation**: This README and inline code comments
- **Issues**: GitHub Issues for bug reports and feature requests
- **Community**: [Join our Slack chatroom](https://join.slack.com/t/openaccounting/shared_invite/zt-23zy988e8-93HP1GfLDB7osoQ6umpfiA)
## License
See LICENSE file for details.

239
STORAGE.md Normal file
View File

@@ -0,0 +1,239 @@
# Modular Storage System
The OpenAccounting server now supports multiple storage backends for file attachments. This allows you to choose between local filesystem storage for simple deployments or cloud storage for production/multi-user environments.
## Supported Storage Backends
### 1. Local Filesystem Storage
Perfect for self-hosted deployments or development environments.
**Configuration:**
```json
{
"storage": {
"backend": "local",
"local": {
"root_dir": "./uploads",
"base_url": "https://yourapp.com/files"
}
}
}
```
**Environment Variables:**
```bash
OA_STORAGE_BACKEND=local
OA_STORAGE_LOCAL_ROOT_DIR=./uploads
OA_STORAGE_LOCAL_BASE_URL=https://yourapp.com/files
```
### 2. Amazon S3 Storage
Reliable cloud storage for production deployments.
**Configuration:**
```json
{
"storage": {
"backend": "s3",
"s3": {
"region": "us-east-1",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments",
"access_key_id": "AKIA...",
"secret_access_key": "...",
"endpoint": "",
"path_style": false
}
}
}
```
**Environment Variables:**
```bash
OA_STORAGE_BACKEND=s3
OA_STORAGE_S3_REGION=us-east-1
OA_STORAGE_S3_BUCKET=my-openaccounting-attachments
OA_STORAGE_S3_PREFIX=attachments
OA_STORAGE_S3_ACCESS_KEY_ID=AKIA...
OA_STORAGE_S3_SECRET_ACCESS_KEY=...
```
**Features:**
- Automatic presigned URL generation
- Configurable expiry times
- Support for S3-compatible services (MinIO, DigitalOcean Spaces)
- IAM role support (leave credentials empty to use IAM)
### 3. Backblaze B2 Storage
Cost-effective cloud storage alternative to S3.
**Configuration:**
```json
{
"storage": {
"backend": "b2",
"b2": {
"account_id": "your-b2-account-id",
"application_key": "your-b2-application-key",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments"
}
}
}
```
**Environment Variables:**
```bash
OA_STORAGE_BACKEND=b2
OA_STORAGE_B2_ACCOUNT_ID=your-b2-account-id
OA_STORAGE_B2_APPLICATION_KEY=your-b2-application-key
OA_STORAGE_B2_BUCKET=my-openaccounting-attachments
OA_STORAGE_B2_PREFIX=attachments
```
## API Endpoints
The storage system provides both legacy and new endpoints:
### New Storage-Agnostic Endpoints
**Upload Attachment:**
```
POST /api/v1/attachments
Content-Type: multipart/form-data
transactionId: uuid
description: string (optional)
file: binary data
```
**Get Attachment Metadata:**
```
GET /api/v1/attachments/{id}
```
**Get Download URL:**
```
GET /api/v1/attachments/{id}/url
```
**Download File:**
```
GET /api/v1/attachments/{id}?download=true
```
**Delete Attachment:**
```
DELETE /api/v1/attachments/{id}
```
### Legacy Endpoints (Still Supported)
The original transaction-scoped endpoints remain available for backward compatibility:
- `GET/POST /api/v1/orgs/{orgId}/transactions/{transactionId}/attachments`
## Security Features
- **File type validation** - Only allowed MIME types are accepted
- **File size limits** - Configurable maximum file size (default 10MB)
- **Path traversal protection** - Prevents directory traversal attacks
- **Access control** - Files are linked to users and organizations
- **Presigned URLs** - Time-limited access for cloud storage
## File Organization
Files are automatically organized by date:
```
uploads/
├── 2025/
│ ├── 01/
│ │ ├── 15/
│ │ │ ├── uuid1.pdf
│ │ │ └── uuid2.png
│ │ └── 16/
│ │ └── uuid3.jpg
```
## Configuration Examples
### Development (Local Storage)
```json
{
"storage": {
"backend": "local",
"local": {
"root_dir": "./dev-uploads"
}
}
}
```
### Production (S3 with IAM)
```json
{
"storage": {
"backend": "s3",
"s3": {
"region": "us-west-2",
"bucket": "prod-openaccounting-files",
"prefix": "attachments"
}
}
}
```
### Cost-Optimized (Backblaze B2)
```json
{
"storage": {
"backend": "b2",
"b2": {
"account_id": "${B2_ACCOUNT_ID}",
"application_key": "${B2_APP_KEY}",
"bucket": "openaccounting-prod"
}
}
}
```
## Migration Between Storage Backends
When changing storage backends, existing attachments will remain in the old storage location. The database records contain the storage path, so files can be accessed until migrated.
To migrate:
1. Update configuration to new backend
2. Restart server
3. New uploads will use the new backend
4. Optional: Run migration script to move existing files
## Environment-Specific Considerations
### Self-Hosted
- Use local storage for simplicity
- Ensure backup strategy includes upload directory
- Consider disk space management
### Cloud Deployment
- Use S3 or B2 for reliability and scalability
- Configure proper IAM policies
- Enable versioning and lifecycle policies
### Multi-Region
- Use cloud storage with appropriate region selection
- Consider CDN integration for better performance
## Troubleshooting
**Storage backend not initialized:**
- Check configuration syntax
- Verify credentials for cloud backends
- Ensure storage directories/buckets exist
**Permission denied:**
- Check file system permissions for local storage
- Verify IAM policies for S3
- Confirm B2 application key permissions
**Large file uploads failing:**
- Check `MaxFileSize` configuration
- Verify network timeouts
- Consider multipart upload for large files

17
config.b2.json.sample Normal file
View File

@@ -0,0 +1,17 @@
{
"weburl": "https://yourapp.com",
"address": "localhost",
"port": 8080,
"apiprefix": "/api/v1",
"databasedriver": "sqlite",
"databasefile": "./openaccounting.db",
"storage": {
"backend": "b2",
"b2": {
"account_id": "your-b2-account-id",
"application_key": "your-b2-application-key",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments"
}
}
}

View File

@@ -1,16 +1,31 @@
{
"_comment_config": "OpenAccounting Server Configuration - now supports Viper for multiple config sources",
"_comment_viper": "You can override any setting with environment variables using OA_ prefix (e.g., OA_PASSWORD, OA_MAILGUN_KEY)",
"WebUrl": "https://domain.com",
"Address": "",
"Address": "localhost",
"Port": 8080,
"ApiPrefix": "",
"ApiPrefix": "/api/v1",
"KeyFile": "",
"CertFile": "",
"DatabaseAddress": "",
"_comment_database": "Database configuration - choose 'sqlite' for local testing or 'mysql' for production",
"DatabaseDriver": "sqlite",
"DatabaseFile": "./data/openaccounting.db",
"DatabaseAddress": "localhost:3306",
"Database": "openaccounting",
"User": "openaccounting",
"Password": "openaccounting",
"Password": "",
"_comment_password": "SECURITY: Set password via OA_PASSWORD environment variable instead of this file",
"_comment_mailgun": "Mailgun configuration for email sending",
"MailgunDomain": "mg.domain.com",
"MailgunKey": "",
"_comment_mailgun_key": "SECURITY: Set Mailgun key via OA_MAILGUN_KEY environment variable",
"MailgunEmail": "noreply@domain.com",
"MailgunSender": "Sender"
"MailgunSender": "Sender",
"_comment_env_examples": "Environment variable examples:",
"_example_development": "OA_DATABASE_DRIVER=sqlite OA_DATABASE_FILE=./dev.db ./server",
"_example_production": "OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret OA_MAILGUN_KEY=key-123 ./server"
}

19
config.mysql.json.sample Normal file
View File

@@ -0,0 +1,19 @@
{
"WebUrl": "https://domain.com",
"Address": "",
"Port": 8080,
"ApiPrefix": "",
"KeyFile": "",
"CertFile": "",
"_comment_database": "MySQL configuration for production use",
"DatabaseDriver": "mysql",
"DatabaseAddress": "localhost:3306",
"Database": "openaccounting",
"User": "openaccounting",
"Password": "openaccounting",
"_comment_mailgun": "Mailgun configuration for email sending",
"MailgunDomain": "mg.domain.com",
"MailgunKey": "",
"MailgunEmail": "noreply@domain.com",
"MailgunSender": "Sender"
}

20
config.s3.json.sample Normal file
View File

@@ -0,0 +1,20 @@
{
"weburl": "https://yourapp.com",
"address": "localhost",
"port": 8080,
"apiprefix": "/api/v1",
"databasedriver": "sqlite",
"databasefile": "./openaccounting.db",
"storage": {
"backend": "s3",
"s3": {
"region": "us-east-1",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments",
"access_key_id": "",
"secret_access_key": "",
"endpoint": "",
"path_style": false
}
}
}

View File

@@ -0,0 +1,15 @@
{
"weburl": "https://yourapp.com",
"address": "localhost",
"port": 8080,
"apiprefix": "/api/v1",
"databasedriver": "sqlite",
"databasefile": "./openaccounting.db",
"storage": {
"backend": "local",
"local": {
"root_dir": "./uploads",
"base_url": "https://yourapp.com/files"
}
}
}

View File

@@ -2,7 +2,7 @@ package api
import (
"encoding/json"
"io/ioutil"
"io"
"net/http"
"strconv"
"time"
@@ -12,48 +12,7 @@ import (
"github.com/openaccounting/oa-server/core/model/types"
)
/**
* @api {get} /orgs/:orgId/accounts Get Accounts by Org id
* @apiVersion 1.4.0
* @apiName GetOrgAccounts
* @apiGroup Account
*
* @apiHeader {String} Authorization HTTP Basic Auth
* @apiHeader {String} Accept-Version ^1.4.0 semver versioning
*
* @apiSuccess {String} id Id of the Account.
* @apiSuccess {String} orgId Id of the Org.
* @apiSuccess {Date} inserted Date Account was created
* @apiSuccess {Date} updated Date Account was updated
* @apiSuccess {String} name Name of the Account.
* @apiSuccess {String} parent Id of the parent Account.
* @apiSuccess {String} currency Three letter currency code.
* @apiSuccess {Number} precision How many digits the currency goes out to.
* @apiSuccess {Boolean} debitBalance True if Account has a debit balance.
* @apiSuccess {Number} balance Current Account balance in this Account's currency
* @apiSuccess {Number} nativeBalance Current Account balance in the Org's currency
*
* @apiSuccessExample Success-Response:
* HTTP/1.1 200 OK
* [
* {
* "id": "22222222222222222222222222222222",
* "orgId": "11111111111111111111111111111111",
* "inserted": "2018-09-11T18:05:04.420Z",
* "updated": "2018-09-11T18:05:04.420Z",
* "name": "Cash",
* "parent": "11111111111111111111111111111111",
* "currency": "USD",
* "precision": 2,
* "debitBalance": true,
* "balance": 10000,
* "nativeBalance": 10000
* }
* ]
*
* @apiUse NotAuthorizedError
* @apiUse InternalServerError
*/
// GetOrgAccounts /**
func GetOrgAccounts(w rest.ResponseWriter, r *rest.Request) {
user := r.Env["USER"].(*types.User)
orgId := r.PathParam("orgId")
@@ -208,7 +167,7 @@ func PostAccount(w rest.ResponseWriter, r *rest.Request) {
user := r.Env["USER"].(*types.User)
orgId := r.PathParam("orgId")
content, err := ioutil.ReadAll(r.Body)
content, err := io.ReadAll(r.Body)
r.Body.Close()
if err != nil {

313
core/api/attachment.go Normal file
View File

@@ -0,0 +1,313 @@
package api
import (
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/ant0ine/go-json-rest/rest"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
)
const (
MaxFileSize = 10 * 1024 * 1024 // 10MB
MaxFilesPerTx = 10
AttachmentDir = "attachments"
)
var AllowedMimeTypes = map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"application/pdf": true,
"text/plain": true,
"text/csv": true,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true, // .xlsx
"application/vnd.ms-excel": true, // .xls
}
func PostAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Parse multipart form
err := r.ParseMultipartForm(MaxFileSize)
if err != nil {
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
rest.Error(w, "No files provided", http.StatusBadRequest)
return
}
if len(files) > MaxFilesPerTx {
rest.Error(w, fmt.Sprintf("Too many files. Maximum %d files allowed", MaxFilesPerTx), http.StatusBadRequest)
return
}
// Verify transaction exists and user has permission
tx, err := model.Instance.GetTransaction(transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Transaction not found or access denied", http.StatusNotFound)
return
}
if tx == nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
var attachments []*types.Attachment
var description string
if desc := r.FormValue("description"); desc != "" {
description = desc
}
for _, fileHeader := range files {
attachment, err := processFileUpload(fileHeader, transactionId, orgId, user.Id, description)
if err != nil {
// Clean up any successfully uploaded files
for _, att := range attachments {
os.Remove(att.FilePath)
}
rest.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Save attachment to database
createdAttachment, err := model.Instance.CreateAttachment(attachment)
if err != nil {
// Clean up file and any previously uploaded files
os.Remove(attachment.FilePath)
for _, att := range attachments {
os.Remove(att.FilePath)
}
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
return
}
attachments = append(attachments, createdAttachment)
}
w.WriteJson(map[string]interface{}{
"attachments": attachments,
"count": len(attachments),
})
}
func GetAttachments(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Failed to retrieve attachments", http.StatusInternalServerError)
return
}
w.WriteJson(attachments)
}
func GetAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
attachmentId := r.PathParam("attachmentId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Attachment not found or access denied", http.StatusNotFound)
return
}
w.WriteJson(attachment)
}
func DownloadAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
attachmentId := r.PathParam("attachmentId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Attachment not found or access denied", http.StatusNotFound)
return
}
// Check if file exists
if _, err := os.Stat(attachment.FilePath); os.IsNotExist(err) {
rest.Error(w, "File not found on disk", http.StatusNotFound)
return
}
// Set headers for file download
w.Header().Set("Content-Type", attachment.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
// Open and serve file
file, err := os.Open(attachment.FilePath)
if err != nil {
rest.Error(w, "Failed to open file", http.StatusInternalServerError)
return
}
defer file.Close()
io.Copy(w.(http.ResponseWriter), file)
}
func DeleteAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
attachmentId := r.PathParam("attachmentId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
err := model.Instance.DeleteAttachment(attachmentId, transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Failed to delete attachment or access denied", http.StatusInternalServerError)
return
}
w.WriteJson(map[string]string{"status": "deleted"})
}
func processFileUpload(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
// Validate file size
if fileHeader.Size > MaxFileSize {
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
}
// Open uploaded file
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %v", err)
}
defer file.Close()
// Validate file type from header
contentType := fileHeader.Header.Get("Content-Type")
if !AllowedMimeTypes[contentType] {
return nil, fmt.Errorf("file type %s not allowed", contentType)
}
// Validate file type by detecting content (more secure)
buffer := make([]byte, 512)
n, err := file.Read(buffer)
if err != nil {
return nil, fmt.Errorf("failed to read file for content detection: %v", err)
}
// Reset file pointer to beginning
if _, err := file.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to reset file pointer: %v", err)
}
detectedType := http.DetectContentType(buffer[:n])
if !AllowedMimeTypes[detectedType] {
return nil, fmt.Errorf("detected file type %s not allowed (header claimed %s)", detectedType, contentType)
}
// Generate unique filename
attachmentId := util.NewUUID()
ext := filepath.Ext(fileHeader.Filename)
fileName := attachmentId + ext
// Create attachments directory if it doesn't exist
uploadDir := filepath.Join(AttachmentDir, orgId, transactionId)
if err := os.MkdirAll(uploadDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create upload directory: %v", err)
}
// Create file path
filePath := filepath.Join(uploadDir, fileName)
// Create destination file
dst, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to create destination file: %v", err)
}
defer dst.Close()
// Copy file contents
if _, err := io.Copy(dst, file); err != nil {
return nil, fmt.Errorf("failed to save file: %v", err)
}
// Create attachment object
attachment := &types.Attachment{
Id: attachmentId,
TransactionId: transactionId,
OrgId: orgId,
UserId: userId,
FileName: fileName,
OriginalName: fileHeader.Filename,
ContentType: contentType,
FileSize: fileHeader.Size,
FilePath: filePath,
Description: description,
Uploaded: time.Now(),
Deleted: false,
}
return attachment, nil
}
func sanitizeFilename(filename string) string {
// Remove potentially dangerous characters
filename = strings.ReplaceAll(filename, "..", "")
filename = strings.ReplaceAll(filename, "/", "")
filename = strings.ReplaceAll(filename, "\\", "")
filename = strings.ReplaceAll(filename, "\x00", "") // null bytes
filename = strings.ReplaceAll(filename, "\r", "") // carriage return
filename = strings.ReplaceAll(filename, "\n", "") // newline
// Limit filename length
if len(filename) > 255 {
ext := filepath.Ext(filename)
base := filename[:255-len(ext)]
filename = base + ext
}
return filename
}

View File

@@ -0,0 +1,306 @@
package api
import (
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/core/util/id"
"github.com/openaccounting/oa-server/database"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupTestDatabase(t *testing.T) (*gorm.DB, func()) {
// Create temporary database file
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
// Set global DB for database package
database.DB = db
// Run migrations
err = database.AutoMigrate()
if err != nil {
t.Fatalf("Failed to run auto migrations: %v", err)
}
err = database.Migrate()
if err != nil {
t.Fatalf("Failed to run custom migrations: %v", err)
}
// Cleanup function
cleanup := func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
os.RemoveAll(tmpDir)
}
return db, cleanup
}
func setupTestData(t *testing.T, db *gorm.DB) (orgID, userID, transactionID string) {
// Use hardcoded UUIDs without dashes for hex format
orgID = "550e8400e29b41d4a716446655440000"
userID = "550e8400e29b41d4a716446655440001"
transactionID = "550e8400e29b41d4a716446655440002"
accountID := "550e8400e29b41d4a716446655440003"
// Insert test data using raw SQL for reliability
now := time.Now()
// Insert org
err := db.Exec("INSERT INTO orgs (id, inserted, updated, name, currency, `precision`, timezone) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?)",
orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Org", "USD", 2, "UTC").Error
if err != nil {
t.Fatalf("Failed to insert org: %v", err)
}
// Insert user
err = db.Exec("INSERT INTO users (id, inserted, updated, firstName, lastName, email, passwordHash, agreeToTerms, passwordReset, emailVerified, emailVerifyCode, signupSource) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test", "User", "test@example.com", "hashedpassword", true, "", true, "", "test").Error
if err != nil {
t.Fatalf("Failed to insert user: %v", err)
}
// Insert user-org relationship
err = db.Exec("INSERT INTO user_orgs (userId, orgId, admin) VALUES (UNHEX(?), UNHEX(?), ?)",
userID, orgID, false).Error
if err != nil {
t.Fatalf("Failed to insert user-org: %v", err)
}
// Insert account
err = db.Exec("INSERT INTO accounts (id, orgId, inserted, updated, name, parent, currency, `precision`, debitBalance) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?, ?)",
accountID, orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Account", []byte{}, "USD", 2, true).Error
if err != nil {
t.Fatalf("Failed to insert account: %v", err)
}
// Insert transaction
err = db.Exec("INSERT INTO transactions (id, orgId, userId, inserted, updated, date, description, data, deleted) VALUES (UNHEX(?), UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
transactionID, orgID, userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Transaction", "", false).Error
if err != nil {
t.Fatalf("Failed to insert transaction: %v", err)
}
// Insert split
err = db.Exec("INSERT INTO splits (transactionId, accountId, date, inserted, updated, amount, nativeAmount, deleted) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
transactionID, accountID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), 100, 100, false).Error
if err != nil {
t.Fatalf("Failed to insert split: %v", err)
}
return orgID, userID, transactionID
}
func createTestFile(t *testing.T) (string, []byte) {
content := []byte("This is a test file content for attachment testing")
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.txt")
err := os.WriteFile(filePath, content, 0644)
if err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
return filePath, content
}
func TestAttachmentIntegration(t *testing.T) {
db, cleanup := setupTestDatabase(t)
defer cleanup()
orgID, userID, transactionID := setupTestData(t, db)
// Set up the model instance for the API handlers
bc := &util.StandardBcrypt{}
// Use the existing datastore model which has the attachment implementation
// We need to create it with the database connection
datastoreModel := model.NewModel(nil, bc, types.Config{})
model.Instance = datastoreModel
t.Run("Database Integration Test", func(t *testing.T) {
// Test direct database operations first
filePath, originalContent := createTestFile(t)
defer os.Remove(filePath)
// Create attachment record directly
attachmentID := id.String(id.New())
uploadTime := time.Now()
attachment := types.Attachment{
Id: attachmentID,
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
FileName: "stored_test.txt",
OriginalName: "test.txt",
ContentType: "text/plain",
FileSize: int64(len(originalContent)),
FilePath: "uploads/test/" + attachmentID + ".txt",
Description: "Test attachment description",
Uploaded: uploadTime,
Deleted: false,
}
// Insert using the existing model
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.NoError(t, err)
assert.NotNil(t, createdAttachment)
assert.Equal(t, attachmentID, createdAttachment.Id)
// Verify database persistence
var dbAttachment types.Attachment
err = db.Raw("SELECT HEX(id) as id, HEX(transactionId) as transactionId, HEX(orgId) as orgId, HEX(userId) as userId, fileName, originalName, contentType, fileSize, filePath, description, uploaded, deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&dbAttachment).Error
assert.NoError(t, err)
assert.Equal(t, attachmentID, dbAttachment.Id)
assert.Equal(t, transactionID, dbAttachment.TransactionId)
assert.Equal(t, "Test attachment description", dbAttachment.Description)
// Test retrieval
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
assert.NotNil(t, retrievedAttachment)
assert.Equal(t, attachmentID, retrievedAttachment.Id)
// Test listing by transaction
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionID, orgID, userID)
assert.NoError(t, err)
assert.Len(t, attachments, 1)
assert.Equal(t, attachmentID, attachments[0].Id)
// Test soft deletion
err = model.Instance.DeleteAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
// Verify soft deletion in database
var deletedAttachment types.Attachment
err = db.Raw("SELECT deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&deletedAttachment).Error
assert.NoError(t, err)
assert.True(t, deletedAttachment.Deleted)
// Verify attachment is no longer accessible
retrievedAttachment, err = model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.Error(t, err)
assert.Nil(t, retrievedAttachment)
})
t.Run("File Upload Integration Test", func(t *testing.T) {
// Test file upload functionality
filePath, originalContent := createTestFile(t)
defer os.Remove(filePath)
// Create upload directory
uploadDir := "uploads/test"
os.MkdirAll(uploadDir, 0755)
defer os.RemoveAll("uploads")
// Simulate file upload process
attachmentID := id.String(id.New())
storedFilePath := filepath.Join(uploadDir, attachmentID+".txt")
// Copy file to upload location
err := copyFile(filePath, storedFilePath)
assert.NoError(t, err)
// Create attachment record
attachment := types.Attachment{
Id: attachmentID,
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
FileName: filepath.Base(storedFilePath),
OriginalName: "test.txt",
ContentType: "text/plain",
FileSize: int64(len(originalContent)),
FilePath: storedFilePath,
Description: "Uploaded test file",
Uploaded: time.Now(),
Deleted: false,
}
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.NoError(t, err)
assert.NotNil(t, createdAttachment)
// Verify file exists
_, err = os.Stat(storedFilePath)
assert.NoError(t, err)
// Verify database record
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
assert.Equal(t, storedFilePath, retrievedAttachment.FilePath)
assert.Equal(t, int64(len(originalContent)), retrievedAttachment.FileSize)
})
}
// Helper function to copy files
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
func TestAttachmentValidation(t *testing.T) {
db, cleanup := setupTestDatabase(t)
defer cleanup()
orgID, userID, transactionID := setupTestData(t, db)
// Set up the model instance
bc := &util.StandardBcrypt{}
gormModel := model.NewGormModel(db, bc, types.Config{})
model.Instance = gormModel
t.Run("Invalid attachment data", func(t *testing.T) {
// Test with missing required fields
attachment := types.Attachment{
// Missing ID
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
}
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.Error(t, err)
assert.Nil(t, createdAttachment)
})
t.Run("Non-existent attachment retrieval", func(t *testing.T) {
nonExistentID := id.String(id.New())
attachment, err := model.Instance.GetAttachment(nonExistentID, transactionID, orgID, userID)
assert.Error(t, err)
assert.Nil(t, attachment)
})
}

View File

@@ -0,0 +1,289 @@
package api
import (
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
"github.com/ant0ine/go-json-rest/rest"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/storage"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/core/util/id"
)
// AttachmentHandler handles attachment operations with configurable storage
type AttachmentHandler struct {
storage storage.Storage
}
// Global attachment handler instance (will be initialized during server startup)
var attachmentHandler *AttachmentHandler
// InitializeAttachmentHandler initializes the global attachment handler with storage backend
func InitializeAttachmentHandler(storageConfig storage.Config) error {
storageBackend, err := storage.NewStorage(storageConfig)
if err != nil {
return fmt.Errorf("failed to initialize storage backend: %w", err)
}
attachmentHandler = &AttachmentHandler{
storage: storageBackend,
}
return nil
}
// PostAttachmentWithStorage handles file upload using the configured storage backend
func PostAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
transactionId := r.FormValue("transactionId")
if transactionId == "" {
rest.Error(w, "Transaction ID is required", http.StatusBadRequest)
return
}
if !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid transaction ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Parse multipart form
err := r.ParseMultipartForm(MaxFileSize)
if err != nil {
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return
}
files := r.MultipartForm.File["file"]
if len(files) == 0 {
rest.Error(w, "No file provided", http.StatusBadRequest)
return
}
fileHeader := files[0] // Take the first file
// Verify transaction exists and user has permission
tx, err := model.Instance.GetTransaction(transactionId, "", user.Id)
if err != nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
if tx == nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
// Process the file upload
attachment, err := attachmentHandler.processFileUploadWithStorage(fileHeader, transactionId, tx.OrgId, user.Id, r.FormValue("description"))
if err != nil {
rest.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Save attachment to database
createdAttachment, err := model.Instance.CreateAttachment(attachment)
if err != nil {
// Clean up the stored file on database error
attachmentHandler.storage.Delete(attachment.FilePath)
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
w.WriteJson(createdAttachment)
}
// GetAttachmentWithStorage retrieves an attachment using the configured storage backend
func GetAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Check if this is a download request
if r.URL.Query().Get("download") == "true" {
// Stream the file directly to the client
err := attachmentHandler.streamFile(w, attachment)
if err != nil {
rest.Error(w, "Failed to retrieve file", http.StatusInternalServerError)
return
}
return
}
// Return attachment metadata
w.WriteJson(attachment)
}
// GetAttachmentDownloadURL returns a download URL for an attachment
func GetAttachmentDownloadURL(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Generate download URL (valid for 1 hour)
url, err := attachmentHandler.storage.GetURL(attachment.FilePath, time.Hour)
if err != nil {
rest.Error(w, "Failed to generate download URL", http.StatusInternalServerError)
return
}
response := map[string]string{
"url": url,
"expiresIn": "3600", // 1 hour in seconds
}
w.WriteJson(response)
}
// DeleteAttachmentWithStorage deletes an attachment using the configured storage backend
func DeleteAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database first
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Delete from database (soft delete)
err = model.Instance.DeleteAttachment(attachmentId, attachment.TransactionId, attachment.OrgId, user.Id)
if err != nil {
rest.Error(w, "Failed to delete attachment", http.StatusInternalServerError)
return
}
// Delete from storage backend
// Note: For production, you might want to delay physical deletion
// and run a cleanup job later to handle any issues
err = attachmentHandler.storage.Delete(attachment.FilePath)
if err != nil {
// Log the error but don't fail the request since database deletion succeeded
// The file can be cleaned up later by a maintenance job
fmt.Printf("Warning: Failed to delete file from storage: %v\n", err)
}
w.WriteHeader(http.StatusOK)
w.WriteJson(map[string]string{"status": "deleted"})
}
// processFileUploadWithStorage processes a file upload using the storage backend
func (h *AttachmentHandler) processFileUploadWithStorage(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
// Validate file size
if fileHeader.Size > MaxFileSize {
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
}
// Validate content type
contentType := fileHeader.Header.Get("Content-Type")
if !AllowedMimeTypes[contentType] {
return nil, fmt.Errorf("unsupported file type: %s", contentType)
}
// Open the file
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
}
defer file.Close()
// Store the file using the storage backend
storagePath, err := h.storage.Store(fileHeader.Filename, file, contentType)
if err != nil {
return nil, fmt.Errorf("failed to store file: %w", err)
}
// Create attachment record
attachment := &types.Attachment{
Id: id.String(id.New()),
TransactionId: transactionId,
OrgId: orgId,
UserId: userId,
FileName: storagePath, // Store the storage path/key
OriginalName: fileHeader.Filename,
ContentType: contentType,
FileSize: fileHeader.Size,
FilePath: storagePath, // For backward compatibility
Description: description,
Uploaded: time.Now(),
Deleted: false,
}
return attachment, nil
}
// streamFile streams a file from storage to the HTTP response
func (h *AttachmentHandler) streamFile(w rest.ResponseWriter, attachment *types.Attachment) error {
// Get file from storage
reader, err := h.storage.Retrieve(attachment.FilePath)
if err != nil {
return fmt.Errorf("failed to retrieve file: %w", err)
}
defer reader.Close()
// Set appropriate headers
w.Header().Set("Content-Type", attachment.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
// If we know the file size, set Content-Length
if attachment.FileSize > 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", attachment.FileSize))
}
// Stream the file to the client
_, err = io.Copy(w.(http.ResponseWriter), reader)
return err
}

View File

@@ -31,6 +31,17 @@ func GetRouter(auth *AuthMiddleware, prefix string) (rest.App, error) {
rest.Post(prefix+"/orgs/:orgId/transactions", auth.RequireAuth(PostTransaction)),
rest.Put(prefix+"/orgs/:orgId/transactions/:transactionId", auth.RequireAuth(PutTransaction)),
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId", auth.RequireAuth(DeleteTransaction)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments", auth.RequireAuth(GetAttachments)),
rest.Post(prefix+"/orgs/:orgId/transactions/:transactionId/attachments", auth.RequireAuth(PostAttachment)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(GetAttachment)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId/download", auth.RequireAuth(DownloadAttachment)),
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(DeleteAttachment)),
// New storage-based attachment endpoints
rest.Post(prefix+"/attachments", auth.RequireAuth(PostAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id", auth.RequireAuth(GetAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id/url", auth.RequireAuth(GetAttachmentDownloadURL)),
rest.Delete(prefix+"/attachments/:id", auth.RequireAuth(DeleteAttachmentWithStorage)),
rest.Get(prefix+"/orgs/:orgId/prices", auth.RequireAuth(GetPrices)),
rest.Post(prefix+"/orgs/:orgId/prices", auth.RequireAuth(PostPrice)),
rest.Delete(prefix+"/orgs/:orgId/prices/:priceId", auth.RequireAuth(DeletePrice)),

View File

@@ -14,6 +14,22 @@ type AuthService struct {
bcrypt util.Bcrypt
}
// AuthRepository interface for dependency injection
type AuthRepository interface {
GetVerifiedUserByEmail(string) (*types.User, error)
GetUserByActiveSession(string) (*types.User, error)
GetUserByApiKey(string) (*types.User, error)
GetUserByEmailVerifyCode(string) (*types.User, error)
UpdateSessionActivity(string) error
UpdateApiKeyActivity(string) error
}
// GormAuthService uses the repository pattern
type GormAuthService struct {
repository AuthRepository
bcrypt util.Bcrypt
}
type Interface interface {
Authenticate(string, string) (*types.User, error)
AuthenticateUser(email string, password string) (*types.User, error)
@@ -28,6 +44,12 @@ func NewAuthService(db db.Datastore, bcrypt util.Bcrypt) *AuthService {
return authService
}
func NewGormAuthService(repository AuthRepository, bcrypt util.Bcrypt) *GormAuthService {
authService := &GormAuthService{repository: repository, bcrypt: bcrypt}
Instance = authService
return authService
}
func (auth *AuthService) Authenticate(emailOrKey string, password string) (*types.User, error) {
// authenticate via session, apikey or user
user, err := auth.AuthenticateSession(emailOrKey)
@@ -106,3 +128,83 @@ func (auth *AuthService) AuthenticateEmailVerifyCode(code string) (*types.User,
return u, nil
}
// GormAuthService implementations
func (auth *GormAuthService) Authenticate(emailOrKey string, password string) (*types.User, error) {
// authenticate via session, apikey or user
user, err := auth.AuthenticateSession(emailOrKey)
if err == nil {
return user, nil
}
user, err = auth.AuthenticateApiKey(emailOrKey)
if err == nil {
return user, nil
}
user, err = auth.AuthenticateUser(emailOrKey, password)
if err == nil {
return user, nil
}
user, err = auth.AuthenticateEmailVerifyCode(emailOrKey)
if err == nil {
return user, nil
}
return nil, errors.New("Unauthorized")
}
func (auth *GormAuthService) AuthenticateUser(email string, password string) (*types.User, error) {
u, err := auth.repository.GetVerifiedUserByEmail(email)
if err != nil {
return nil, errors.New("Invalid email or password")
}
err = auth.bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
if err != nil {
return nil, errors.New("Invalid email or password")
}
return u, nil
}
func (auth *GormAuthService) AuthenticateSession(id string) (*types.User, error) {
u, err := auth.repository.GetUserByActiveSession(id)
if err != nil {
return nil, errors.New("Invalid session")
}
auth.repository.UpdateSessionActivity(id)
return u, nil
}
func (auth *GormAuthService) AuthenticateApiKey(id string) (*types.User, error) {
u, err := auth.repository.GetUserByApiKey(id)
if err != nil {
return nil, errors.New("Access denied")
}
auth.repository.UpdateApiKeyActivity(id)
return u, nil
}
func (auth *GormAuthService) AuthenticateEmailVerifyCode(code string) (*types.User, error) {
u, err := auth.repository.GetUserByEmailVerifyCode(code)
if err != nil {
return nil, errors.New("Access denied")
}
return u, nil
}

View File

@@ -2,12 +2,13 @@ package auth
import (
"errors"
"testing"
"time"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type TdUser struct {
@@ -28,18 +29,19 @@ func (td *TdUser) GetVerifiedUserByEmail(email string) (*types.User, error) {
func (td *TdUser) GetVerifiedUserByEmail_1(email string) (*types.User, error) {
return &types.User{
"1",
time.Unix(0, 0),
time.Unix(0, 0),
"John",
"Doe",
"johndoe@email.com",
"password",
"$2a$10$KrtvADe7jwrmYIe3GXFbNupOQaPIvyOKeng5826g4VGOD47TpAisG",
true,
"",
false,
"",
Id: "1",
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
FirstName: "John",
LastName: "Doe",
Email: "johndoe@email.com",
Password: "password",
PasswordHash: "$2a$10$KrtvADe7jwrmYIe3GXFbNupOQaPIvyOKeng5826g4VGOD47TpAisG",
AgreeToTerms: true,
PasswordReset: "",
EmailVerified: false,
EmailVerifyCode: "",
SignupSource: "",
}, nil
}

View File

@@ -10,6 +10,31 @@ type Datastore struct {
mock.Mock
}
// DeleteBudget implements db.Datastore.
func (_m *Datastore) DeleteBudget(string) error {
panic("unimplemented")
}
// GetBudget implements db.Datastore.
func (_m *Datastore) GetBudget(string) (*types.Budget, error) {
panic("unimplemented")
}
// GetUserByEmailVerifyCode implements db.Datastore.
func (_m *Datastore) GetUserByEmailVerifyCode(string) (*types.User, error) {
panic("unimplemented")
}
// InsertAndReplaceBudget implements db.Datastore.
func (_m *Datastore) InsertAndReplaceBudget(*types.Budget) error {
panic("unimplemented")
}
// Ping implements db.Datastore.
func (_m *Datastore) Ping() error {
panic("unimplemented")
}
// AcceptInvite provides a mock function with given fields: _a0, _a1
func (_m *Datastore) AcceptInvite(_a0 *types.Invite, _a1 string) error {
ret := _m.Called(_a0, _a1)
@@ -968,3 +993,74 @@ func (_m *Datastore) VerifyUser(_a0 string) error {
return r0
}
// Attachment interface mock methods
func (_m *Datastore) InsertAttachment(_a0 *types.Attachment) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(*types.Attachment) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
func (_m *Datastore) GetAttachment(_a0 string, _a1 string, _a2 string) (*types.Attachment, error) {
ret := _m.Called(_a0, _a1, _a2)
var r0 *types.Attachment
if rf, ok := ret.Get(0).(func(string, string, string) *types.Attachment); ok {
r0 = rf(_a0, _a1, _a2)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.Attachment)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string, string) error); ok {
r1 = rf(_a0, _a1, _a2)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (_m *Datastore) GetAttachmentsByTransaction(_a0 string, _a1 string) ([]*types.Attachment, error) {
ret := _m.Called(_a0, _a1)
var r0 []*types.Attachment
if rf, ok := ret.Get(0).(func(string, string) []*types.Attachment); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*types.Attachment)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (_m *Datastore) DeleteAttachment(_a0 string, _a1 string, _a2 string) error {
ret := _m.Called(_a0, _a1, _a2)
var r0 error
if rf, ok := ret.Get(0).(func(string, string, string) error); ok {
r0 = rf(_a0, _a1, _a2)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -408,6 +408,15 @@ func (model *Model) accountsContainWriteAccess(accounts []*types.Account, accoun
return false
}
func (model *Model) accountsContainReadAccess(accounts []*types.Account, accountId string) bool {
for _, account := range accounts {
if account.Id == accountId {
return true
}
}
return false
}
func (model *Model) getAccountFromList(accounts []*types.Account, accountId string) *types.Account {
for _, account := range accounts {
if account.Id == accountId {

View File

@@ -162,6 +162,10 @@ func TestCreateAccount(t *testing.T) {
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
// Mock GetSplitCountByAccountId for parent account check
if test.account.Parent != "" {
td.On("GetSplitCountByAccountId", test.account.Parent).Return(int64(0), nil)
}
model := NewModel(td, nil, types.Config{})
@@ -206,6 +210,10 @@ func TestUpdateAccount(t *testing.T) {
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
// Mock GetSplitCountByAccountId for parent account check
if test.account.Parent != "" {
td.On("GetSplitCountByAccountId", test.account.Parent).Return(int64(0), nil)
}
model := NewModel(td, nil, types.Config{})

163
core/model/attachment.go Normal file
View File

@@ -0,0 +1,163 @@
package model
import (
"errors"
"time"
"github.com/openaccounting/oa-server/core/model/types"
)
type AttachmentInterface interface {
CreateAttachment(*types.Attachment) (*types.Attachment, error)
GetAttachmentsByTransaction(string, string, string) ([]*types.Attachment, error)
GetAttachment(string, string, string, string) (*types.Attachment, error)
DeleteAttachment(string, string, string, string) error
}
func (model *Model) CreateAttachment(attachment *types.Attachment) (*types.Attachment, error) {
if attachment.Id == "" {
return nil, errors.New("attachment ID required")
}
if attachment.TransactionId == "" {
return nil, errors.New("transaction ID required")
}
if attachment.OrgId == "" {
return nil, errors.New("organization ID required")
}
if attachment.UserId == "" {
return nil, errors.New("user ID required")
}
if attachment.FileName == "" {
return nil, errors.New("file name required")
}
if attachment.FilePath == "" {
return nil, errors.New("file path required")
}
// Set upload timestamp
attachment.Uploaded = time.Now()
attachment.Deleted = false
// Save to database
err := model.db.InsertAttachment(attachment)
if err != nil {
return nil, err
}
return attachment, nil
}
func (model *Model) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
if transactionId == "" {
return nil, errors.New("transaction ID required")
}
if orgId == "" {
return nil, errors.New("organization ID required")
}
if userId == "" {
return nil, errors.New("user ID required")
}
// First verify the user has access to the transaction
tx, err := model.GetTransaction(transactionId, orgId, userId)
if err != nil {
return nil, err
}
if tx == nil {
return nil, errors.New("transaction not found or access denied")
}
// Get attachments for the transaction
attachments, err := model.db.GetAttachmentsByTransaction(transactionId, orgId)
if err != nil {
return nil, err
}
return attachments, nil
}
func (model *Model) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
if attachmentId == "" {
return nil, errors.New("attachment ID required")
}
if transactionId == "" {
return nil, errors.New("transaction ID required")
}
if orgId == "" {
return nil, errors.New("organization ID required")
}
if userId == "" {
return nil, errors.New("user ID required")
}
// First verify the user has access to the transaction
tx, err := model.GetTransaction(transactionId, orgId, userId)
if err != nil {
return nil, err
}
if tx == nil {
return nil, errors.New("transaction not found or access denied")
}
// Get the attachment
attachment, err := model.db.GetAttachment(attachmentId, transactionId, orgId)
if err != nil {
return nil, err
}
return attachment, nil
}
func (model *Model) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
if attachmentId == "" {
return errors.New("attachment ID required")
}
if transactionId == "" {
return errors.New("transaction ID required")
}
if orgId == "" {
return errors.New("organization ID required")
}
if userId == "" {
return errors.New("user ID required")
}
// First verify the user has access to the transaction
tx, err := model.GetTransaction(transactionId, orgId, userId)
if err != nil {
return err
}
if tx == nil {
return errors.New("transaction not found or access denied")
}
// Verify the attachment exists and belongs to the transaction
attachment, err := model.db.GetAttachment(attachmentId, transactionId, orgId)
if err != nil {
return err
}
if attachment == nil {
return errors.New("attachment not found")
}
// Soft delete the attachment
err = model.db.DeleteAttachment(attachmentId, transactionId, orgId)
if err != nil {
return err
}
return nil
}

View File

@@ -2,6 +2,7 @@ package model
import (
"errors"
"github.com/openaccounting/oa-server/core/model/types"
)
@@ -18,8 +19,8 @@ func (model *Model) GetBudget(orgId string, userId string) (*types.Budget, error
return nil, err
}
if belongs == false {
return nil, errors.New("User does not belong to org")
if !belongs {
return nil, errors.New("user does not belong to org")
}
return model.db.GetBudget(orgId)
@@ -32,8 +33,8 @@ func (model *Model) CreateBudget(budget *types.Budget, userId string) error {
return err
}
if belongs == false {
return errors.New("User does not belong to org")
if !belongs {
return errors.New("user does not belong to org")
}
if budget.OrgId == "" {
@@ -50,8 +51,8 @@ func (model *Model) DeleteBudget(orgId string, userId string) error {
return err
}
if belongs == false {
return errors.New("User does not belong to org")
if !belongs {
return errors.New("user does not belong to org")
}
return model.db.DeleteBudget(orgId)

126
core/model/db/attachment.go Normal file
View File

@@ -0,0 +1,126 @@
package db
import (
"database/sql"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
)
const attachmentFields = "LOWER(HEX(id)),LOWER(HEX(transactionId)),LOWER(HEX(orgId)),LOWER(HEX(userId)),fileName,originalName,contentType,fileSize,filePath,description,uploaded,deleted"
type AttachmentInterface interface {
InsertAttachment(*types.Attachment) error
GetAttachment(string, string, string) (*types.Attachment, error)
GetAttachmentsByTransaction(string, string) ([]*types.Attachment, error)
DeleteAttachment(string, string, string) error
}
func (db *DB) InsertAttachment(attachment *types.Attachment) error {
query := "INSERT INTO attachment(id,transactionId,orgId,userId,fileName,originalName,contentType,fileSize,filePath,description,uploaded,deleted) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),UNHEX(?),?,?,?,?,?,?,?,?)"
_, err := db.Exec(
query,
attachment.Id,
attachment.TransactionId,
attachment.OrgId,
attachment.UserId,
attachment.FileName,
attachment.OriginalName,
attachment.ContentType,
attachment.FileSize,
attachment.FilePath,
attachment.Description,
util.TimeToMs(attachment.Uploaded),
attachment.Deleted,
)
return err
}
func (db *DB) GetAttachment(attachmentId, transactionId, orgId string) (*types.Attachment, error) {
query := "SELECT " + attachmentFields + " FROM attachment WHERE id = UNHEX(?) AND transactionId = UNHEX(?) AND orgId = UNHEX(?) AND deleted = false"
row := db.QueryRow(query, attachmentId, transactionId, orgId)
return db.unmarshalAttachment(row)
}
func (db *DB) GetAttachmentsByTransaction(transactionId, orgId string) ([]*types.Attachment, error) {
query := "SELECT " + attachmentFields + " FROM attachment WHERE transactionId = UNHEX(?) AND orgId = UNHEX(?) AND deleted = false ORDER BY uploaded DESC"
rows, err := db.Query(query, transactionId, orgId)
if err != nil {
return nil, err
}
return db.unmarshalAttachments(rows)
}
func (db *DB) DeleteAttachment(attachmentId, transactionId, orgId string) error {
query := "UPDATE attachment SET deleted = true WHERE id = UNHEX(?) AND transactionId = UNHEX(?) AND orgId = UNHEX(?)"
_, err := db.Exec(query, attachmentId, transactionId, orgId)
return err
}
func (db *DB) unmarshalAttachment(row *sql.Row) (*types.Attachment, error) {
attachment := &types.Attachment{}
var uploaded int64
err := row.Scan(
&attachment.Id,
&attachment.TransactionId,
&attachment.OrgId,
&attachment.UserId,
&attachment.FileName,
&attachment.OriginalName,
&attachment.ContentType,
&attachment.FileSize,
&attachment.FilePath,
&attachment.Description,
&uploaded,
&attachment.Deleted,
)
if err != nil {
return nil, err
}
attachment.Uploaded = util.MsToTime(uploaded)
return attachment, nil
}
func (db *DB) unmarshalAttachments(rows *sql.Rows) ([]*types.Attachment, error) {
defer rows.Close()
attachments := []*types.Attachment{}
for rows.Next() {
attachment := &types.Attachment{}
var uploaded int64
err := rows.Scan(
&attachment.Id,
&attachment.TransactionId,
&attachment.OrgId,
&attachment.UserId,
&attachment.FileName,
&attachment.OriginalName,
&attachment.ContentType,
&attachment.FileSize,
&attachment.FilePath,
&attachment.Description,
&uploaded,
&attachment.Deleted,
)
if err != nil {
return nil, err
}
attachment.Uploaded = util.MsToTime(uploaded)
attachments = append(attachments, attachment)
}
return attachments, nil
}

View File

@@ -15,6 +15,7 @@ type Datastore interface {
OrgInterface
AccountInterface
TransactionInterface
AttachmentInterface
PriceInterface
SessionInterface
ApiKeyInterface

353
core/model/gorm_model.go Normal file
View File

@@ -0,0 +1,353 @@
package model
import (
"errors"
"time"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/repository"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/database"
"gorm.io/gorm"
)
// GormModel is the GORM-based implementation of the Model
type GormModel struct {
repository *repository.GormRepository
bcrypt util.Bcrypt
config types.Config
}
// NewGormModel creates a new GORM-based model
func NewGormModel(gormDB *gorm.DB, bcrypt util.Bcrypt, config types.Config) *GormModel {
repo := repository.NewGormRepository(gormDB)
return &GormModel{
repository: repo,
bcrypt: bcrypt,
config: config,
}
}
// CreateModel creates a new model using the existing database connection
func CreateGormModel(bcrypt util.Bcrypt, config types.Config) (*GormModel, error) {
// Use the existing database connection
if database.DB == nil {
return nil, errors.New("database connection not initialized")
}
return NewGormModel(database.DB, bcrypt, config), nil
}
// Implement the Interface by delegating to the business logic layer
// The business logic layer (existing model methods) will call the repository
// UserInterface methods - delegate to existing business logic
func (m *GormModel) CreateUser(user *types.User) error {
// The existing business logic in user.go will be updated to use the repository
// For now, delegate directly to repository for basic operations
return m.repository.InsertUser(user)
}
func (m *GormModel) VerifyUser(code string) error {
return m.repository.VerifyUser(code)
}
func (m *GormModel) UpdateUser(user *types.User) error {
return m.repository.UpdateUser(user)
}
func (m *GormModel) ResetPassword(email string) error {
// This would need the full business logic from the original model
// For now, simplified implementation
user, err := m.repository.GetVerifiedUserByEmail(email)
if err != nil {
return err
}
user.PasswordReset, err = util.NewGuid()
if err != nil {
return err
}
return m.repository.UpdateUserResetPassword(user)
}
func (m *GormModel) ConfirmResetPassword(password string, code string) (*types.User, error) {
user, err := m.repository.GetUserByResetCode(code)
if err != nil {
return nil, err
}
passwordHash, err := m.bcrypt.GenerateFromPassword([]byte(password), m.bcrypt.GetDefaultCost())
if err != nil {
return nil, err
}
user.PasswordHash = string(passwordHash)
user.Password = ""
err = m.repository.UpdateUser(user)
if err != nil {
return nil, err
}
return user, nil
}
// AccountInterface methods - delegate to repository
func (m *GormModel) CreateAccount(account *types.Account, userId string) error {
return m.repository.InsertAccount(account)
}
func (m *GormModel) UpdateAccount(account *types.Account, userId string) error {
return m.repository.UpdateAccount(account)
}
func (m *GormModel) DeleteAccount(id string, userId string, orgId string) error {
return m.repository.DeleteAccount(id)
}
func (m *GormModel) GetAccounts(orgId string, userId string, tokenId string) ([]*types.Account, error) {
return m.repository.GetAccountsByOrgId(orgId)
}
func (m *GormModel) GetAccountsWithBalances(orgId string, userId string, tokenId string, date time.Time) ([]*types.Account, error) {
accounts, err := m.repository.GetAccountsByOrgId(orgId)
if err != nil {
return nil, err
}
// Add balance calculations
err = m.repository.AddBalances(accounts, date)
if err != nil {
return nil, err
}
return accounts, nil
}
func (m *GormModel) GetAccount(orgId, accId, userId, tokenId string) (*types.Account, error) {
return m.repository.GetAccount(accId)
}
func (m *GormModel) GetAccountWithBalance(orgId, accId, userId, tokenId string, date time.Time) (*types.Account, error) {
account, err := m.repository.GetAccount(accId)
if err != nil {
return nil, err
}
// Add balance calculation
err = m.repository.AddBalance(account, date)
if err != nil {
return nil, err
}
return account, nil
}
// Complete OrgInterface implementation
func (m *GormModel) CreateOrg(org *types.Org, userId string) error {
// Get default accounts - this needs to be implemented properly
accounts := []*types.Account{} // Empty for now, should create default chart of accounts
return m.repository.CreateOrg(org, userId, accounts)
}
func (m *GormModel) GetOrg(orgId, userId string) (*types.Org, error) {
return m.repository.GetOrg(orgId, userId)
}
func (m *GormModel) GetOrgs(userId string) ([]*types.Org, error) {
return m.repository.GetOrgs(userId)
}
func (m *GormModel) UpdateOrg(org *types.Org, userId string) error {
return m.repository.UpdateOrg(org)
}
func (m *GormModel) CreateInvite(invite *types.Invite, userId string) error {
return m.repository.InsertInvite(invite)
}
func (m *GormModel) AcceptInvite(invite *types.Invite, userId string) error {
return m.repository.AcceptInvite(invite, userId)
}
func (m *GormModel) GetInvites(orgId, userId string) ([]*types.Invite, error) {
return m.repository.GetInvites(orgId)
}
func (m *GormModel) DeleteInvite(inviteId, userId string) error {
return m.repository.DeleteInvite(inviteId)
}
// SessionInterface implementation
func (m *GormModel) CreateSession(session *types.Session) error {
return m.repository.InsertSession(session)
}
func (m *GormModel) InsertSession(session *types.Session) error {
return m.repository.InsertSession(session)
}
func (m *GormModel) DeleteSession(sessionId, userId string) error {
return m.repository.DeleteSession(sessionId, userId)
}
func (m *GormModel) UpdateSessionActivity(sessionId string) error {
return m.repository.UpdateSessionActivity(sessionId)
}
// ApiKeyInterface implementation
func (m *GormModel) CreateApiKey(apiKey *types.ApiKey) error {
return m.repository.InsertApiKey(apiKey)
}
func (m *GormModel) InsertApiKey(apiKey *types.ApiKey) error {
return m.repository.InsertApiKey(apiKey)
}
func (m *GormModel) UpdateApiKey(apiKey *types.ApiKey) error {
return m.repository.UpdateApiKey(apiKey)
}
func (m *GormModel) DeleteApiKey(keyId, userId string) error {
return m.repository.DeleteApiKey(keyId, userId)
}
func (m *GormModel) GetApiKeys(userId string) ([]*types.ApiKey, error) {
return m.repository.GetApiKeys(userId)
}
func (m *GormModel) UpdateApiKeyActivity(keyId string) error {
return m.repository.UpdateApiKeyActivity(keyId)
}
// TransactionInterface implementation
func (m *GormModel) CreateTransaction(transaction *types.Transaction) error {
return m.repository.InsertTransaction(transaction)
}
func (m *GormModel) UpdateTransaction(transactionId string, transaction *types.Transaction) error {
return m.repository.DeleteAndInsertTransaction(transactionId, transaction)
}
func (m *GormModel) GetTransactionsByAccount(accountId, orgId, userId string, options *types.QueryOptions) ([]*types.Transaction, error) {
return m.repository.GetTransactionsByAccount(accountId, options)
}
func (m *GormModel) GetTransactionsByOrg(orgId, userId string, options *types.QueryOptions) ([]*types.Transaction, error) {
return m.repository.GetTransactionsByOrg(orgId, options, []string{})
}
func (m *GormModel) DeleteTransaction(transactionId, orgId, userId string) error {
return m.repository.DeleteTransaction(transactionId)
}
func (m *GormModel) InsertTransaction(transaction *types.Transaction) error {
return m.repository.InsertTransaction(transaction)
}
func (m *GormModel) GetTransaction(transactionId, orgId, userId string) (*types.Transaction, error) {
// For now, delegate to repository - in a full implementation, this would include permission checking
return m.repository.GetTransactionById(transactionId)
}
// AttachmentInterface implementation
func (m *GormModel) CreateAttachment(attachment *types.Attachment) (*types.Attachment, error) {
if attachment.Id == "" {
return nil, errors.New("attachment ID required")
}
// Set upload timestamp
attachment.Uploaded = time.Now()
attachment.Deleted = false
// For GORM implementation, we'd need to implement repository methods
// For now, return an error indicating not implemented
return nil, errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
return nil, errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
return nil, errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
return errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) GetTransactionById(id string) (*types.Transaction, error) {
return m.repository.GetTransactionById(id)
}
func (m *GormModel) DeleteAndInsertTransaction(id string, transaction *types.Transaction) error {
return m.repository.DeleteAndInsertTransaction(id, transaction)
}
// PriceInterface implementation
func (m *GormModel) CreatePrice(price *types.Price, userId string) error {
return m.repository.InsertPrice(price)
}
func (m *GormModel) DeletePrice(priceId, userId string) error {
// Stub implementation - would need proper implementation
return nil
}
func (m *GormModel) GetPricesNearestInTime(orgId string, date time.Time, currency string) ([]*types.Price, error) {
// Stub implementation - would need proper implementation based on specific logic
return m.repository.GetPrices(orgId, date)
}
func (m *GormModel) GetPricesByCurrency(orgId, currency, userId string) ([]*types.Price, error) {
// Stub implementation - would need proper implementation based on specific logic
return m.repository.GetPrices(orgId, time.Now())
}
func (m *GormModel) GetPrices(orgId string, date time.Time) ([]*types.Price, error) {
return m.repository.GetPrices(orgId, date)
}
func (m *GormModel) InsertPrice(price *types.Price) error {
return m.repository.InsertPrice(price)
}
// SystemHealthInteface implementation
func (m *GormModel) PingDatabase() error {
return m.repository.Ping()
}
func (m *GormModel) Ping() error {
return m.repository.Ping()
}
// BudgetInterface implementation
func (m *GormModel) GetBudget(orgId, userId string) (*types.Budget, error) {
// Stub implementation - would need proper implementation
return &types.Budget{}, nil
}
func (m *GormModel) CreateBudget(budget *types.Budget, userId string) error {
return m.repository.InsertBudget(budget)
}
func (m *GormModel) DeleteBudget(budgetId, userId string) error {
// Stub implementation - would need proper implementation
return nil
}
func (m *GormModel) InsertBudget(budget *types.Budget) error {
return m.repository.InsertBudget(budget)
}
func (m *GormModel) GetBudgets(orgId string) ([]*types.Budget, error) {
return m.repository.GetBudgets(orgId)
}
// Helper methods
func (m *GormModel) GetOrgUserIds(orgId string) ([]string, error) {
return m.repository.GetOrgUserIds(orgId)
}

View File

@@ -14,16 +14,19 @@ type Model struct {
config types.Config
}
type Interface interface {
UserInterface
OrgInterface
AccountInterface
TransactionInterface
AttachmentInterface
PriceInterface
SessionInterface
ApiKeyInterface
SystemHealthInteface
BudgetInterface
GetTransaction(string, string, string) (*types.Transaction, error)
}
func NewModel(db db.Datastore, bcrypt util.Bcrypt, config types.Config) *Model {
@@ -31,3 +34,4 @@ func NewModel(db db.Datastore, bcrypt util.Bcrypt, config types.Config) *Model {
Instance = model
return model
}

View File

@@ -2,44 +2,45 @@ package model
import (
"errors"
"testing"
"time"
"github.com/openaccounting/oa-server/core/mocks"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestCreatePrice(t *testing.T) {
price := types.Price{
"1",
"2",
"BTC",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
Id: "1",
OrgId: "2",
Currency: "BTC",
Date: time.Unix(0, 0),
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
Price: 6700,
}
badPrice := types.Price{
"1",
"2",
"",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
Id: "1",
OrgId: "2",
Currency: "",
Date: time.Unix(0, 0),
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
Price: 6700,
}
badOrg := types.Price{
"1",
"1",
"BTC",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
Id: "1",
OrgId: "1",
Currency: "BTC",
Date: time.Unix(0, 0),
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
Price: 6700,
}
tests := map[string]struct {
@@ -89,13 +90,13 @@ func TestCreatePrice(t *testing.T) {
func TestDeletePrice(t *testing.T) {
price := types.Price{
"1",
"2",
"BTC",
time.Unix(0, 0),
time.Unix(0, 0),
time.Unix(0, 0),
6700,
Id: "1",
OrgId: "2",
Currency: "BTC",
Date: time.Unix(0, 0),
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
Price: 6700,
}
tests := map[string]struct {

View File

@@ -3,9 +3,10 @@ package model
import (
"errors"
"fmt"
"time"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/ws"
"time"
)
type TransactionInterface interface {
@@ -105,7 +106,7 @@ func (model *Model) GetTransactionsByAccount(orgId string, userId string, accoun
}
if !model.accountsContainWriteAccess(userAccounts, accountId) {
return nil, errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", accountId))
return nil, fmt.Errorf("%s %s", "user does not have permission to access account", accountId)
}
return model.db.GetTransactionsByAccount(accountId, options)
@@ -142,7 +143,7 @@ func (model *Model) DeleteTransaction(id string, userId string, orgId string) (e
for _, split := range transaction.Splits {
if !model.accountsContainWriteAccess(userAccounts, split.AccountId) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", split.AccountId))
return fmt.Errorf("%s %s", "user does not have permission to access account", split.AccountId)
}
}
@@ -168,6 +169,31 @@ func (model *Model) getTransactionById(id string) (*types.Transaction, error) {
return model.db.GetTransactionById(id)
}
func (model *Model) GetTransaction(transactionId, orgId, userId string) (*types.Transaction, error) {
transaction, err := model.getTransactionById(transactionId)
if err != nil {
return nil, err
}
if transaction == nil || transaction.OrgId != orgId {
return nil, nil
}
// Check if user has access to all accounts in the transaction
userAccounts, err := model.GetAccounts(orgId, userId, "")
if err != nil {
return nil, err
}
for _, split := range transaction.Splits {
if !model.accountsContainReadAccess(userAccounts, split.AccountId) {
return nil, fmt.Errorf("user does not have permission to access account %s", split.AccountId)
}
}
return transaction, nil
}
func (model *Model) checkSplits(transaction *types.Transaction) (err error) {
if len(transaction.Splits) < 2 {
return errors.New("at least 2 splits are required")
@@ -189,13 +215,13 @@ func (model *Model) checkSplits(transaction *types.Transaction) (err error) {
for _, split := range transaction.Splits {
if !model.accountsContainWriteAccess(userAccounts, split.AccountId) {
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", split.AccountId))
return fmt.Errorf("%s %s", "user does not have permission to access account", split.AccountId)
}
account := model.getAccountFromList(userAccounts, split.AccountId)
if account.HasChildren == true {
return errors.New("Cannot use parent account for split")
if !account.HasChildren {
return errors.New("cannot use parent account for split")
}
if account.Currency == org.Currency && split.NativeAmount != split.Amount {

View File

@@ -2,12 +2,13 @@ package model
import (
"errors"
"testing"
"time"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
"time"
)
type TdTransaction struct {
@@ -57,72 +58,72 @@ func TestCreateTransaction(t *testing.T) {
"successful": {
err: nil,
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 1000},
&types.Split{"1", "2", -1000, -1000},
Id: "1",
OrgId: "2",
UserId: "3",
Date: time.Now(),
Inserted: time.Now(),
Updated: time.Now(),
Description: "description",
Data: "",
Deleted: false,
Splits: []*types.Split{
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 1000},
&types.Split{TransactionId: "1", AccountId: "2", Amount: -1000, NativeAmount: -1000},
},
},
},
"bad split amounts": {
err: errors.New("splits must add up to 0"),
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 1000},
&types.Split{"1", "2", -500, -500},
Id: "1",
OrgId: "2",
UserId: "3",
Date: time.Now(),
Inserted: time.Now(),
Updated: time.Now(),
Description: "description",
Data: "",
Deleted: false,
Splits: []*types.Split{
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 1000},
&types.Split{TransactionId: "1", AccountId: "2", Amount: -500, NativeAmount: -500},
},
},
},
"lacking permission": {
err: errors.New("user does not have permission to access account 3"),
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 1000},
&types.Split{"1", "3", -1000, -1000},
Id: "1",
OrgId: "2",
UserId: "3",
Date: time.Now(),
Inserted: time.Now(),
Updated: time.Now(),
Description: "description",
Data: "",
Deleted: false,
Splits: []*types.Split{
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 1000},
&types.Split{TransactionId: "1", AccountId: "3", Amount: -1000, NativeAmount: -1000},
},
},
},
"nativeAmount mismatch": {
err: errors.New("nativeAmount must equal amount for native currency splits"),
tx: &types.Transaction{
"1",
"2",
"3",
time.Now(),
time.Now(),
time.Now(),
"description",
"",
false,
[]*types.Split{
&types.Split{"1", "1", 1000, 500},
&types.Split{"1", "2", -1000, -500},
Id: "1",
OrgId: "2",
UserId: "3",
Date: time.Now(),
Inserted: time.Now(),
Updated: time.Now(),
Description: "description",
Data: "",
Deleted: false,
Splits: []*types.Split{
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 500},
&types.Split{TransactionId: "1", AccountId: "2", Amount: -1000, NativeAmount: -500},
},
},
},

View File

@@ -0,0 +1,20 @@
package types
import (
"time"
)
type Attachment struct {
Id string `json:"id"`
TransactionId string `json:"transactionId"`
OrgId string `json:"orgId"`
UserId string `json:"userId"`
FileName string `json:"fileName"`
OriginalName string `json:"originalName"`
ContentType string `json:"contentType"`
FileSize int64 `json:"fileSize"`
FilePath string `json:"filePath"`
Description string `json:"description"`
Uploaded time.Time `json:"uploaded"`
Deleted bool `json:"deleted"`
}

View File

@@ -1,18 +1,27 @@
package types
import "github.com/openaccounting/oa-server/core/storage"
type Config struct {
WebUrl string
Address string
Port int
ApiPrefix string
KeyFile string
CertFile string
DatabaseAddress string
Database string
User string
Password string
MailgunDomain string
MailgunKey string
MailgunEmail string
MailgunSender string
WebUrl string `mapstructure:"weburl"`
Address string `mapstructure:"address"`
Port int `mapstructure:"port"`
ApiPrefix string `mapstructure:"apiprefix"`
KeyFile string `mapstructure:"keyfile"`
CertFile string `mapstructure:"certfile"`
// Database configuration
DatabaseDriver string `mapstructure:"databasedriver"` // "mysql" or "sqlite"
DatabaseAddress string `mapstructure:"databaseaddress"`
Database string `mapstructure:"database"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"` // Sensitive: use OA_PASSWORD env var
// SQLite specific
DatabaseFile string `mapstructure:"databasefile"`
// Email configuration
MailgunDomain string `mapstructure:"mailgundomain"`
MailgunKey string `mapstructure:"mailgunkey"` // Sensitive: use OA_MAILGUN_KEY env var
MailgunEmail string `mapstructure:"mailgunemail"`
MailgunSender string `mapstructure:"mailgunsender"`
// Storage configuration
Storage storage.Config `mapstructure:"storage"`
}

View File

@@ -37,7 +37,7 @@ func (model *Model) CreateUser(user *types.User) error {
return errors.New("email required")
}
re := regexp.MustCompile(".+@.+\\..+")
re := regexp.MustCompile(`.+@.+\..+`)
if re.FindString(user.Email) == "" {
return errors.New("invalid email address")
@@ -47,7 +47,7 @@ func (model *Model) CreateUser(user *types.User) error {
return errors.New("password required")
}
if user.AgreeToTerms != true {
if !user.AgreeToTerms {
return errors.New("must agree to terms")
}
@@ -123,7 +123,7 @@ func (model *Model) ResetPassword(email string) error {
if err != nil {
// Don't send back error so people can't try to find user accounts
log.Printf("Invalid email for reset password " + email)
log.Printf("Invalid email for reset password %s", email)
return nil
}
@@ -154,7 +154,7 @@ func (model *Model) ConfirmResetPassword(password string, code string) (*types.U
user, err := model.db.GetUserByResetCode(code)
if err != nil {
return nil, errors.New("Invalid code")
return nil, errors.New("invalid code")
}
passwordHash, err := model.bcrypt.GenerateFromPassword([]byte(password), model.bcrypt.GetDefaultCost())

View File

@@ -2,12 +2,13 @@ package model
import (
"errors"
"testing"
"time"
"github.com/openaccounting/oa-server/core/mocks"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type TdUser struct {
@@ -39,33 +40,35 @@ func TestCreateUser(t *testing.T) {
// EmailVerifyCode string `json:"-"`
user := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John",
"Doe",
"johndoe@email.com",
"password",
"",
true,
"",
false,
"",
Id: "0",
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
FirstName: "John",
LastName: "Doe",
Email: "johndoe@email.com",
Password: "password",
PasswordHash: "",
AgreeToTerms: true,
PasswordReset: "",
EmailVerified: false,
EmailVerifyCode: "",
SignupSource: "",
}
badUser := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John",
"Doe",
"",
"password",
"",
true,
"",
false,
"",
Id: "0",
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
FirstName: "John",
LastName: "Doe",
Email: "",
Password: "password",
PasswordHash: "",
AgreeToTerms: true,
PasswordReset: "",
EmailVerified: false,
EmailVerifyCode: "",
SignupSource: "",
}
tests := map[string]struct {
@@ -109,33 +112,35 @@ func TestCreateUser(t *testing.T) {
func TestUpdateUser(t *testing.T) {
user := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John2",
"Doe",
"johndoe@email.com",
"password",
"",
true,
"",
false,
"",
Id: "0",
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
FirstName: "John2",
LastName: "Doe",
Email: "johndoe@email.com",
Password: "password",
PasswordHash: "",
AgreeToTerms: true,
PasswordReset: "",
EmailVerified: false,
EmailVerifyCode: "",
SignupSource: "",
}
badUser := types.User{
"0",
time.Unix(0, 0),
time.Unix(0, 0),
"John2",
"Doe",
"johndoe@email.com",
"",
"",
true,
"",
false,
"",
Id: "0",
Inserted: time.Unix(0, 0),
Updated: time.Unix(0, 0),
FirstName: "John2",
LastName: "Doe",
Email: "johndoe@email.com",
Password: "",
PasswordHash: "",
AgreeToTerms: true,
PasswordReset: "",
EmailVerified: false,
EmailVerifyCode: "",
SignupSource: "",
}
tests := map[string]struct {

View File

@@ -0,0 +1,375 @@
package repository
import (
"errors"
"strings"
"time"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/models"
"gorm.io/gorm"
)
// GormRepository implements the same interfaces as core/model/db but uses GORM
type GormRepository struct {
db *gorm.DB
}
// Note: GormRepository implements most of the Datastore interface
// Some methods like DeleteAndInsertTransaction need to be added for full compatibility
// NewGormRepository creates a new GORM repository
func NewGormRepository(db *gorm.DB) *GormRepository {
return &GormRepository{db: db}
}
// UserInterface implementation
func (r *GormRepository) InsertUser(user *types.User) error {
user.Inserted = time.Now()
user.Updated = user.Inserted
user.PasswordReset = ""
// Convert types.User to models.User
gormUser := &models.User{
ID: []byte(user.Id), // Convert string ID to []byte
Inserted: uint64(util.TimeToMs(user.Inserted)),
Updated: uint64(util.TimeToMs(user.Updated)),
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email,
PasswordHash: user.PasswordHash,
AgreeToTerms: user.AgreeToTerms,
PasswordReset: user.PasswordReset,
EmailVerified: user.EmailVerified,
EmailVerifyCode: user.EmailVerifyCode,
SignupSource: user.SignupSource,
}
result := r.db.Create(gormUser)
if result.Error != nil {
return result.Error
}
if result.RowsAffected < 1 {
return errors.New("unable to insert user into db")
}
return nil
}
func (r *GormRepository) VerifyUser(code string) error {
result := r.db.Model(&models.User{}).
Where("email_verify_code = ?", code).
Updates(map[string]interface{}{
"updated": util.TimeToMs(time.Now()),
"email_verified": true,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("invalid code")
}
return nil
}
func (r *GormRepository) UpdateUser(user *types.User) error {
user.Updated = time.Now()
result := r.db.Model(&models.User{}).
Where("id = ?", []byte(user.Id)).
Updates(map[string]interface{}{
"updated": util.TimeToMs(user.Updated),
"password_hash": user.PasswordHash,
"password_reset": "",
})
return result.Error
}
func (r *GormRepository) UpdateUserResetPassword(user *types.User) error {
user.Updated = time.Now()
result := r.db.Model(&models.User{}).
Where("id = ?", []byte(user.Id)).
Updates(map[string]interface{}{
"updated": util.TimeToMs(user.Updated),
"password_reset": user.PasswordReset,
})
return result.Error
}
func (r *GormRepository) GetVerifiedUserByEmail(email string) (*types.User, error) {
var gormUser models.User
result := r.db.Where("email = ? AND email_verified = ?",
strings.TrimSpace(strings.ToLower(email)), true).
First(&gormUser)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormUserToTypesUser(&gormUser), nil
}
func (r *GormRepository) GetUserByActiveSession(sessionId string) (*types.User, error) {
var gormUser models.User
result := r.db.Table("users").
Select("users.*").
Joins("JOIN sessions ON sessions.user_id = users.id").
Where("sessions.terminated IS NULL AND sessions.id = ?", []byte(sessionId)).
First(&gormUser)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormUserToTypesUser(&gormUser), nil
}
func (r *GormRepository) GetUserByApiKey(keyId string) (*types.User, error) {
var gormUser models.User
result := r.db.Table("users").
Select("users.*").
Joins("JOIN api_keys ON api_keys.user_id = users.id").
Where("api_keys.deleted_at IS NULL AND api_keys.id = ?", []byte(keyId)).
First(&gormUser)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormUserToTypesUser(&gormUser), nil
}
func (r *GormRepository) GetUserByResetCode(code string) (*types.User, error) {
var gormUser models.User
result := r.db.Where("password_reset = ?", code).First(&gormUser)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormUserToTypesUser(&gormUser), nil
}
func (r *GormRepository) GetUserByEmailVerifyCode(code string) (*types.User, error) {
// only allow this for 3 days
minInserted := (time.Now().UnixNano() / 1000000) - (3 * 24 * 60 * 60 * 1000)
var gormUser models.User
result := r.db.Where("email_verify_code = ? AND inserted > ?", code, minInserted).
First(&gormUser)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormUserToTypesUser(&gormUser), nil
}
func (r *GormRepository) GetOrgAdmins(orgId string) ([]*types.User, error) {
var gormUsers []models.User
result := r.db.Table("users").
Select("users.*").
Joins("JOIN user_orgs ON user_orgs.user_id = users.id").
Where("user_orgs.admin = ? AND user_orgs.org_id = ?", true, []byte(orgId)).
Find(&gormUsers)
if result.Error != nil {
return nil, result.Error
}
users := make([]*types.User, len(gormUsers))
for i, gormUser := range gormUsers {
users[i] = r.convertGormUserToTypesUser(&gormUser)
}
return users, nil
}
// Helper function to convert GORM User to types.User
func (r *GormRepository) convertGormUserToTypesUser(gormUser *models.User) *types.User {
return &types.User{
Id: string(gormUser.ID),
Inserted: util.MsToTime(int64(gormUser.Inserted)),
Updated: util.MsToTime(int64(gormUser.Updated)),
FirstName: gormUser.FirstName,
LastName: gormUser.LastName,
Email: gormUser.Email,
PasswordHash: gormUser.PasswordHash,
AgreeToTerms: gormUser.AgreeToTerms,
PasswordReset: gormUser.PasswordReset,
EmailVerified: gormUser.EmailVerified,
EmailVerifyCode: gormUser.EmailVerifyCode,
SignupSource: gormUser.SignupSource,
}
}
// AccountInterface implementation
func (r *GormRepository) InsertAccount(account *types.Account) error {
account.Inserted = time.Now()
account.Updated = account.Inserted
// Convert types.Account to models.Account
gormAccount := &models.Account{
ID: []byte(account.Id),
OrgID: []byte(account.OrgId),
Inserted: uint64(util.TimeToMs(account.Inserted)),
Updated: uint64(util.TimeToMs(account.Updated)),
Name: account.Name,
Parent: []byte(account.Parent),
Currency: account.Currency,
Precision: account.Precision,
DebitBalance: account.DebitBalance,
}
return r.db.Create(gormAccount).Error
}
func (r *GormRepository) UpdateAccount(account *types.Account) error {
account.Updated = time.Now()
result := r.db.Model(&models.Account{}).
Where("id = ?", []byte(account.Id)).
Updates(map[string]interface{}{
"updated": util.TimeToMs(account.Updated),
"name": account.Name,
"parent": []byte(account.Parent),
"currency": account.Currency,
"precision": account.Precision,
"debit_balance": account.DebitBalance,
})
return result.Error
}
func (r *GormRepository) GetAccount(id string) (*types.Account, error) {
var gormAccount models.Account
result := r.db.Where("id = ?", []byte(id)).First(&gormAccount)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormAccountToTypesAccount(&gormAccount), nil
}
func (r *GormRepository) GetAccountsByOrgId(orgId string) ([]*types.Account, error) {
var gormAccounts []models.Account
result := r.db.Where("org_id = ?", []byte(orgId)).Find(&gormAccounts)
if result.Error != nil {
return nil, result.Error
}
accounts := make([]*types.Account, len(gormAccounts))
for i, gormAccount := range gormAccounts {
accounts[i] = r.convertGormAccountToTypesAccount(&gormAccount)
}
return accounts, nil
}
func (r *GormRepository) GetPermissionedAccountIds(orgId, userId, tokenId string) ([]string, error) {
var accountIds []string
result := r.db.Table("permissions").
Select("DISTINCT LOWER(HEX(account_id)) as account_id").
Where("org_id = ? AND user_id = ?", []byte(orgId), []byte(userId)).
Pluck("account_id", &accountIds)
return accountIds, result.Error
}
func (r *GormRepository) GetSplitCountByAccountId(id string) (int64, error) {
var count int64
result := r.db.Model(&models.Split{}).
Where("account_id = ?", []byte(id)).
Count(&count)
return count, result.Error
}
func (r *GormRepository) GetChildCountByAccountId(id string) (int64, error) {
var count int64
result := r.db.Model(&models.Account{}).
Where("parent = ?", []byte(id)).
Count(&count)
return count, result.Error
}
func (r *GormRepository) DeleteAccount(id string) error {
return r.db.Where("id = ?", []byte(id)).Delete(&models.Account{}).Error
}
func (r *GormRepository) GetRootAccount(orgId string) (*types.Account, error) {
var gormAccount models.Account
result := r.db.Where("org_id = ? AND parent = ?", []byte(orgId), []byte{0}).
First(&gormAccount)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormAccountToTypesAccount(&gormAccount), nil
}
// Balance-related methods (simplified implementations)
func (r *GormRepository) AddBalances(accounts []*types.Account, date time.Time) error {
// Implementation would need to be completed based on your balance calculation logic
return nil
}
func (r *GormRepository) AddNativeBalancesCost(accounts []*types.Account, date time.Time) error {
// Implementation would need to be completed based on your balance calculation logic
return nil
}
func (r *GormRepository) AddNativeBalancesNearestInTime(accounts []*types.Account, date time.Time) error {
// Implementation would need to be completed based on your balance calculation logic
return nil
}
func (r *GormRepository) AddBalance(account *types.Account, date time.Time) error {
// Implementation would need to be completed based on your balance calculation logic
return nil
}
func (r *GormRepository) AddNativeBalanceCost(account *types.Account, date time.Time) error {
// Implementation would need to be completed based on your balance calculation logic
return nil
}
func (r *GormRepository) AddNativeBalanceNearestInTime(account *types.Account, date time.Time) error {
// Implementation would need to be completed based on your balance calculation logic
return nil
}
// Helper function to convert GORM Account to types.Account
func (r *GormRepository) convertGormAccountToTypesAccount(gormAccount *models.Account) *types.Account {
return &types.Account{
Id: string(gormAccount.ID),
OrgId: string(gormAccount.OrgID),
Inserted: util.MsToTime(int64(gormAccount.Inserted)),
Updated: util.MsToTime(int64(gormAccount.Updated)),
Name: gormAccount.Name,
Parent: string(gormAccount.Parent),
Currency: gormAccount.Currency,
Precision: gormAccount.Precision,
DebitBalance: gormAccount.DebitBalance,
// Balance fields would be populated by the AddBalance methods
}
}
// Escape method for SQL injection protection (GORM handles this automatically)
func (r *GormRepository) Escape(sql string) string {
// GORM handles SQL injection protection automatically
// This method is kept for interface compatibility
return sql
}

View File

@@ -0,0 +1,462 @@
package repository
import (
"time"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/models"
"gorm.io/gorm"
)
// OrgInterface implementation
func (r *GormRepository) CreateOrg(org *types.Org, userId string, accounts []*types.Account) error {
return r.db.Transaction(func(tx *gorm.DB) error {
org.Inserted = time.Now()
org.Updated = org.Inserted
// Create org
gormOrg := &models.Org{
ID: []byte(org.Id),
Inserted: uint64(util.TimeToMs(org.Inserted)),
Updated: uint64(util.TimeToMs(org.Updated)),
Name: org.Name,
Currency: org.Currency,
Precision: org.Precision,
Timezone: org.Timezone,
}
if err := tx.Create(gormOrg).Error; err != nil {
return err
}
// Create accounts
for _, account := range accounts {
gormAccount := &models.Account{
ID: []byte(account.Id),
OrgID: []byte(account.OrgId),
Inserted: uint64(util.TimeToMs(time.Now())),
Updated: uint64(util.TimeToMs(time.Now())),
Name: account.Name,
Parent: []byte(account.Parent),
Currency: account.Currency,
Precision: account.Precision,
DebitBalance: account.DebitBalance,
}
if err := tx.Create(gormAccount).Error; err != nil {
return err
}
}
// Create userorg association
userOrg := &models.UserOrg{
UserID: []byte(userId),
OrgID: []byte(org.Id),
Admin: true,
}
return tx.Create(userOrg).Error
})
}
func (r *GormRepository) UpdateOrg(org *types.Org) error {
org.Updated = time.Now()
return r.db.Model(&models.Org{}).
Where("id = ?", []byte(org.Id)).
Updates(map[string]interface{}{
"updated": util.TimeToMs(org.Updated),
"name": org.Name,
"currency": org.Currency,
"precision": org.Precision,
"timezone": org.Timezone,
}).Error
}
func (r *GormRepository) GetOrg(orgId, userId string) (*types.Org, error) {
var gormOrg models.Org
result := r.db.Table("orgs").
Select("orgs.*").
Joins("JOIN user_orgs ON user_orgs.org_id = orgs.id").
Where("orgs.id = ? AND user_orgs.user_id = ?", []byte(orgId), []byte(userId)).
First(&gormOrg)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormOrgToTypesOrg(&gormOrg), nil
}
func (r *GormRepository) GetOrgs(userId string) ([]*types.Org, error) {
var gormOrgs []models.Org
result := r.db.Table("orgs").
Select("orgs.*").
Joins("JOIN user_orgs ON user_orgs.org_id = orgs.id").
Where("user_orgs.user_id = ?", []byte(userId)).
Find(&gormOrgs)
if result.Error != nil {
return nil, result.Error
}
orgs := make([]*types.Org, len(gormOrgs))
for i, gormOrg := range gormOrgs {
orgs[i] = r.convertGormOrgToTypesOrg(&gormOrg)
}
return orgs, nil
}
func (r *GormRepository) GetOrgUserIds(orgId string) ([]string, error) {
var userIds []string
result := r.db.Table("user_orgs").
Select("LOWER(HEX(user_id)) as user_id").
Where("org_id = ?", []byte(orgId)).
Pluck("user_id", &userIds)
return userIds, result.Error
}
func (r *GormRepository) InsertInvite(invite *types.Invite) error {
invite.Inserted = time.Now()
invite.Updated = invite.Inserted
gormInvite := &models.Invite{
ID: invite.Id,
OrgID: []byte(invite.OrgId),
Inserted: uint64(util.TimeToMs(invite.Inserted)),
Updated: uint64(util.TimeToMs(invite.Updated)),
Email: invite.Email,
Accepted: invite.Accepted,
}
return r.db.Create(gormInvite).Error
}
func (r *GormRepository) AcceptInvite(invite *types.Invite, userId string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Update invite
if err := tx.Model(&models.Invite{}).
Where("id = ?", []byte(invite.Id)).
Update("accepted", true).Error; err != nil {
return err
}
// Create userorg association
userOrg := &models.UserOrg{
UserID: []byte(userId),
OrgID: []byte(invite.OrgId),
Admin: false,
}
return tx.Create(userOrg).Error
})
}
func (r *GormRepository) GetInvites(orgId string) ([]*types.Invite, error) {
var gormInvites []models.Invite
result := r.db.Where("org_id = ?", []byte(orgId)).Find(&gormInvites)
if result.Error != nil {
return nil, result.Error
}
invites := make([]*types.Invite, len(gormInvites))
for i, gormInvite := range gormInvites {
invites[i] = r.convertGormInviteToTypesInvite(&gormInvite)
}
return invites, nil
}
func (r *GormRepository) GetInvite(id string) (*types.Invite, error) {
var gormInvite models.Invite
result := r.db.Where("id = ?", id).First(&gormInvite)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormInviteToTypesInvite(&gormInvite), nil
}
func (r *GormRepository) DeleteInvite(id string) error {
return r.db.Where("id = ?", id).Delete(&models.Invite{}).Error
}
// SessionInterface implementation (basic stubs)
func (r *GormRepository) InsertSession(session *types.Session) error {
var terminated uint64
if session.Terminated.Valid {
terminated = uint64(util.TimeToMs(session.Terminated.Time))
}
gormSession := &models.Session{
ID: []byte(session.Id),
UserID: []byte(session.UserId),
Inserted: uint64(util.TimeToMs(time.Now())),
Updated: uint64(util.TimeToMs(time.Now())),
Terminated: terminated,
}
return r.db.Create(gormSession).Error
}
func (r *GormRepository) TerminateSession(sessionId string) error {
return r.db.Model(&models.Session{}).
Where("id = ?", []byte(sessionId)).
Update("terminated", uint64(util.TimeToMs(time.Now()))).Error
}
func (r *GormRepository) DeleteSession(sessionId, userId string) error {
return r.db.Where("id = ? AND user_id = ?", []byte(sessionId), []byte(userId)).
Delete(&models.Session{}).Error
}
func (r *GormRepository) UpdateSessionActivity(sessionId string) error {
return r.db.Model(&models.Session{}).
Where("id = ?", []byte(sessionId)).
Update("updated", uint64(util.TimeToMs(time.Now()))).Error
}
// APIKey interface (basic stubs)
func (r *GormRepository) InsertApiKey(apiKey *types.ApiKey) error {
gormApiKey := &models.APIKey{
ID: []byte(apiKey.Id),
UserID: []byte(apiKey.UserId),
Inserted: uint64(util.TimeToMs(time.Now())),
Updated: uint64(util.TimeToMs(time.Now())),
Label: apiKey.Label,
}
return r.db.Create(gormApiKey).Error
}
func (r *GormRepository) DeleteApiKey(keyId, userId string) error {
return r.db.Where("id = ? AND user_id = ?", []byte(keyId), []byte(userId)).
Delete(&models.APIKey{}).Error
}
func (r *GormRepository) UpdateApiKey(apiKey *types.ApiKey) error {
return r.db.Model(&models.APIKey{}).
Where("id = ?", []byte(apiKey.Id)).
Updates(map[string]interface{}{
"updated": uint64(util.TimeToMs(time.Now())),
"label": apiKey.Label,
}).Error
}
func (r *GormRepository) GetApiKeys(userId string) ([]*types.ApiKey, error) {
var gormApiKeys []models.APIKey
result := r.db.Where("user_id = ? AND deleted_at IS NULL", []byte(userId)).
Find(&gormApiKeys)
if result.Error != nil {
return nil, result.Error
}
apiKeys := make([]*types.ApiKey, len(gormApiKeys))
for i, gormApiKey := range gormApiKeys {
apiKeys[i] = r.convertGormApiKeyToTypesApiKey(&gormApiKey)
}
return apiKeys, nil
}
func (r *GormRepository) UpdateApiKeyActivity(keyId string) error {
return r.db.Model(&models.APIKey{}).
Where("id = ?", []byte(keyId)).
Update("updated", uint64(util.TimeToMs(time.Now()))).Error
}
func (r *GormRepository) convertGormApiKeyToTypesApiKey(gormApiKey *models.APIKey) *types.ApiKey {
return &types.ApiKey{
Id: string(gormApiKey.ID),
Inserted: util.MsToTime(int64(gormApiKey.Inserted)),
Updated: util.MsToTime(int64(gormApiKey.Updated)),
UserId: string(gormApiKey.UserID),
Label: gormApiKey.Label,
}
}
// TransactionInterface implementation
func (r *GormRepository) InsertTransaction(transaction *types.Transaction) error {
gormTransaction := &models.Transaction{
ID: []byte(transaction.Id),
OrgID: []byte(transaction.OrgId),
UserID: []byte(transaction.UserId),
Date: uint64(transaction.Date.Unix()),
Inserted: uint64(util.TimeToMs(time.Now())),
Updated: uint64(util.TimeToMs(time.Now())),
Description: transaction.Description,
Data: transaction.Data,
}
return r.db.Create(gormTransaction).Error
}
func (r *GormRepository) GetTransactionById(id string) (*types.Transaction, error) {
var gormTransaction models.Transaction
result := r.db.Where("id = ?", []byte(id)).First(&gormTransaction)
if result.Error != nil {
return nil, result.Error
}
return r.convertGormTransactionToTypesTransaction(&gormTransaction), nil
}
func (r *GormRepository) GetTransactionsByAccount(accountId string, options *types.QueryOptions) ([]*types.Transaction, error) {
var gormTransactions []models.Transaction
query := r.db.Table("transactions").
Joins("JOIN splits ON splits.transaction_id = transactions.id").
Where("splits.account_id = ?", []byte(accountId))
if options != nil {
// Apply query options like limit, skip, date range, etc.
if options.Limit > 0 {
query = query.Limit(int(options.Limit))
}
if options.Skip > 0 {
query = query.Offset(int(options.Skip))
}
}
result := query.Find(&gormTransactions)
if result.Error != nil {
return nil, result.Error
}
transactions := make([]*types.Transaction, len(gormTransactions))
for i, gormTx := range gormTransactions {
transactions[i] = r.convertGormTransactionToTypesTransaction(&gormTx)
}
return transactions, nil
}
func (r *GormRepository) GetTransactionsByOrg(orgId string, options *types.QueryOptions, accountIds []string) ([]*types.Transaction, error) {
var gormTransactions []models.Transaction
query := r.db.Where("org_id = ?", []byte(orgId))
if len(accountIds) > 0 {
// Convert string IDs to byte arrays
byteAccountIds := make([][]byte, len(accountIds))
for i, id := range accountIds {
byteAccountIds[i] = []byte(id)
}
query = query.Where("id IN (SELECT DISTINCT transaction_id FROM splits WHERE account_id IN ?)", byteAccountIds)
}
if options != nil {
if options.Limit > 0 {
query = query.Limit(int(options.Limit))
}
if options.Skip > 0 {
query = query.Offset(int(options.Skip))
}
}
result := query.Find(&gormTransactions)
if result.Error != nil {
return nil, result.Error
}
transactions := make([]*types.Transaction, len(gormTransactions))
for i, gormTx := range gormTransactions {
transactions[i] = r.convertGormTransactionToTypesTransaction(&gormTx)
}
return transactions, nil
}
func (r *GormRepository) DeleteTransaction(id string) error {
return r.db.Where("id = ?", []byte(id)).Delete(&models.Transaction{}).Error
}
func (r *GormRepository) DeleteAndInsertTransaction(id string, transaction *types.Transaction) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Delete the old transaction
if err := tx.Where("id = ?", []byte(id)).Delete(&models.Transaction{}).Error; err != nil {
return err
}
// Insert the new transaction
gormTransaction := &models.Transaction{
ID: []byte(transaction.Id),
OrgID: []byte(transaction.OrgId),
UserID: []byte(transaction.UserId),
Date: uint64(transaction.Date.Unix()),
Inserted: uint64(util.TimeToMs(time.Now())),
Updated: uint64(util.TimeToMs(time.Now())),
Description: transaction.Description,
Data: transaction.Data,
}
return tx.Create(gormTransaction).Error
})
}
func (r *GormRepository) convertGormTransactionToTypesTransaction(gormTx *models.Transaction) *types.Transaction {
return &types.Transaction{
Id: string(gormTx.ID),
OrgId: string(gormTx.OrgID),
UserId: string(gormTx.UserID),
Date: time.Unix(int64(gormTx.Date), 0),
Inserted: util.MsToTime(int64(gormTx.Inserted)),
Updated: util.MsToTime(int64(gormTx.Updated)),
Description: gormTx.Description,
Data: gormTx.Data,
Deleted: gormTx.Deleted,
}
}
// Helper conversion functions
func (r *GormRepository) convertGormOrgToTypesOrg(gormOrg *models.Org) *types.Org {
return &types.Org{
Id: string(gormOrg.ID),
Inserted: util.MsToTime(int64(gormOrg.Inserted)),
Updated: util.MsToTime(int64(gormOrg.Updated)),
Name: gormOrg.Name,
Currency: gormOrg.Currency,
Precision: gormOrg.Precision,
Timezone: gormOrg.Timezone,
}
}
func (r *GormRepository) convertGormInviteToTypesInvite(gormInvite *models.Invite) *types.Invite {
return &types.Invite{
Id: string(gormInvite.ID),
OrgId: string(gormInvite.OrgID),
Inserted: util.MsToTime(int64(gormInvite.Inserted)),
Updated: util.MsToTime(int64(gormInvite.Updated)),
Email: gormInvite.Email,
Accepted: gormInvite.Accepted,
}
}
// Stub implementations for remaining interfaces that aren't fully used yet
func (r *GormRepository) GetPrices(orgId string, date time.Time) ([]*types.Price, error) {
// Stub implementation - add as needed
return nil, nil
}
func (r *GormRepository) InsertPrice(price *types.Price) error {
// Stub implementation - add as needed
return nil
}
func (r *GormRepository) Ping() error {
// Check if the database connection is alive
sqlDB, err := r.db.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}
func (r *GormRepository) InsertBudget(budget *types.Budget) error {
// Stub implementation - add as needed
return nil
}
func (r *GormRepository) GetBudgets(orgId string) ([]*types.Budget, error) {
// Stub implementation - add as needed
return nil, nil
}

View File

@@ -1,48 +1,145 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strconv"
"strings"
"github.com/openaccounting/oa-server/core/api"
"github.com/openaccounting/oa-server/core/auth"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/repository"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/database"
"github.com/spf13/viper"
)
func main() {
//filename is the path to the json config file
// Initialize Viper configuration
var config types.Config
file, err := os.Open("./config.json")
// Set config file properties
viper.SetConfigName("config")
viper.SetConfigType("json")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/openaccounting/")
viper.AddConfigPath("$HOME/.openaccounting")
// Enable environment variables
viper.AutomaticEnv()
viper.SetEnvPrefix("OA") // will look for OA_DATABASE_PASSWORD, etc.
// Configure Viper to handle nested config with environment variables
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Bind specific storage environment variables for better support
// Using mapstructure field names (snake_case)
viper.BindEnv("Storage.backend", "OA_STORAGE_BACKEND")
viper.BindEnv("Storage.local.root_dir", "OA_STORAGE_LOCAL_ROOTDIR")
viper.BindEnv("Storage.local.base_url", "OA_STORAGE_LOCAL_BASEURL")
viper.BindEnv("Storage.s3.region", "OA_STORAGE_S3_REGION")
viper.BindEnv("Storage.s3.bucket", "OA_STORAGE_S3_BUCKET")
viper.BindEnv("Storage.s3.prefix", "OA_STORAGE_S3_PREFIX")
viper.BindEnv("Storage.s3.access_key_id", "OA_STORAGE_S3_ACCESSKEYID")
viper.BindEnv("Storage.s3.secret_access_key", "OA_STORAGE_S3_SECRETACCESSKEY")
viper.BindEnv("Storage.s3.endpoint", "OA_STORAGE_S3_ENDPOINT")
viper.BindEnv("Storage.s3.path_style", "OA_STORAGE_S3_PATHSTYLE")
// Set default values
viper.SetDefault("Address", "localhost")
viper.SetDefault("Port", 8080)
viper.SetDefault("DatabaseDriver", "sqlite")
viper.SetDefault("DatabaseFile", "./openaccounting.db")
viper.SetDefault("ApiPrefix", "/api/v1")
// Set storage defaults (using mapstructure field names)
viper.SetDefault("Storage.backend", "local")
viper.SetDefault("Storage.local.root_dir", "./uploads")
viper.SetDefault("Storage.local.base_url", "")
// Read configuration
err := viper.ReadInConfig()
if err != nil {
log.Fatal(fmt.Errorf("failed to open ./config.json with: %s", err.Error()))
log.Printf("Warning: Could not read config file: %v", err)
log.Println("Using environment variables and defaults")
}
// Unmarshal config into struct
err = viper.Unmarshal(&config)
if err != nil {
log.Fatal(fmt.Errorf("failed to unmarshal config: %s", err.Error()))
}
// Set storage defaults if not configured (Viper doesn't handle nested defaults well)
if config.Storage.Backend == "" {
config.Storage.Backend = "local"
}
if config.Storage.Local.RootDir == "" {
config.Storage.Local.RootDir = "./uploads"
}
decoder := json.NewDecoder(file)
err = decoder.Decode(&config)
if err != nil {
log.Fatal(fmt.Errorf("failed to decode ./config.json with: %s", err.Error()))
// Parse database address (assuming format host:port for MySQL)
host := config.DatabaseAddress
port := "3306"
if len(config.DatabaseAddress) > 0 {
// If there's a colon, split host and port
if colonIndex := len(config.DatabaseAddress); colonIndex > 0 {
host = config.DatabaseAddress
}
}
connectionString := config.User + ":" + config.Password + "@" + config.DatabaseAddress + "/" + config.Database
// Default to SQLite if no driver specified
driver := config.DatabaseDriver
if driver == "" {
driver = "sqlite"
}
db, err := db.NewDB(connectionString)
// Initialize GORM database
dbConfig := &database.Config{
Driver: driver,
Host: host,
Port: port,
User: config.User,
Password: config.Password,
DBName: config.Database,
File: config.DatabaseFile,
SSLMode: "disable", // Adjust as needed
}
err = database.Connect(dbConfig)
if err != nil {
log.Fatal(fmt.Errorf("failed to connect to database with: %s", err.Error()))
}
// Run migrations
err = database.AutoMigrate()
if err != nil {
log.Fatal(fmt.Errorf("failed to run migrations: %s", err.Error()))
}
err = database.Migrate()
if err != nil {
log.Fatal(fmt.Errorf("failed to run custom migrations: %s", err.Error()))
}
bc := &util.StandardBcrypt{}
model.NewModel(db, bc, config)
auth.NewAuthService(db, bc)
// Create GORM repository and models
gormRepo := repository.NewGormRepository(database.DB)
gormModel := model.NewGormModel(database.DB, bc, config)
auth.NewGormAuthService(gormRepo, bc)
// Set the global model instance
model.Instance = gormModel
// Initialize storage backend for attachments
err = api.InitializeAttachmentHandler(config.Storage)
if err != nil {
log.Fatal(fmt.Errorf("failed to initialize storage backend: %s", err.Error()))
}
app, err := api.Init(config.ApiPrefix)
if err != nil {

106
core/storage/interface.go Normal file
View File

@@ -0,0 +1,106 @@
package storage
import (
"io"
"time"
)
// Storage defines the interface for file storage backends
type Storage interface {
// Store saves a file and returns the storage path/key
Store(filename string, content io.Reader, contentType string) (string, error)
// Retrieve gets a file by its storage path/key
Retrieve(path string) (io.ReadCloser, error)
// Delete removes a file by its storage path/key
Delete(path string) error
// GetURL returns a URL for accessing the file (may be signed/temporary)
GetURL(path string, expiry time.Duration) (string, error)
// Exists checks if a file exists at the given path
Exists(path string) (bool, error)
// GetMetadata returns file metadata (size, last modified, etc.)
GetMetadata(path string) (*FileMetadata, error)
}
// FileMetadata contains information about a stored file
type FileMetadata struct {
Size int64
LastModified time.Time
ContentType string
ETag string
}
// Config holds configuration for storage backends
type Config struct {
// Storage backend type: "local", "s3"
Backend string `mapstructure:"backend"`
// Local filesystem configuration
Local LocalConfig `mapstructure:"local"`
// S3-compatible storage configuration (S3, B2, R2, etc.)
S3 S3Config `mapstructure:"s3"`
}
// LocalConfig configures local filesystem storage
type LocalConfig struct {
// Root directory for file storage
RootDir string `mapstructure:"root_dir"`
// Base URL for serving files (optional)
BaseURL string `mapstructure:"base_url"`
}
// S3Config configures S3-compatible storage (AWS S3, Backblaze B2, Cloudflare R2, etc.)
type S3Config struct {
// AWS Region (use "auto" for Cloudflare R2)
Region string `mapstructure:"region"`
// S3 Bucket name
Bucket string `mapstructure:"bucket"`
// Optional prefix for all objects
Prefix string `mapstructure:"prefix"`
// Access Key ID
AccessKeyID string `mapstructure:"access_key_id"`
// Secret Access Key
SecretAccessKey string `mapstructure:"secret_access_key"`
// Custom endpoint URL for S3-compatible services:
// - Backblaze B2: https://s3.us-west-004.backblazeb2.com
// - Cloudflare R2: https://<account-id>.r2.cloudflarestorage.com
// - MinIO: http://localhost:9000
// Leave empty for AWS S3
Endpoint string `mapstructure:"endpoint"`
// Use path-style addressing (required for some S3-compatible services)
PathStyle bool `mapstructure:"path_style"`
}
// NewStorage creates a new storage backend based on configuration
func NewStorage(config Config) (Storage, error) {
switch config.Backend {
case "local", "":
return NewLocalStorage(config.Local)
case "s3":
return NewS3Storage(config.S3)
default:
return nil, &UnsupportedBackendError{Backend: config.Backend}
}
}
// UnsupportedBackendError is returned when an unknown storage backend is requested
type UnsupportedBackendError struct {
Backend string
}
func (e *UnsupportedBackendError) Error() string {
return "unsupported storage backend: " + e.Backend
}

View File

@@ -0,0 +1,101 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewStorage(t *testing.T) {
t.Run("Local Storage", func(t *testing.T) {
config := Config{
Backend: "local",
Local: LocalConfig{
RootDir: t.TempDir(),
},
}
storage, err := NewStorage(config)
assert.NoError(t, err)
assert.IsType(t, &LocalStorage{}, storage)
})
t.Run("Default to Local Storage", func(t *testing.T) {
config := Config{
// No backend specified
Local: LocalConfig{
RootDir: t.TempDir(),
},
}
storage, err := NewStorage(config)
assert.NoError(t, err)
assert.IsType(t, &LocalStorage{}, storage)
})
t.Run("S3 Storage", func(t *testing.T) {
config := Config{
Backend: "s3",
S3: S3Config{
Region: "us-east-1",
Bucket: "test-bucket",
},
}
// This might succeed if AWS credentials are available via IAM roles or env vars
// Let's just check that we get an S3Storage instance or an error
storage, err := NewStorage(config)
if err != nil {
// If it fails, that's expected in test environments without AWS access
assert.Nil(t, storage)
} else {
// If it succeeds, we should get an S3Storage instance
assert.IsType(t, &S3Storage{}, storage)
}
})
t.Run("B2 Storage", func(t *testing.T) {
config := Config{
Backend: "b2",
B2: B2Config{
AccountID: "test-account",
ApplicationKey: "test-key",
Bucket: "test-bucket",
},
}
// This will fail because we don't have real B2 credentials
storage, err := NewStorage(config)
assert.Error(t, err) // Expected to fail without credentials
assert.Nil(t, storage)
})
t.Run("Unsupported Backend", func(t *testing.T) {
config := Config{
Backend: "unsupported",
}
storage, err := NewStorage(config)
assert.Error(t, err)
assert.IsType(t, &UnsupportedBackendError{}, err)
assert.Nil(t, storage)
assert.Contains(t, err.Error(), "unsupported")
})
}
func TestStorageErrors(t *testing.T) {
t.Run("UnsupportedBackendError", func(t *testing.T) {
err := &UnsupportedBackendError{Backend: "ftp"}
assert.Equal(t, "unsupported storage backend: ftp", err.Error())
})
t.Run("FileNotFoundError", func(t *testing.T) {
err := &FileNotFoundError{Path: "missing.txt"}
assert.Equal(t, "file not found: missing.txt", err.Error())
})
t.Run("InvalidPathError", func(t *testing.T) {
err := &InvalidPathError{Path: "../../../etc/passwd"}
assert.Equal(t, "invalid path: ../../../etc/passwd", err.Error())
})
}

243
core/storage/local.go Normal file
View File

@@ -0,0 +1,243 @@
package storage
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/openaccounting/oa-server/core/util/id"
)
// LocalStorage implements the Storage interface for local filesystem
type LocalStorage struct {
rootDir string
baseURL string
}
// NewLocalStorage creates a new local filesystem storage backend
func NewLocalStorage(config LocalConfig) (*LocalStorage, error) {
rootDir := config.RootDir
if rootDir == "" {
rootDir = "./uploads"
}
// Ensure the root directory exists
if err := os.MkdirAll(rootDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create storage directory: %w", err)
}
return &LocalStorage{
rootDir: rootDir,
baseURL: config.BaseURL,
}, nil
}
// Store saves a file to the local filesystem
func (l *LocalStorage) Store(filename string, content io.Reader, contentType string) (string, error) {
// Generate a unique storage path
storagePath := l.generateStoragePath(filename)
fullPath := filepath.Join(l.rootDir, storagePath)
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
// Create and write the file
file, err := os.Create(fullPath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
_, err = io.Copy(file, content)
if err != nil {
// Clean up the file if write failed
os.Remove(fullPath)
return "", fmt.Errorf("failed to write file: %w", err)
}
return storagePath, nil
}
// Retrieve gets a file from the local filesystem
func (l *LocalStorage) Retrieve(path string) (io.ReadCloser, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return nil, err
}
fullPath := filepath.Join(l.rootDir, path)
file, err := os.Open(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, &FileNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to open file: %w", err)
}
return file, nil
}
// Delete removes a file from the local filesystem
func (l *LocalStorage) Delete(path string) error {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return err
}
fullPath := filepath.Join(l.rootDir, path)
err := os.Remove(fullPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete file: %w", err)
}
// Try to remove empty parent directories
l.cleanupEmptyDirs(filepath.Dir(fullPath))
return nil
}
// GetURL returns a URL for accessing the file
func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return "", err
}
// Check if file exists
exists, err := l.Exists(path)
if err != nil {
return "", err
}
if !exists {
return "", &FileNotFoundError{Path: path}
}
if l.baseURL != "" {
// Return a public URL if base URL is configured
return l.baseURL + "/" + path, nil
}
// For local storage without a base URL, return the file path
// In a real application, you might serve these through an endpoint
return "/files/" + path, nil
}
// Exists checks if a file exists at the given path
func (l *LocalStorage) Exists(path string) (bool, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return false, err
}
fullPath := filepath.Join(l.rootDir, path)
_, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check file existence: %w", err)
}
return true, nil
}
// GetMetadata returns file metadata
func (l *LocalStorage) GetMetadata(path string) (*FileMetadata, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return nil, err
}
fullPath := filepath.Join(l.rootDir, path)
info, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, &FileNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to get file metadata: %w", err)
}
return &FileMetadata{
Size: info.Size(),
LastModified: info.ModTime(),
ContentType: "", // Local storage doesn't store content type
ETag: "", // Local storage doesn't have ETags
}, nil
}
// generateStoragePath creates a unique storage path for a file
func (l *LocalStorage) generateStoragePath(filename string) string {
// Generate a unique ID for the file
fileID := id.String(id.New())
// Extract file extension
ext := filepath.Ext(filename)
// Create a path structure: YYYY/MM/DD/uuid.ext
now := time.Now()
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
return filepath.Join(datePath, fileID+ext)
}
// validatePath ensures the path doesn't contain directory traversal attempts
func (l *LocalStorage) validatePath(path string) error {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Reject paths that try to go up directories
if strings.Contains(cleanPath, "..") {
return &InvalidPathError{Path: path}
}
// Reject absolute paths
if filepath.IsAbs(cleanPath) {
return &InvalidPathError{Path: path}
}
return nil
}
// cleanupEmptyDirs removes empty parent directories up to the root
func (l *LocalStorage) cleanupEmptyDirs(dir string) {
// Don't remove the root directory
if dir == l.rootDir {
return
}
// Check if directory is empty
entries, err := os.ReadDir(dir)
if err != nil || len(entries) > 0 {
return
}
// Remove empty directory
if err := os.Remove(dir); err == nil {
// Recursively clean parent directories
l.cleanupEmptyDirs(filepath.Dir(dir))
}
}
// FileNotFoundError is returned when a file doesn't exist
type FileNotFoundError struct {
Path string
}
func (e *FileNotFoundError) Error() string {
return "file not found: " + e.Path
}
// InvalidPathError is returned when a path is invalid or contains traversal attempts
type InvalidPathError struct {
Path string
}
func (e *InvalidPathError) Error() string {
return "invalid path: " + e.Path
}

202
core/storage/local_test.go Normal file
View File

@@ -0,0 +1,202 @@
package storage
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLocalStorage(t *testing.T) {
// Create temporary directory for testing
tmpDir := t.TempDir()
config := LocalConfig{
RootDir: tmpDir,
BaseURL: "http://localhost:8080/files",
}
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.NotNil(t, storage)
t.Run("Store and Retrieve File", func(t *testing.T) {
content := []byte("test file content")
reader := bytes.NewReader(content)
// Store file
path, err := storage.Store("test.txt", reader, "text/plain")
assert.NoError(t, err)
assert.NotEmpty(t, path)
// Verify file exists
exists, err := storage.Exists(path)
assert.NoError(t, err)
assert.True(t, exists)
// Retrieve file
retrievedReader, err := storage.Retrieve(path)
assert.NoError(t, err)
defer retrievedReader.Close()
retrievedContent, err := io.ReadAll(retrievedReader)
assert.NoError(t, err)
assert.Equal(t, content, retrievedContent)
})
t.Run("Get File Metadata", func(t *testing.T) {
content := []byte("metadata test content")
reader := bytes.NewReader(content)
path, err := storage.Store("metadata.txt", reader, "text/plain")
assert.NoError(t, err)
metadata, err := storage.GetMetadata(path)
assert.NoError(t, err)
assert.Equal(t, int64(len(content)), metadata.Size)
assert.False(t, metadata.LastModified.IsZero())
})
t.Run("Get File URL", func(t *testing.T) {
content := []byte("url test content")
reader := bytes.NewReader(content)
path, err := storage.Store("url.txt", reader, "text/plain")
assert.NoError(t, err)
url, err := storage.GetURL(path, time.Hour)
assert.NoError(t, err)
assert.Contains(t, url, path)
assert.Contains(t, url, config.BaseURL)
})
t.Run("Delete File", func(t *testing.T) {
content := []byte("delete test content")
reader := bytes.NewReader(content)
path, err := storage.Store("delete.txt", reader, "text/plain")
assert.NoError(t, err)
// Verify file exists
exists, err := storage.Exists(path)
assert.NoError(t, err)
assert.True(t, exists)
// Delete file
err = storage.Delete(path)
assert.NoError(t, err)
// Verify file no longer exists
exists, err = storage.Exists(path)
assert.NoError(t, err)
assert.False(t, exists)
})
t.Run("Path Validation", func(t *testing.T) {
// Test directory traversal prevention
_, err := storage.Retrieve("../../../etc/passwd")
assert.Error(t, err)
assert.IsType(t, &InvalidPathError{}, err)
// Test absolute path rejection
_, err = storage.Retrieve("/etc/passwd")
assert.Error(t, err)
assert.IsType(t, &InvalidPathError{}, err)
})
t.Run("File Not Found", func(t *testing.T) {
_, err := storage.Retrieve("nonexistent.txt")
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
_, err = storage.GetMetadata("nonexistent.txt")
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
_, err = storage.GetURL("nonexistent.txt", time.Hour)
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
})
t.Run("Storage Path Generation", func(t *testing.T) {
content := []byte("path test content")
reader1 := bytes.NewReader(content)
reader2 := bytes.NewReader(content)
// Store two files with same name
path1, err := storage.Store("same.txt", reader1, "text/plain")
assert.NoError(t, err)
path2, err := storage.Store("same.txt", reader2, "text/plain")
assert.NoError(t, err)
// Paths should be different (unique)
assert.NotEqual(t, path1, path2)
// Both should exist
exists1, err := storage.Exists(path1)
assert.NoError(t, err)
assert.True(t, exists1)
exists2, err := storage.Exists(path2)
assert.NoError(t, err)
assert.True(t, exists2)
// Both should have correct extension
assert.True(t, strings.HasSuffix(path1, ".txt"))
assert.True(t, strings.HasSuffix(path2, ".txt"))
// Should be organized by date
now := time.Now()
expectedPrefix := filepath.Join(
fmt.Sprintf("%04d", now.Year()),
fmt.Sprintf("%02d", now.Month()),
fmt.Sprintf("%02d", now.Day()),
)
assert.True(t, strings.HasPrefix(path1, expectedPrefix))
assert.True(t, strings.HasPrefix(path2, expectedPrefix))
})
}
func TestLocalStorageConfig(t *testing.T) {
t.Run("Default Root Directory", func(t *testing.T) {
config := LocalConfig{} // Empty config
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.NotNil(t, storage)
// Should create default uploads directory
assert.Equal(t, "./uploads", storage.rootDir)
// Verify directory was created
_, err = os.Stat("./uploads")
assert.NoError(t, err)
// Clean up
os.RemoveAll("./uploads")
})
t.Run("Custom Root Directory", func(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom", "storage")
config := LocalConfig{
RootDir: customDir,
}
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.Equal(t, customDir, storage.rootDir)
// Verify custom directory was created
_, err = os.Stat(customDir)
assert.NoError(t, err)
})
}

236
core/storage/s3.go Normal file
View File

@@ -0,0 +1,236 @@
package storage
import (
"fmt"
"io"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/openaccounting/oa-server/core/util/id"
)
// S3Storage implements the Storage interface for Amazon S3
type S3Storage struct {
client *s3.S3
uploader *s3manager.Uploader
bucket string
prefix string
}
// NewS3Storage creates a new S3 storage backend
func NewS3Storage(config S3Config) (*S3Storage, error) {
if config.Bucket == "" {
return nil, fmt.Errorf("S3 bucket name is required")
}
// Create AWS config
awsConfig := &aws.Config{
Region: aws.String(config.Region),
}
// Set custom endpoint if provided (for S3-compatible services)
if config.Endpoint != "" {
awsConfig.Endpoint = aws.String(config.Endpoint)
awsConfig.S3ForcePathStyle = aws.Bool(config.PathStyle)
}
// Set credentials if provided
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
awsConfig.Credentials = credentials.NewStaticCredentials(
config.AccessKeyID,
config.SecretAccessKey,
"",
)
}
// Create session
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create AWS session: %w", err)
}
// Create S3 client
client := s3.New(sess)
uploader := s3manager.NewUploader(sess)
return &S3Storage{
client: client,
uploader: uploader,
bucket: config.Bucket,
prefix: config.Prefix,
}, nil
}
// Store saves a file to S3
func (s *S3Storage) Store(filename string, content io.Reader, contentType string) (string, error) {
// Generate a unique storage key
storageKey := s.generateStorageKey(filename)
// Prepare upload input
input := &s3manager.UploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(storageKey),
Body: content,
}
// Set content type if provided
if contentType != "" {
input.ContentType = aws.String(contentType)
}
// Upload the file
_, err := s.uploader.Upload(input)
if err != nil {
return "", fmt.Errorf("failed to upload file to S3: %w", err)
}
return storageKey, nil
}
// Retrieve gets a file from S3
func (s *S3Storage) Retrieve(path string) (io.ReadCloser, error) {
input := &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
result, err := s.client.GetObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey:
return nil, &FileNotFoundError{Path: path}
}
}
return nil, fmt.Errorf("failed to retrieve file from S3: %w", err)
}
return result.Body, nil
}
// Delete removes a file from S3
func (s *S3Storage) Delete(path string) error {
input := &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
_, err := s.client.DeleteObject(input)
if err != nil {
return fmt.Errorf("failed to delete file from S3: %w", err)
}
return nil
}
// GetURL returns a presigned URL for accessing the file
func (s *S3Storage) GetURL(path string, expiry time.Duration) (string, error) {
// Check if file exists first
exists, err := s.Exists(path)
if err != nil {
return "", err
}
if !exists {
return "", &FileNotFoundError{Path: path}
}
// Generate presigned URL
req, _ := s.client.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
})
url, err := req.Presign(expiry)
if err != nil {
return "", fmt.Errorf("failed to generate presigned URL: %w", err)
}
return url, nil
}
// Exists checks if a file exists in S3
func (s *S3Storage) Exists(path string) (bool, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
_, err := s.client.HeadObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey, "NotFound":
return false, nil
}
}
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
}
return true, nil
}
// GetMetadata returns file metadata from S3
func (s *S3Storage) GetMetadata(path string) (*FileMetadata, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
result, err := s.client.HeadObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey, "NotFound":
return nil, &FileNotFoundError{Path: path}
}
}
return nil, fmt.Errorf("failed to get file metadata from S3: %w", err)
}
metadata := &FileMetadata{
Size: aws.Int64Value(result.ContentLength),
}
if result.LastModified != nil {
metadata.LastModified = *result.LastModified
}
if result.ContentType != nil {
metadata.ContentType = *result.ContentType
}
if result.ETag != nil {
metadata.ETag = strings.Trim(*result.ETag, "\"")
}
return metadata, nil
}
// generateStorageKey creates a unique storage key for a file
func (s *S3Storage) generateStorageKey(filename string) string {
// Generate a unique ID for the file
fileID := id.String(id.New())
// Extract file extension
ext := path.Ext(filename)
// Create a key structure: prefix/YYYY/MM/DD/uuid.ext
now := time.Now()
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
key := path.Join(datePath, fileID+ext)
// Add prefix if configured
if s.prefix != "" {
key = path.Join(s.prefix, key)
}
return key
}

48
core/util/id/id.go Normal file
View File

@@ -0,0 +1,48 @@
package id
import (
"encoding/binary"
"github.com/google/uuid"
)
// New creates a new binary ID (16 bytes) using UUID v4
func New() []byte {
u := uuid.New()
return u[:]
}
// ToUUID converts a binary ID back to UUID
func ToUUID(b []byte) (uuid.UUID, error) {
return uuid.FromBytes(b)
}
// String returns the string representation of a binary ID
func String(b []byte) string {
u, err := uuid.FromBytes(b)
if err != nil {
return ""
}
return u.String()
}
// FromString parses a string UUID into binary format
func FromString(s string) ([]byte, error) {
u, err := uuid.Parse(s)
if err != nil {
return nil, err
}
return u[:], nil
}
// Uint64ToBytes converts a uint64 to 8-byte slice
func Uint64ToBytes(v uint64) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, v)
return b
}
// BytesToUint64 converts 8-byte slice to uint64
func BytesToUint64(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}

View File

@@ -3,6 +3,7 @@ package util
import (
"crypto/rand"
"encoding/hex"
"regexp"
"time"
)
@@ -44,3 +45,23 @@ func NewInviteId() (string, error) {
return hex.EncodeToString(byteArray), nil
}
func NewUUID() string {
guid, err := NewGuid()
if err != nil {
// Fallback to timestamp-based UUID if random generation fails
return hex.EncodeToString([]byte(time.Now().Format("20060102150405")))
}
return guid
}
func IsValidUUID(uuid string) bool {
// Check if the string is a valid 32-character hex string (16 bytes * 2 hex chars)
if len(uuid) != 32 {
return false
}
// Check if all characters are valid hex characters
matched, _ := regexp.MatchString("^[0-9a-f]{32}$", uuid)
return matched
}

View File

@@ -67,7 +67,7 @@ func Handler(w rest.ResponseWriter, r *rest.Request) {
continue
}
log.Printf("recv: %s", message)
log.Printf("recv: %+v", message)
// check version
err = checkVersion(message.Version)

337
database/database.go Normal file
View File

@@ -0,0 +1,337 @@
package database
import (
"fmt"
"log"
"os"
"time"
"github.com/openaccounting/oa-server/core/util/id"
"github.com/openaccounting/oa-server/models"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var DB *gorm.DB
type Config struct {
Driver string // "mysql" or "sqlite"
Host string
Port string
User string
Password string
DBName string
SSLMode string
// SQLite specific
File string // SQLite database file path
}
func Connect(config *Config) error {
// Configure GORM logger
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: true,
},
)
var db *gorm.DB
var err error
// Choose driver based on config
switch config.Driver {
case "sqlite":
// Use SQLite
dbFile := config.File
if dbFile == "" {
dbFile = "./openaccounting.db" // Default SQLite file
}
db, err = gorm.Open(sqlite.Open(dbFile), &gorm.Config{
Logger: newLogger,
})
case "mysql":
// Use MySQL (existing logic)
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
config.User, config.Password, config.Host, config.Port, config.DBName)
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: newLogger,
})
default:
return fmt.Errorf("unsupported database driver: %s (supported: mysql, sqlite)", config.Driver)
}
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
// Configure connection pool (only for MySQL, SQLite handles this internally)
if config.Driver == "mysql" {
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get database instance: %w", err)
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(25)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
}
DB = db
return nil
}
// AutoMigrate runs automatic migrations for all models
func AutoMigrate() error {
return DB.AutoMigrate(
&models.Org{},
&models.User{},
&models.UserOrg{},
&models.Token{},
&models.Account{},
&models.Transaction{},
&models.Split{},
&models.Balance{},
&models.Permission{},
&models.Price{},
&models.Session{},
&models.APIKey{},
&models.Invite{},
&models.BudgetItem{},
&models.Attachment{},
)
}
// Migrate runs custom migrations
func Migrate() error {
// Create indexes
if err := createIndexes(); err != nil {
return err
}
// Insert default data - temporarily disabled for testing
// if err := seedDefaultData(); err != nil {
// return err
// }
return nil
}
func createIndexes() error {
// Create custom indexes that GORM doesn't handle automatically
// Based on original indexes.sql file
indexes := []string{
// Original indexes from indexes.sql
"CREATE INDEX IF NOT EXISTS account_orgId_index ON accounts(orgId)",
"CREATE INDEX IF NOT EXISTS split_accountId_index ON splits(accountId)",
"CREATE INDEX IF NOT EXISTS split_transactionId_index ON splits(transactionId)",
"CREATE INDEX IF NOT EXISTS split_date_index ON splits(date)",
"CREATE INDEX IF NOT EXISTS split_updated_index ON splits(updated)",
"CREATE INDEX IF NOT EXISTS budgetitem_orgId_index ON budget_items(orgId)",
"CREATE INDEX IF NOT EXISTS attachment_transactionId_index ON attachment(transactionId)",
"CREATE INDEX IF NOT EXISTS attachment_orgId_index ON attachment(orgId)",
"CREATE INDEX IF NOT EXISTS attachment_userId_index ON attachment(userId)",
"CREATE INDEX IF NOT EXISTS attachment_uploaded_index ON attachment(uploaded)",
// Additional useful indexes for performance
"CREATE INDEX IF NOT EXISTS idx_transaction_date ON transactions(date)",
"CREATE INDEX IF NOT EXISTS idx_transaction_org ON transactions(orgId)",
"CREATE INDEX IF NOT EXISTS idx_account_parent ON accounts(parent)",
"CREATE INDEX IF NOT EXISTS idx_userorg_user ON user_orgs(userId)",
"CREATE INDEX IF NOT EXISTS idx_userorg_org ON user_orgs(orgId)",
"CREATE INDEX IF NOT EXISTS idx_balance_account_date ON balances(accountId, date)",
"CREATE INDEX IF NOT EXISTS idx_permission_org_account ON permissions(orgId, accountId)",
}
for _, idx := range indexes {
if err := DB.Exec(idx).Error; err != nil {
return fmt.Errorf("failed to create index: %w", err)
}
}
return nil
}
func seedDefaultData() error {
// Check if we already have data
var count int64
if err := DB.Model(&models.Org{}).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return nil // Data already exists
}
// Create a default organization
defaultOrg := models.Org{
ID: id.New(), // You'll need to implement this
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Default Organization",
Currency: "USD",
Precision: 2,
Timezone: "UTC",
}
if err := DB.Create(&defaultOrg).Error; err != nil {
return fmt.Errorf("failed to create default organization: %w", err)
}
// Create default accounts for the organization
defaultAccounts := []models.Account{
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Assets",
Parent: []byte{0}, // Root account has zero parent
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Liabilities",
Parent: []byte{0},
Currency: "USD",
Precision: 2,
DebitBalance: false,
},
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Equity",
Parent: []byte{0},
Currency: "USD",
Precision: 2,
DebitBalance: false,
},
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Revenue",
Parent: []byte{0},
Currency: "USD",
Precision: 2,
DebitBalance: false,
},
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Expenses",
Parent: []byte{0},
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
}
// Create accounts and store their IDs for parent-child relationships
accountMap := make(map[string]*models.Account)
for _, acc := range defaultAccounts {
account := acc
if err := DB.Create(&account).Error; err != nil {
return fmt.Errorf("failed to create account %s: %w", acc.Name, err)
}
accountMap[acc.Name] = &account
}
// Create Current Assets first
var assetsParent []byte
if assetsAccount, exists := accountMap["Assets"]; exists {
assetsParent = assetsAccount.ID
} else {
return fmt.Errorf("Assets account not found in accountMap")
}
currentAssets := models.Account{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Current Assets",
Parent: assetsParent,
Currency: "USD",
Precision: 2,
DebitBalance: true,
}
if err := DB.Create(&currentAssets).Error; err != nil {
return fmt.Errorf("failed to create Current Assets: %w", err)
}
accountMap["Current Assets"] = &currentAssets
// Create Accounts Payable
var liabilitiesParent []byte
if liabilitiesAccount, exists := accountMap["Liabilities"]; exists {
liabilitiesParent = liabilitiesAccount.ID
} else {
return fmt.Errorf("Liabilities account not found in accountMap")
}
accountsPayable := models.Account{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Accounts Payable",
Parent: liabilitiesParent,
Currency: "USD",
Precision: 2,
DebitBalance: false,
}
if err := DB.Create(&accountsPayable).Error; err != nil {
return fmt.Errorf("failed to create Accounts Payable: %w", err)
}
// Now create sub-accounts under Current Assets
subAccounts := []models.Account{
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Cash",
Parent: currentAssets.ID,
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
{
ID: id.New(),
OrgID: defaultOrg.ID,
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
Name: "Accounts Receivable",
Parent: currentAssets.ID,
Currency: "USD",
Precision: 2,
DebitBalance: true,
},
}
for _, acc := range subAccounts {
if err := DB.Create(&acc).Error; err != nil {
return fmt.Errorf("failed to create sub-account %s: %w", acc.Name, err)
}
}
return nil
}

51
go.mod
View File

@@ -1,16 +1,53 @@
module github.com/openaccounting/oa-server
go 1.24.2
require (
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01
github.com/go-sql-driver/mysql v1.4.1
github.com/aws/aws-sdk-go v1.44.0
github.com/go-sql-driver/mysql v1.8.1
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da
github.com/mailgun/mailgun-go/v4 v4.3.0
github.com/mitchellh/mapstructure v0.0.0-20180511142126-bb74f1db0675
github.com/sendgrid/rest v0.0.0-20180905234047-875828e14d98 // indirect
github.com/sendgrid/sendgrid-go v0.0.0-20180905233524-8cb43f4ca4f5 // indirect
github.com/stretchr/objx v0.3.0 // indirect
github.com/stretchr/testify v1.3.0
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
google.golang.org/appengine v1.6.7 // indirect
github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.32.0
gorm.io/driver/sqlite v1.6.0
)
require (
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/sys v0.29.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-chi/chi v4.0.0+incompatible // indirect
github.com/google/uuid v1.6.0
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.10 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/text v0.21.0 // indirect
gorm.io/driver/mysql v1.6.0
gorm.io/gorm v1.30.0
)

113
go.sum
View File

@@ -1,7 +1,11 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf h1:BMUJnVJI5J506LOcyGHEvbCocMHAmKTRcG6CMAwGFYU=
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01 h1:oYAjCHMjyRaNBo3nUEepDce4LC+Kuh+6jU6y+AllvnU=
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01/go.mod h1:q6aCt0GfU6LhpBsnZ/2U+mwe+0XB5WStbmwyoPfc+sk=
github.com/aws/aws-sdk-go v1.44.0 h1:jwtHuNqfnJxL4DKHBUVUmQlfueQqBW7oXP6yebZR/R0=
github.com/aws/aws-sdk-go v1.44.0/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -11,53 +15,104 @@ github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 h1:JWuenKqqX8nojt
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052/go.mod h1:UbMTZqLaRiH3MsBH8va0n7s1pQYcu3uTb8G4tygF4Zg=
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQDg5gKsWoLBOB0n+ZW8s599zru8FJ2/Y=
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-chi/chi v4.0.0+incompatible h1:SiLLEDyAkqNnw+T/uDTf3aFB9T4FTrwMpuYrgaRcnW4=
github.com/go-chi/chi v4.0.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da h1:b5fma7aUP2fn6+tdKKCJ0TxXYzY/5wDiqUxNdyi5VF4=
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailgun/mailgun-go/v4 v4.3.0 h1:9nAF7LI3k6bfDPbMZQMMl63Q8/vs+dr1FUN8eR1XMhk=
github.com/mailgun/mailgun-go/v4 v4.3.0/go.mod h1:fWuBI2iaS/pSSyo6+EBpHjatQO3lV8onwqcRy7joSJI=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v0.0.0-20180511142126-bb74f1db0675 h1:/rdJjIiKG5rRdwG5yxHmSE/7ZREjpyC0kL7GxGT/qJw=
github.com/mitchellh/mapstructure v0.0.0-20180511142126-bb74f1db0675/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sendgrid/rest v0.0.0-20180905234047-875828e14d98 h1:wpBZ5DAYLNl+2v4E4WP8k/y8tM5OjIf1FezJS1qX8sU=
github.com/sendgrid/rest v0.0.0-20180905234047-875828e14d98/go.mod h1:kXX7q3jZtJXK5c5qK83bSGMdV6tsOE70KbHoqJls4lE=
github.com/sendgrid/sendgrid-go v0.0.0-20180905233524-8cb43f4ca4f5 h1:V18LU+jSbihmDiWfLSzs9FV1d3KVB1gRTkNxgVHmcvg=
github.com/sendgrid/sendgrid-go v0.0.0-20180905233524-8cb43f4ca4f5/go.mod h1:QRQt+LX/NmgVEvmdRw0VT/QgUn499+iza2FnDca9fg8=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As=
github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20171231215028-0fcca4842a8d h1:GrqEEc3+MtHKTsZrdIGVoYDgLpbSRzW1EF+nLu0PcHE=
golang.org/x/crypto v0.0.0-20171231215028-0fcca4842a8d/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=

View File

@@ -3,4 +3,8 @@ CREATE INDEX split_accountId_index ON split (accountId);
CREATE INDEX split_transactionId_index ON split (transactionId);
CREATE INDEX split_date_index ON split (date);
CREATE INDEX split_updated_index ON split (updated);
CREATE INDEX budgetitem_orgId_index ON budgetitem (orgId);
CREATE INDEX budgetitem_orgId_index ON budgetitem (orgId);
CREATE INDEX attachment_transactionId_index ON attachment (transactionId);
CREATE INDEX attachment_orgId_index ON attachment (orgId);
CREATE INDEX attachment_userId_index ON attachment (userId);
CREATE INDEX attachment_uploaded_index ON attachment (uploaded);

166
justfile Normal file
View File

@@ -0,0 +1,166 @@
# OpenAccounting Server - Just recipes
# https://github.com/casey/just
# Default recipe
default:
@just --list
# Variables
image_name := "openaccounting-server"
tag := "latest"
# Build the Go application
build:
@echo "Building OpenAccounting Server..."
go build -o server ./core/
# Run the server locally
run: build
@echo "Starting server locally..."
./server
# Run with custom environment
run-dev: build
@echo "Starting server in development mode..."
OA_DATABASEDRIVER=sqlite OA_DATABASEFILE=./dev.db OA_PORT=8080 ./server
# Run tests
test:
@echo "Running tests..."
go test ./...
# Clean build artifacts
clean:
@echo "Cleaning up..."
rm -f server
rm -f *.db
# Docker recipes
# Build Docker image
docker-build:
@echo "Building Docker image: {{image_name}}:{{tag}}"
docker build -t {{image_name}}:{{tag}} .
# Run container with SQLite (development)
docker-run: docker-build
@echo "Running container with SQLite..."
docker run --rm -p 8080:8080 \
-e OA_DATABASE_DRIVER=sqlite \
-e OA_DATABASE_FILE=/app/data/openaccounting.db \
-v $(pwd)/data:/app/data \
{{image_name}}:{{tag}}
# Run container with MySQL (production example)
docker-run-mysql: docker-build
@echo "Running container with MySQL (requires external MySQL)..."
docker run --rm -p 8080:8080 \
-e OA_DATABASE_DRIVER=mysql \
-e OA_DATABASE_ADDRESS=mysql:3306 \
-e OA_DATABASE=openaccounting \
-e OA_USER=openaccounting \
-e OA_PASSWORD=secret \
{{image_name}}:{{tag}}
# Run with docker-compose (if you create one)
docker-compose-up:
@echo "Starting with docker-compose..."
docker-compose up -d
docker-compose-down:
@echo "Stopping docker-compose..."
docker-compose down
# Development utilities
# Format code
fmt:
@echo "Formatting code..."
go fmt ./...
# Lint code (requires golangci-lint)
lint:
@echo "Linting code..."
golangci-lint run
# Install development dependencies
install-deps:
@echo "Installing development dependencies..."
go mod download
go mod vendor
# Update dependencies
update-deps:
@echo "Updating dependencies..."
go get -u ./...
go mod tidy
go mod vendor
# Database utilities
# Create SQLite database directory
init-db:
@echo "Creating database directory..."
mkdir -p data
# Reset SQLite database
reset-db:
@echo "Resetting SQLite database..."
rm -f *.db data/*.db
# Migration recipes
# Run database migrations manually (if needed)
migrate:
@echo "Running database migrations..."
go run ./core/ --migrate-only || echo "Migration command not implemented yet"
# Production utilities
# Build for production
build-prod:
@echo "Building for production..."
CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -ldflags="-w -s" -o server ./core/
# Create release tarball
release: build-prod
@echo "Creating release package..."
tar -czf openaccounting-server-$(date +%Y%m%d).tar.gz server config.json.sample README.md
# Security scan (requires trivy)
security-scan:
@echo "Scanning Docker image for vulnerabilities..."
trivy image {{image_name}}:{{tag}}
# Show configuration help
config-help:
@echo "OpenAccounting Server Configuration:"
@echo ""
@echo "Environment Variables (prefix with OA_):"
@echo " OA_ADDRESS Server address (default: localhost)"
@echo " OA_PORT Server port (default: 8080)"
@echo " OA_API_PREFIX API prefix (default: /api/v1)"
@echo " OA_DATABASE_DRIVER Database driver: sqlite or mysql (default: sqlite)"
@echo " OA_DATABASE_FILE SQLite database file (default: ./openaccounting.db)"
@echo " OA_DATABASE_ADDRESS MySQL address (e.g., localhost:3306)"
@echo " OA_DATABASE MySQL database name"
@echo " OA_USER Database username"
@echo " OA_PASSWORD Database password (recommended for security)"
@echo " OA_MAILGUN_DOMAIN Mailgun domain"
@echo " OA_MAILGUN_KEY Mailgun API key (recommended for security)"
@echo " OA_MAILGUN_EMAIL Mailgun email"
@echo " OA_MAILGUN_SENDER Mailgun sender name"
@echo ""
@echo "Examples:"
@echo " Development: OA_DATABASE_DRIVER=sqlite OA_PORT=8080 ./server"
@echo " Production: OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret ./server"
# All-in-one development setup
dev-setup: install-deps init-db build
@echo "Development setup complete!"
@echo "Run 'just run-dev' to start the server"
# All-in-one production build
prod-build: clean build-prod docker-build
@echo "Production build complete!"
@echo "Run 'just docker-run' to test the container"

View File

@@ -1,105 +0,0 @@
package main
import (
"encoding/json"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"log"
"os"
)
func main() {
if len(os.Args) != 2 {
log.Fatal("Usage: migrate1.go <upgrade/downgrade>")
}
command := os.Args[1]
if command != "upgrade" && command != "downgrade" {
log.Fatal("Usage: migrate1.go <upgrade/downgrade>")
}
//filename is the path to the json config file
var config types.Config
file, err := os.Open("./config.json")
if err != nil {
log.Fatal(err)
}
decoder := json.NewDecoder(file)
err = decoder.Decode(&config)
if err != nil {
log.Fatal(err)
}
connectionString := config.User + ":" + config.Password + "@/" + config.Database
db, err := db.NewDB(connectionString)
if command == "upgrade" {
err = upgrade(db)
} else {
err = downgrade(db)
}
if err != nil {
log.Fatal(err)
}
log.Println("done")
}
func upgrade(db *db.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
query1 := "ALTER TABLE user ADD COLUMN signupSource VARCHAR(100) NOT NULL AFTER emailVerifyCode"
if _, err = tx.Exec(query1); err != nil {
return
}
return
}
func downgrade(db *db.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
query1 := "ALTER TABLE user DROP COLUMN signupSource"
if _, err = tx.Exec(query1); err != nil {
return
}
return
}

View File

@@ -1,105 +0,0 @@
package main
import (
"encoding/json"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"log"
"os"
)
func main() {
if len(os.Args) != 2 {
log.Fatal("Usage: migrate2.go <upgrade/downgrade>")
}
command := os.Args[1]
if command != "upgrade" && command != "downgrade" {
log.Fatal("Usage: migrate2.go <upgrade/downgrade>")
}
//filename is the path to the json config file
var config types.Config
file, err := os.Open("./config.json")
if err != nil {
log.Fatal(err)
}
decoder := json.NewDecoder(file)
err = decoder.Decode(&config)
if err != nil {
log.Fatal(err)
}
connectionString := config.User + ":" + config.Password + "@/" + config.Database
db, err := db.NewDB(connectionString)
if command == "upgrade" {
err = upgrade(db)
} else {
err = downgrade(db)
}
if err != nil {
log.Fatal(err)
}
log.Println("done")
}
func upgrade(db *db.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
query1 := "ALTER TABLE org ADD COLUMN timezone VARCHAR(100) NOT NULL AFTER `precision`"
if _, err = tx.Exec(query1); err != nil {
return
}
return
}
func downgrade(db *db.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
query1 := "ALTER TABLE org DROP COLUMN timezone"
if _, err = tx.Exec(query1); err != nil {
return
}
return
}

View File

@@ -1,105 +0,0 @@
package main
import (
"encoding/json"
"github.com/openaccounting/oa-server/core/model/db"
"github.com/openaccounting/oa-server/core/model/types"
"log"
"os"
)
func main() {
if len(os.Args) != 2 {
log.Fatal("Usage: migrate3.go <upgrade/downgrade>")
}
command := os.Args[1]
if command != "upgrade" && command != "downgrade" {
log.Fatal("Usage: migrate3.go <upgrade/downgrade>")
}
//filename is the path to the json config file
var config types.Config
file, err := os.Open("./config.json")
if err != nil {
log.Fatal(err)
}
decoder := json.NewDecoder(file)
err = decoder.Decode(&config)
if err != nil {
log.Fatal(err)
}
connectionString := config.User + ":" + config.Password + "@/" + config.Database
db, err := db.NewDB(connectionString)
if command == "upgrade" {
err = upgrade(db)
} else {
err = downgrade(db)
}
if err != nil {
log.Fatal(err)
}
log.Println("done")
}
func upgrade(db *db.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
query1 := "CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;"
if _, err = tx.Exec(query1); err != nil {
return
}
return
}
func downgrade(db *db.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
query1 := "DROP TABLE budgetitem"
if _, err = tx.Exec(query1); err != nil {
return
}
return
}

18
models/account.go Normal file
View File

@@ -0,0 +1,18 @@
package models
// Account represents a financial account
type Account struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
Name string `gorm:"column:name;size:100;not null"`
Parent []byte `gorm:"column:parent;type:BINARY(16);not null"`
Currency string `gorm:"column:currency;size:10;not null"`
Precision int `gorm:"column:precision;not null"`
DebitBalance bool `gorm:"column:debitBalance;not null"`
Org Org `gorm:"foreignKey:OrgID"`
Splits []Split `gorm:"foreignKey:AccountID"`
Balances []Balance `gorm:"foreignKey:AccountID"`
}

13
models/api_key.go Normal file
View File

@@ -0,0 +1,13 @@
package models
// APIKey represents API keys for users
type APIKey struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
Label string `gorm:"column:label;size:300;not null"`
Deleted uint64 `gorm:"column:deleted"`
User User `gorm:"foreignKey:UserID"`
}

28
models/attachment.go Normal file
View File

@@ -0,0 +1,28 @@
package models
import (
"time"
)
type Attachment struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
TransactionID []byte `gorm:"column:transactionId;type:BINARY(16);not null"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
FileName string `gorm:"column:fileName;size:255;not null"`
OriginalName string `gorm:"column:originalName;size:255;not null"`
ContentType string `gorm:"column:contentType;size:100;not null"`
FileSize int64 `gorm:"column:fileSize;not null"`
FilePath string `gorm:"column:filePath;size:500;not null"`
Description string `gorm:"column:description;size:500"`
Uploaded time.Time `gorm:"column:uploaded;not null"`
Deleted bool `gorm:"column:deleted;default:false"`
Transaction Transaction `gorm:"foreignKey:TransactionID"`
Org Org `gorm:"foreignKey:OrgID"`
User User `gorm:"foreignKey:UserID"`
}
func (Attachment) TableName() string {
return "attachment"
}

11
models/balance.go Normal file
View File

@@ -0,0 +1,11 @@
package models
// Balance represents an account balance at a point in time
type Balance struct {
ID uint `gorm:"primaryKey;autoIncrement"`
Date uint64 `gorm:"column:date;not null"`
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
Amount int64 `gorm:"column:amount;not null"`
Account Account `gorm:"foreignKey:AccountID"`
}

45
models/base.go Normal file
View File

@@ -0,0 +1,45 @@
package models
import (
"github.com/google/uuid"
"github.com/openaccounting/oa-server/core/util/id"
"gorm.io/gorm"
)
type Base struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
}
// GetUUID converts binary ID to UUID
func (b *Base) GetUUID() (uuid.UUID, error) {
return id.ToUUID(b.ID)
}
// GetIDString returns string representation of the ID
func (b *Base) GetIDString() string {
return id.String(b.ID)
}
// SetIDFromString parses string UUID into binary ID
func (b *Base) SetIDFromString(s string) error {
binID, err := id.FromString(s)
if err != nil {
return err
}
b.ID = binID
return nil
}
// ValidateID checks if the ID is a valid UUID
func (b *Base) ValidateID() error {
_, err := uuid.FromBytes(b.ID)
return err
}
// BeforeCreate GORM hook to set ID if empty
func (b *Base) BeforeCreate(tx *gorm.DB) error {
if len(b.ID) == 0 {
b.ID = id.New()
}
return nil
}

13
models/budget_item.go Normal file
View File

@@ -0,0 +1,13 @@
package models
// BudgetItem represents budget items
type BudgetItem struct {
ID uint `gorm:"primaryKey;autoIncrement"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Amount int64 `gorm:"column:amount;not null"`
Org Org `gorm:"foreignKey:OrgID"`
Account Account `gorm:"foreignKey:AccountID"`
}

13
models/invite.go Normal file
View File

@@ -0,0 +1,13 @@
package models
// Invite represents organization invitations
type Invite struct {
ID string `gorm:"size:32;primaryKey"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
Email string `gorm:"column:email;size:100;not null"`
Accepted bool `gorm:"column:accepted;not null"`
Org Org `gorm:"foreignKey:OrgID"`
}

15
models/org.go Normal file
View File

@@ -0,0 +1,15 @@
package models
// Org represents an organization
type Org struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
Name string `gorm:"column:name;size:100;not null"`
Currency string `gorm:"column:currency;size:10;not null"`
Precision int `gorm:"column:precision;not null"`
Timezone string `gorm:"column:timezone;size:100;not null"`
Accounts []Account `gorm:"foreignKey:OrgID"`
UserOrgs []UserOrg `gorm:"foreignKey:OrgID"`
}

18
models/permission.go Normal file
View File

@@ -0,0 +1,18 @@
package models
// Permission represents access control rules
type Permission struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
UserID []byte `gorm:"column:userId;type:BINARY(16)"`
TokenID []byte `gorm:"column:tokenId;type:BINARY(16)"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
Type uint `gorm:"column:type;not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
User User `gorm:"foreignKey:UserID"`
Token Token `gorm:"foreignKey:TokenID"`
Org Org `gorm:"foreignKey:OrgID"`
Account Account `gorm:"foreignKey:AccountID"`
}

14
models/price.go Normal file
View File

@@ -0,0 +1,14 @@
package models
// Price represents currency exchange rates
type Price struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
Currency string `gorm:"column:currency;size:10;not null"`
Date uint64 `gorm:"column:date;not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
Price float64 `gorm:"column:price;not null"`
Org Org `gorm:"foreignKey:OrgID"`
}

12
models/session.go Normal file
View File

@@ -0,0 +1,12 @@
package models
// Session represents user sessions
type Session struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
Terminated uint64 `gorm:"column:terminated"`
User User `gorm:"foreignKey:UserID"`
}

17
models/split.go Normal file
View File

@@ -0,0 +1,17 @@
package models
// Split represents a single entry in a transaction
type Split struct {
ID uint `gorm:"primaryKey;autoIncrement"`
TransactionID []byte `gorm:"column:transactionId;type:BINARY(16);not null"`
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
Date uint64 `gorm:"column:date;not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
Amount int64 `gorm:"column:amount;not null"`
NativeAmount int64 `gorm:"column:nativeAmount;not null"`
Deleted bool `gorm:"column:deleted;default:false"`
Transaction Transaction `gorm:"foreignKey:TransactionID"`
Account Account `gorm:"foreignKey:AccountID"`
}

10
models/token.go Normal file
View File

@@ -0,0 +1,10 @@
package models
// Token represents an API token
type Token struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
Name string `gorm:"column:name;size:100"`
UserOrgID uint `gorm:"column:userOrgId;not null"`
UserOrg UserOrg `gorm:"foreignKey:UserOrgID"`
}

18
models/transaction.go Normal file
View File

@@ -0,0 +1,18 @@
package models
// Transaction represents a financial transaction
type Transaction struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
Date uint64 `gorm:"column:date;not null"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
Description string `gorm:"column:description;size:300;not null"`
Data string `gorm:"column:data;type:TEXT;not null"`
Deleted bool `gorm:"column:deleted;default:false"`
Org Org `gorm:"foreignKey:OrgID"`
User User `gorm:"foreignKey:UserID"`
Splits []Split `gorm:"foreignKey:TransactionID"`
}

21
models/user.go Normal file
View File

@@ -0,0 +1,21 @@
package models
// User represents a user account
type User struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
Inserted uint64 `gorm:"column:inserted;not null"`
Updated uint64 `gorm:"column:updated;not null"`
FirstName string `gorm:"column:firstName;size:50;not null"`
LastName string `gorm:"column:lastName;size:50;not null"`
Email string `gorm:"column:email;size:100;not null;unique"`
PasswordHash string `gorm:"column:passwordHash;size:100;not null"`
AgreeToTerms bool `gorm:"column:agreeToTerms;not null"`
PasswordReset string `gorm:"column:passwordReset;size:32;not null"`
EmailVerified bool `gorm:"column:emailVerified;not null"`
EmailVerifyCode string `gorm:"column:emailVerifyCode;size:32;not null"`
SignupSource string `gorm:"column:signupSource;size:100;not null"`
UserOrgs []UserOrg `gorm:"foreignKey:UserID"`
Sessions []Session `gorm:"foreignKey:UserID"`
APIKeys []APIKey `gorm:"foreignKey:UserID"`
}

12
models/user_org.go Normal file
View File

@@ -0,0 +1,12 @@
package models
// UserOrg represents the relationship between users and organizations
type UserOrg struct {
ID uint `gorm:"primaryKey;autoIncrement"`
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
Admin bool `gorm:"column:admin;default:false"`
User User `gorm:"foreignKey:UserID"`
Org Org `gorm:"foreignKey:OrgID"`
}

View File

@@ -30,4 +30,6 @@ CREATE TABLE apikey (id BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL,
CREATE TABLE invite (id VARCHAR(32) NOT NULL, orgId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, updated BIGINT UNSIGNED NOT NULL, email VARCHAR(100) NOT NULL, accepted BOOLEAN NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
CREATE TABLE attachment (id BINARY(16) NOT NULL, transactionId BINARY(16) NOT NULL, orgId BINARY(16) NOT NULL, userId BINARY(16) NOT NULL, fileName VARCHAR(255) NOT NULL, originalName VARCHAR(255) NOT NULL, contentType VARCHAR(100) NOT NULL, fileSize BIGINT NOT NULL, filePath VARCHAR(500) NOT NULL, description VARCHAR(500), uploaded BIGINT UNSIGNED NOT NULL, deleted BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY(id)) ENGINE=InnoDB;

27
vendor/filippo.io/edwards25519/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

14
vendor/filippo.io/edwards25519/README.md generated vendored Normal file
View File

@@ -0,0 +1,14 @@
# filippo.io/edwards25519
```
import "filippo.io/edwards25519"
```
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.

20
vendor/filippo.io/edwards25519/doc.go generated vendored Normal file
View File

@@ -0,0 +1,20 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/internal/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519

427
vendor/filippo.io/edwards25519/edwards25519.go generated vendored Normal file
View File

@@ -0,0 +1,427 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"errors"
"filippo.io/edwards25519/field"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is NOT valid, and it may be used only as a receiver.
type Point struct {
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t field.Element
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d field.Element
}
// Constructors.
func (v *projP2) Zero() *projP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *affineCached) Zero() *affineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
// Set sets v = u, and returns v.
func (v *Point) Set(u *Point) *Point {
*v = *u
return v
}
// Encoding.
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytes(&buf)
}
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
// Note that SetBytes accepts all non-canonical encodings of valid points.
// That is, it follows decoding rules that match most implementations in
// the ecosystem rather than RFC 8032.
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
// specifically the "Canonical A, R" section.
y, err := new(field.Element).SetBytes(x)
if err != nil {
return nil, errors.New("edwards25519: invalid point encoding length")
}
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
v.z.One()
v.t.Multiply(xx, y) // xy = T / Z
return v, nil
}
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
// Conversions.
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
v.X.Multiply(&p.X, &p.T)
v.Y.Multiply(&p.Y, &p.Z)
v.Z.Multiply(&p.Z, &p.T)
return v
}
func (v *projP2) FromP3(p *Point) *projP2 {
v.X.Set(&p.x)
v.Y.Set(&p.y)
v.Z.Set(&p.z)
return v
}
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
v.x.Multiply(&p.X, &p.T)
v.y.Multiply(&p.Y, &p.Z)
v.z.Multiply(&p.Z, &p.T)
v.t.Multiply(&p.X, &p.Y)
return v
}
func (v *Point) fromP2(p *projP2) *Point {
v.x.Multiply(&p.X, &p.Z)
v.y.Multiply(&p.Y, &p.Z)
v.z.Square(&p.Z)
v.t.Multiply(&p.X, &p.Y)
return v
}
// d is a constant in the curve equation.
var d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.Z.Set(&p.z)
v.T2d.Multiply(&p.t, d2)
return v
}
func (v *affineCached) FromP3(p *Point) *affineCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
v.T2d.Multiply(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Subtract(&ZZ2, &TT2d)
return v
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Subtract(&Z2, &TT2d)
return v
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Subtract(&YY, &XX)
v.X.Subtract(&XplusYsq, &v.Y)
v.T.Subtract(&ZZ2, &v.Z)
return v
}
// Negation.
// Negate sets v = -p, and returns v.
func (v *Point) Negate(p *Point) *Point {
checkInitialized(p)
v.x.Negate(&p.x)
v.y.Set(&p.y)
v.z.Set(&p.z)
v.t.Negate(&p.t)
return v
}
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
t4.Multiply(&u.y, &v.z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}

349
vendor/filippo.io/edwards25519/extra.go generated vendored Normal file
View File

@@ -0,0 +1,349 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// This file contains additional functionality that is not included in the
// upstream crypto/internal/edwards25519 package.
import (
"errors"
"filippo.io/edwards25519/field"
)
// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap. Don't change the style without making
// sure it doesn't increase the inliner cost.
var e [4]field.Element
X, Y, Z, T = v.extendedCoordinates(&e)
return
}
func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
checkInitialized(v)
X = e[0].Set(&v.x)
Y = e[1].Set(&v.y)
Z = e[2].Set(&v.z)
T = e[3].Set(&v.t)
return
}
// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
//
// If the coordinates are invalid or don't represent a valid point on the curve,
// SetExtendedCoordinates returns nil and an error and the receiver is
// unchanged. Otherwise, SetExtendedCoordinates returns v.
func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
if !isOnCurve(X, Y, Z, T) {
return nil, errors.New("edwards25519: invalid point coordinates")
}
v.x.Set(X)
v.y.Set(Y)
v.z.Set(Z)
v.t.Set(T)
return v, nil
}
func isOnCurve(X, Y, Z, T *field.Element) bool {
var lhs, rhs field.Element
XX := new(field.Element).Square(X)
YY := new(field.Element).Square(Y)
ZZ := new(field.Element).Square(Z)
TT := new(field.Element).Square(T)
// -x² + y² = 1 + dx²y²
// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
// -X² + Y² = Z² + dT²
lhs.Subtract(YY, XX)
rhs.Multiply(d, TT).Add(&rhs, ZZ)
if lhs.Equal(&rhs) != 1 {
return false
}
// xy = T/Z
// XY/Z² = T/Z
// XY = TZ
lhs.Multiply(X, Y)
rhs.Multiply(T, Z)
return lhs.Equal(&rhs) == 1
}
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
//
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
// while every valid edwards25519 point has a unique u-coordinate Montgomery
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
// to any edwards25519 point, and every other X25519 input corresponds to two
// edwards25519 points.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert returns zero.
func (s *Scalar) Invert(t *Scalar) *Scalar {
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

420
vendor/filippo.io/edwards25519/field/fe.go generated vendored Normal file
View File

@@ -0,0 +1,420 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"errors"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
// not of the right length, SetBytes returns nil and an error, and the
// receiver is unchanged.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032, but
// consistent with most Ed25519 implementations.
func (v *Element) SetBytes(x []byte) (*Element, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v, nil
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
t0 := new(Element)
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := new(Element).Square(v)
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
uNeg := new(Element).Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(rr) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}

16
vendor/filippo.io/edwards25519/field/fe_amd64.go generated vendored Normal file
View File

@@ -0,0 +1,16 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape
func feSquare(out *Element, a *Element)

379
vendor/filippo.io/edwards25519/field/fe_amd64.s generated vendored Normal file
View File

@@ -0,0 +1,379 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET

12
vendor/filippo.io/edwards25519/field/fe_amd64_noasm.go generated vendored Normal file
View File

@@ -0,0 +1,12 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

16
vendor/filippo.io/edwards25519/field/fe_arm64.go generated vendored Normal file
View File

@@ -0,0 +1,16 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

42
vendor/filippo.io/edwards25519/field/fe_arm64.s generated vendored Normal file
View File

@@ -0,0 +1,42 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET

12
vendor/filippo.io/edwards25519/field/fe_arm64_noasm.go generated vendored Normal file
View File

@@ -0,0 +1,12 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

50
vendor/filippo.io/edwards25519/field/fe_extra.go generated vendored Normal file
View File

@@ -0,0 +1,50 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "errors"
// This file contains additional functionality that is not included in the
// upstream crypto/ed25519/edwards25519/field package.
// SetWideBytes sets v to x, where x is a 64-byte little-endian encoding, which
// is reduced modulo the field order. If x is not of the right length,
// SetWideBytes returns nil and an error, and the receiver is unchanged.
//
// SetWideBytes is not necessary to select a uniformly distributed value, and is
// only provided for compatibility: SetBytes can be used instead as the chance
// of bias is less than 2⁻²⁵⁰.
func (v *Element) SetWideBytes(x []byte) (*Element, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetWideBytes input size")
}
// Split the 64 bytes into two elements, and extract the most significant
// bit of each, which is ignored by SetBytes.
lo, _ := new(Element).SetBytes(x[:32])
loMSB := uint64(x[31] >> 7)
hi, _ := new(Element).SetBytes(x[32:])
hiMSB := uint64(x[63] >> 7)
// The output we want is
//
// v = lo + loMSB * 2²⁵⁵ + hi * 2²⁵⁶ + hiMSB * 2⁵¹¹
//
// which applying the reduction identity comes out to
//
// v = lo + loMSB * 19 + hi * 2 * 19 + hiMSB * 2 * 19²
//
// l0 will be the sum of a 52 bits value (lo.l0), plus a 5 bits value
// (loMSB * 19), a 6 bits value (hi.l0 * 2 * 19), and a 10 bits value
// (hiMSB * 2 * 19²), so it fits in a uint64.
v.l0 = lo.l0 + loMSB*19 + hi.l0*2*19 + hiMSB*2*19*19
v.l1 = lo.l1 + hi.l1*2*19
v.l2 = lo.l2 + hi.l2*2*19
v.l3 = lo.l3 + hi.l3*2*19
v.l4 = lo.l4 + hi.l4*2*19
return v.carryPropagate(), nil
}

266
vendor/filippo.io/edwards25519/field/fe_generic.go generated vendored Normal file
View File

@@ -0,0 +1,266 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}

343
vendor/filippo.io/edwards25519/scalar.go generated vendored Normal file
View File

@@ -0,0 +1,343 @@
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to a uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}

1147
vendor/filippo.io/edwards25519/scalar_fiat.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

214
vendor/filippo.io/edwards25519/scalarmult.go generated vendored Normal file
View File

@@ -0,0 +1,214 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

129
vendor/filippo.io/edwards25519/tables.go generated vendored Normal file
View File

@@ -0,0 +1,129 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}

3
vendor/github.com/aws/aws-sdk-go/NOTICE.txt generated vendored Normal file
View File

@@ -0,0 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

93
vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go generated vendored Normal file
View File

@@ -0,0 +1,93 @@
// Package arn provides a parser for interacting with Amazon Resource Names.
package arn
import (
"errors"
"strings"
)
const (
arnDelimiter = ":"
arnSections = 6
arnPrefix = "arn:"
// zero-indexed
sectionPartition = 1
sectionService = 2
sectionRegion = 3
sectionAccountID = 4
sectionResource = 5
// errors
invalidPrefix = "arn: invalid prefix"
invalidSections = "arn: not enough sections"
)
// ARN captures the individual fields of an Amazon Resource Name.
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
type ARN struct {
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
// (Beijing) region is "aws-cn".
Partition string
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
// namespaces, see
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
Service string
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
// component might be omitted.
Region string
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
// ARNs for some resources don't require an account number, so this component might be omitted.
AccountID string
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
// resource name itself. Some services allows paths for resource names, as described in
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
Resource string
}
// Parse parses an ARN into its constituent parts.
//
// Some example ARNs:
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
// arn:aws:iam::123456789012:user/David
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
func Parse(arn string) (ARN, error) {
if !strings.HasPrefix(arn, arnPrefix) {
return ARN{}, errors.New(invalidPrefix)
}
sections := strings.SplitN(arn, arnDelimiter, arnSections)
if len(sections) != arnSections {
return ARN{}, errors.New(invalidSections)
}
return ARN{
Partition: sections[sectionPartition],
Service: sections[sectionService],
Region: sections[sectionRegion],
AccountID: sections[sectionAccountID],
Resource: sections[sectionResource],
}, nil
}
// IsARN returns whether the given string is an ARN by looking for
// whether the string starts with "arn:" and contains the correct number
// of sections delimited by colons(:).
func IsARN(arn string) bool {
return strings.HasPrefix(arn, arnPrefix) && strings.Count(arn, ":") >= arnSections-1
}
// String returns the canonical representation of the ARN
func (arn ARN) String() string {
return arnPrefix +
arn.Partition + arnDelimiter +
arn.Service + arnDelimiter +
arn.Region + arnDelimiter +
arn.AccountID + arnDelimiter +
arn.Resource
}

164
vendor/github.com/aws/aws-sdk-go/aws/awserr/error.go generated vendored Normal file
View File

@@ -0,0 +1,164 @@
// Package awserr represents API error interface accessors for the SDK.
package awserr
// An Error wraps lower level errors with code, message and an original error.
// The underlying concrete error type may also satisfy other interfaces which
// can be to used to obtain more specific information about the error.
//
// Calling Error() or String() will always include the full information about
// an error based on its underlying type.
//
// Example:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Get error details
// log.Println("Error:", awsErr.Code(), awsErr.Message())
//
// // Prints out full error message, including original error if there was one.
// log.Println("Error:", awsErr.Error())
//
// // Get original error
// if origErr := awsErr.OrigErr(); origErr != nil {
// // operate on original error.
// }
// } else {
// fmt.Println(err.Error())
// }
// }
//
type Error interface {
// Satisfy the generic error interface.
error
// Returns the short phrase depicting the classification of the error.
Code() string
// Returns the error details message.
Message() string
// Returns the original error if one was set. Nil is returned if not set.
OrigErr() error
}
// BatchError is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Deprecated: Replaced with BatchedErrors. Only defined for backwards
// compatibility.
type BatchError interface {
// Satisfy the generic error interface.
error
// Returns the short phrase depicting the classification of the error.
Code() string
// Returns the error details message.
Message() string
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// BatchedErrors is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Replaces BatchError
type BatchedErrors interface {
// Satisfy the base Error interface.
Error
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// New returns an Error object described by the code, message, and origErr.
//
// If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned.
func New(code, message string, origErr error) Error {
var errs []error
if origErr != nil {
errs = append(errs, origErr)
}
return newBaseError(code, message, errs)
}
// NewBatchError returns an BatchedErrors with a collection of errors as an
// array of errors.
func NewBatchError(code, message string, errs []error) BatchedErrors {
return newBaseError(code, message, errs)
}
// A RequestFailure is an interface to extract request failure information from
// an Error such as the request ID of the failed request returned by a service.
// RequestFailures may not always have a requestID value if the request failed
// prior to reaching the service such as a connection error.
//
// Example:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if reqerr, ok := err.(RequestFailure); ok {
// log.Println("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
// } else {
// log.Println("Error:", err.Error())
// }
// }
//
// Combined with awserr.Error:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Generic AWS Error with Code, Message, and original error (if any)
// fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
//
// if reqErr, ok := err.(awserr.RequestFailure); ok {
// // A service error occurred
// fmt.Println(reqErr.StatusCode(), reqErr.RequestID())
// }
// } else {
// fmt.Println(err.Error())
// }
// }
//
type RequestFailure interface {
Error
// The status code of the HTTP response.
StatusCode() int
// The request ID returned by the service for a request failure. This will
// be empty if no request ID is available such as the request failed due
// to a connection error.
RequestID() string
}
// NewRequestFailure returns a wrapped error with additional information for
// request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID)
}
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
type UnmarshalError interface {
awsError
Bytes() []byte
}
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
// the bytes that fail to unmarshal to the error.
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
return &unmarshalError{
awsError: New("UnmarshalError", msg, err),
bytes: bytes,
}
}

221
vendor/github.com/aws/aws-sdk-go/aws/awserr/types.go generated vendored Normal file
View File

@@ -0,0 +1,221 @@
package awserr
import (
"encoding/hex"
"fmt"
)
// SprintError returns a string of the formatted error code.
//
// Both extra and origErr are optional. If they are included their lines
// will be added, but if they are not included their lines will be ignored.
func SprintError(code, message, extra string, origErr error) string {
msg := fmt.Sprintf("%s: %s", code, message)
if extra != "" {
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
}
if origErr != nil {
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
}
return msg
}
// A baseError wraps the code and message which defines an error. It also
// can be used to wrap an original error object.
//
// Should be used as the root for errors satisfying the awserr.Error. Also
// for any error which does not fit into a specific error wrapper type.
type baseError struct {
// Classification of error
code string
// Detailed information about error
message string
// Optional original error this error is based off of. Allows building
// chained errors.
errs []error
}
// newBaseError returns an error object for the code, message, and errors.
//
// code is a short no whitespace phrase depicting the classification of
// the error that is being created.
//
// message is the free flow string containing detailed information about the
// error.
//
// origErrs is the error objects which will be nested under the new errors to
// be returned.
func newBaseError(code, message string, origErrs []error) *baseError {
b := &baseError{
code: code,
message: message,
errs: origErrs,
}
return b
}
// Error returns the string representation of the error.
//
// See ErrorWithExtra for formatting.
//
// Satisfies the error interface.
func (b baseError) Error() string {
size := len(b.errs)
if size > 0 {
return SprintError(b.code, b.message, "", errorList(b.errs))
}
return SprintError(b.code, b.message, "", nil)
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (b baseError) String() string {
return b.Error()
}
// Code returns the short phrase depicting the classification of the error.
func (b baseError) Code() string {
return b.code
}
// Message returns the error details message.
func (b baseError) Message() string {
return b.message
}
// OrigErr returns the original error if one was set. Nil is returned if no
// error was set. This only returns the first element in the list. If the full
// list is needed, use BatchedErrors.
func (b baseError) OrigErr() error {
switch len(b.errs) {
case 0:
return nil
case 1:
return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
"multiple errors occurred", b.errs)
}
}
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (b baseError) OrigErrs() []error {
return b.errs
}
// So that the Error interface type can be included as an anonymous field
// in the requestError struct and not conflict with the error.Error() method.
type awsError Error
// A requestError wraps a request or service error.
//
// Composed of baseError for code, message, and original error.
type requestError struct {
awsError
statusCode int
requestID string
bytes []byte
}
// newRequestError returns a wrapped error with additional information for
// request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
//
// Also wraps original errors via the baseError.
func newRequestError(err Error, statusCode int, requestID string) *requestError {
return &requestError{
awsError: err,
statusCode: statusCode,
requestID: requestID,
}
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: %s",
r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (r requestError) String() string {
return r.Error()
}
// StatusCode returns the wrapped status code for the error
func (r requestError) StatusCode() int {
return r.statusCode
}
// RequestID returns the wrapped requestID
func (r requestError) RequestID() string {
return r.requestID
}
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (r requestError) OrigErrs() []error {
if b, ok := r.awsError.(BatchedErrors); ok {
return b.OrigErrs()
}
return []error{r.OrigErr()}
}
type unmarshalError struct {
awsError
bytes []byte
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (e unmarshalError) Error() string {
extra := hex.Dump(e.bytes)
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (e unmarshalError) String() string {
return e.Error()
}
// Bytes returns the bytes that failed to unmarshal.
func (e unmarshalError) Bytes() []byte {
return e.bytes
}
// An error list that satisfies the golang interface
type errorList []error
// Error returns the string representation of the error.
//
// Satisfies the error interface.
func (e errorList) Error() string {
msg := ""
// How do we want to handle the array size being zero
if size := len(e); size > 0 {
for i := 0; i < size; i++ {
msg += e[i].Error()
// We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n'
if i+1 < size {
msg += "\n"
}
}
}
return msg
}

108
vendor/github.com/aws/aws-sdk-go/aws/awsutil/copy.go generated vendored Normal file
View File

@@ -0,0 +1,108 @@
package awsutil
import (
"io"
"reflect"
"time"
)
// Copy deeply copies a src structure to dst. Useful for copying request and
// response structures.
//
// Can copy between structs of different type, but will only copy fields which
// are assignable, and exist in both structs. Fields which are not assignable,
// or do not exist in both structs are ignored.
func Copy(dst, src interface{}) {
dstval := reflect.ValueOf(dst)
if !dstval.IsValid() {
panic("Copy dst cannot be nil")
}
rcopy(dstval, reflect.ValueOf(src), true)
}
// CopyOf returns a copy of src while also allocating the memory for dst.
// src must be a pointer type or this operation will fail.
func CopyOf(src interface{}) (dst interface{}) {
dsti := reflect.New(reflect.TypeOf(src).Elem())
dst = dsti.Interface()
rcopy(dsti, reflect.ValueOf(src), true)
return
}
// rcopy performs a recursive copy of values from the source to destination.
//
// root is used to skip certain aspects of the copy which are not valid
// for the root node of a object.
func rcopy(dst, src reflect.Value, root bool) {
if !src.IsValid() {
return
}
switch src.Kind() {
case reflect.Ptr:
if _, ok := src.Interface().(io.Reader); ok {
if dst.Kind() == reflect.Ptr && dst.Elem().CanSet() {
dst.Elem().Set(src)
} else if dst.CanSet() {
dst.Set(src)
}
} else {
e := src.Type().Elem()
if dst.CanSet() && !src.IsNil() {
if _, ok := src.Interface().(*time.Time); !ok {
dst.Set(reflect.New(e))
} else {
tempValue := reflect.New(e)
tempValue.Elem().Set(src.Elem())
// Sets time.Time's unexported values
dst.Set(tempValue)
}
}
if src.Elem().IsValid() {
// Keep the current root state since the depth hasn't changed
rcopy(dst.Elem(), src.Elem(), root)
}
}
case reflect.Struct:
t := dst.Type()
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Name
srcVal := src.FieldByName(name)
dstVal := dst.FieldByName(name)
if srcVal.IsValid() && dstVal.CanSet() {
rcopy(dstVal, srcVal, false)
}
}
case reflect.Slice:
if src.IsNil() {
break
}
s := reflect.MakeSlice(src.Type(), src.Len(), src.Cap())
dst.Set(s)
for i := 0; i < src.Len(); i++ {
rcopy(dst.Index(i), src.Index(i), false)
}
case reflect.Map:
if src.IsNil() {
break
}
s := reflect.MakeMap(src.Type())
dst.Set(s)
for _, k := range src.MapKeys() {
v := src.MapIndex(k)
v2 := reflect.New(v.Type()).Elem()
rcopy(v2, v, false)
dst.SetMapIndex(k, v2)
}
default:
// Assign the value if possible. If its not assignable, the value would
// need to be converted and the impact of that may be unexpected, or is
// not compatible with the dst type.
if src.Type().AssignableTo(dst.Type()) {
dst.Set(src)
}
}
}

27
vendor/github.com/aws/aws-sdk-go/aws/awsutil/equal.go generated vendored Normal file
View File

@@ -0,0 +1,27 @@
package awsutil
import (
"reflect"
)
// DeepEqual returns if the two values are deeply equal like reflect.DeepEqual.
// In addition to this, this method will also dereference the input values if
// possible so the DeepEqual performed will not fail if one parameter is a
// pointer and the other is not.
//
// DeepEqual will not perform indirection of nested values of the input parameters.
func DeepEqual(a, b interface{}) bool {
ra := reflect.Indirect(reflect.ValueOf(a))
rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type they are equal
// If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid {
// Both values must be valid to be equal
return false
}
return reflect.DeepEqual(ra.Interface(), rb.Interface())
}

View File

@@ -0,0 +1,221 @@
package awsutil
import (
"reflect"
"regexp"
"strconv"
"strings"
"github.com/jmespath/go-jmespath"
)
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
// rValuesAtPath returns a slice of values found in value v. The values
// in v are explored recursively so all nested values are collected.
func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTerm bool) []reflect.Value {
pathparts := strings.Split(path, "||")
if len(pathparts) > 1 {
for _, pathpart := range pathparts {
vals := rValuesAtPath(v, pathpart, createPath, caseSensitive, nilTerm)
if len(vals) > 0 {
return vals
}
}
return nil
}
values := []reflect.Value{reflect.Indirect(reflect.ValueOf(v))}
components := strings.Split(path, ".")
for len(values) > 0 && len(components) > 0 {
var index *int64
var indexStar bool
c := strings.TrimSpace(components[0])
if c == "" { // no actual component, illegal syntax
return nil
} else if caseSensitive && c != "*" && strings.ToLower(c[0:1]) == c[0:1] {
// TODO normalize case for user
return nil // don't support unexported fields
}
// parse this component
if m := indexRe.FindStringSubmatch(c); m != nil {
c = m[1]
if m[2] == "" {
index = nil
indexStar = true
} else {
i, _ := strconv.ParseInt(m[2], 10, 32)
index = &i
indexStar = false
}
}
nextvals := []reflect.Value{}
for _, value := range values {
// pull component name out of struct member
if value.Kind() != reflect.Struct {
continue
}
if c == "*" { // pull all members
for i := 0; i < value.NumField(); i++ {
if f := reflect.Indirect(value.Field(i)); f.IsValid() {
nextvals = append(nextvals, f)
}
}
continue
}
value = value.FieldByNameFunc(func(name string) bool {
if c == name {
return true
} else if !caseSensitive && strings.EqualFold(name, c) {
return true
}
return false
})
if nilTerm && value.Kind() == reflect.Ptr && len(components[1:]) == 0 {
if !value.IsNil() {
value.Set(reflect.Zero(value.Type()))
}
return []reflect.Value{value}
}
if createPath && value.Kind() == reflect.Ptr && value.IsNil() {
// TODO if the value is the terminus it should not be created
// if the value to be set to its position is nil.
value.Set(reflect.New(value.Type().Elem()))
value = value.Elem()
} else {
value = reflect.Indirect(value)
}
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !createPath && value.IsNil() {
value = reflect.ValueOf(nil)
}
}
if value.IsValid() {
nextvals = append(nextvals, value)
}
}
values = nextvals
if indexStar || index != nil {
nextvals = []reflect.Value{}
for _, valItem := range values {
value := reflect.Indirect(valItem)
if value.Kind() != reflect.Slice {
continue
}
if indexStar { // grab all indices
for i := 0; i < value.Len(); i++ {
idx := reflect.Indirect(value.Index(i))
if idx.IsValid() {
nextvals = append(nextvals, idx)
}
}
continue
}
// pull out index
i := int(*index)
if i >= value.Len() { // check out of bounds
if createPath {
// TODO resize slice
} else {
continue
}
} else if i < 0 { // support negative indexing
i = value.Len() + i
}
value = reflect.Indirect(value.Index(i))
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !createPath && value.IsNil() {
value = reflect.ValueOf(nil)
}
}
if value.IsValid() {
nextvals = append(nextvals, value)
}
}
values = nextvals
}
components = components[1:]
}
return values
}
// ValuesAtPath returns a list of values at the case insensitive lexical
// path inside of a structure.
func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
result, err := jmespath.Search(path, i)
if err != nil {
return nil, err
}
v := reflect.ValueOf(result)
if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) {
return nil, nil
}
if s, ok := result.([]interface{}); ok {
return s, err
}
if v.Kind() == reflect.Map && v.Len() == 0 {
return nil, nil
}
if v.Kind() == reflect.Slice {
out := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
out[i] = v.Index(i).Interface()
}
return out, nil
}
return []interface{}{result}, nil
}
// SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) {
rvals := rValuesAtPath(i, path, true, false, v == nil)
for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue
}
setValue(rval, v)
}
}
func setValue(dstVal reflect.Value, src interface{}) {
if dstVal.Kind() == reflect.Ptr {
dstVal = reflect.Indirect(dstVal)
}
srcVal := reflect.ValueOf(src)
if !srcVal.IsValid() { // src is literal nil
if dstVal.CanAddr() {
// Convert to pointer so that pointer's value can be nil'ed
// dstVal = dstVal.Addr()
}
dstVal.Set(reflect.Zero(dstVal.Type()))
} else if srcVal.Kind() == reflect.Ptr {
if srcVal.IsNil() {
srcVal = reflect.Zero(dstVal.Type())
} else {
srcVal = reflect.ValueOf(src).Elem()
}
dstVal.Set(srcVal)
} else {
dstVal.Set(srcVal)
}
}

View File

@@ -0,0 +1,123 @@
package awsutil
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
)
// Prettify returns the string representation of a value.
func Prettify(i interface{}) string {
var buf bytes.Buffer
prettify(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
// prettify will recursively walk value v to build a textual
// representation of the value.
func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
strtype := v.Type().String()
if strtype == "time.Time" {
fmt.Fprintf(buf, "%s", v.Interface())
break
} else if strings.HasPrefix(strtype, "io.") {
buf.WriteString("<buffer>")
break
}
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice || f.Kind() == reflect.Map) && f.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
ft, ok := v.Type().FieldByName(n)
if !ok {
panic(fmt.Sprintf("expected to find field %v on type %v, but was not found", n, v.Type()))
}
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
prettify(val, indent+2, buf)
}
if i < len(names)-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
strtype := v.Type().String()
if strtype == "[]uint8" {
fmt.Fprintf(buf, "<binary> len %d", v.Len())
break
}
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
}
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
prettify(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
}
}
buf.WriteString(nl + id + "]")
case reflect.Map:
buf.WriteString("{\n")
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
prettify(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
if !v.IsValid() {
fmt.Fprint(buf, "<invalid value>")
return
}
format := "%v"
switch v.Interface().(type) {
case string:
format = "%q"
case io.ReadSeeker, io.Reader:
format = "buffer(%p)"
}
fmt.Fprintf(buf, format, v.Interface())
}
}

View File

@@ -0,0 +1,90 @@
package awsutil
import (
"bytes"
"fmt"
"reflect"
"strings"
)
// StringValue returns the string representation of a value.
//
// Deprecated: Use Prettify instead.
func StringValue(i interface{}) string {
var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
buf.WriteString("{\n")
for i := 0; i < v.Type().NumField(); i++ {
ft := v.Type().Field(i)
fv := v.Field(i)
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
continue // ignore unexported fields
}
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
continue // ignore unset fields
}
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(ft.Name + ": ")
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
stringValue(fv, indent+2, buf)
}
buf.WriteString(",\n")
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
}
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
stringValue(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
}
}
buf.WriteString(nl + id + "]")
case reflect.Map:
buf.WriteString("{\n")
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
stringValue(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
format := "%v"
switch v.Interface().(type) {
case string:
format = "%q"
}
fmt.Fprintf(buf, format, v.Interface())
}
}

94
vendor/github.com/aws/aws-sdk-go/aws/client/client.go generated vendored Normal file
View File

@@ -0,0 +1,94 @@
package client
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
// A Config provides configuration to a service client instance.
type Config struct {
Config *aws.Config
Handlers request.Handlers
PartitionID string
Endpoint string
SigningRegion string
SigningName string
ResolvedRegion string
// States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors
// to determine if the signin name can be overridden based on metadata the
// service has.
SigningNameDerived bool
}
// ConfigProvider provides a generic way for a service client to receive
// the ClientConfig without circular dependencies.
type ConfigProvider interface {
ClientConfig(serviceName string, cfgs ...*aws.Config) Config
}
// ConfigNoResolveEndpointProvider same as ConfigProvider except it will not
// resolve the endpoint automatically. The service client's endpoint must be
// provided via the aws.Config.Endpoint field.
type ConfigNoResolveEndpointProvider interface {
ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) Config
}
// A Client implements the base client request and response handling
// used by all service clients.
type Client struct {
request.Retryer
metadata.ClientInfo
Config aws.Config
Handlers request.Handlers
}
// New will return a pointer to a new initialized service client.
func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, options ...func(*Client)) *Client {
svc := &Client{
Config: cfg,
ClientInfo: info,
Handlers: handlers.Copy(),
}
switch retryer, ok := cfg.Retryer.(request.Retryer); {
case ok:
svc.Retryer = retryer
case cfg.Retryer != nil && cfg.Logger != nil:
s := fmt.Sprintf("WARNING: %T does not implement request.Retryer; using DefaultRetryer instead", cfg.Retryer)
cfg.Logger.Log(s)
fallthrough
default:
maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = DefaultRetryerMaxNumRetries
}
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
}
svc.AddDebugHandlers()
for _, option := range options {
option(svc)
}
return svc
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
func (c *Client) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request {
return request.New(c.Config, c.ClientInfo, c.Handlers, c.Retryer, operation, params, data)
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (c *Client) AddDebugHandlers() {
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
}

View File

@@ -0,0 +1,177 @@
package client
import (
"math"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, you can implement the
// request.Retryer interface.
//
type DefaultRetryer struct {
// Num max Retries is the number of max retries that will be performed.
// By default, this is zero.
NumMaxRetries int
// MinRetryDelay is the minimum retry delay after which retry will be performed.
// If not set, the value is 0ns.
MinRetryDelay time.Duration
// MinThrottleRetryDelay is the minimum retry delay when throttled.
// If not set, the value is 0ns.
MinThrottleDelay time.Duration
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
// If not set, the value is 0ns.
MaxRetryDelay time.Duration
// MaxThrottleDelay is the maximum retry delay when throttled.
// If not set, the value is 0ns.
MaxThrottleDelay time.Duration
}
const (
// DefaultRetryerMaxNumRetries sets maximum number of retries
DefaultRetryerMaxNumRetries = 3
// DefaultRetryerMinRetryDelay sets minimum retry delay
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
// DefaultRetryerMaxRetryDelay sets maximum retry delay
DefaultRetryerMaxRetryDelay = 300 * time.Second
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
DefaultRetryerMaxThrottleDelay = 300 * time.Second
)
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
// setRetryerDefaults sets the default values of the retryer if not set
func (d *DefaultRetryer) setRetryerDefaults() {
if d.MinRetryDelay == 0 {
d.MinRetryDelay = DefaultRetryerMinRetryDelay
}
if d.MaxRetryDelay == 0 {
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
}
if d.MinThrottleDelay == 0 {
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
}
if d.MaxThrottleDelay == 0 {
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
}
}
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// if number of max retries is zero, no retries will be performed.
if d.NumMaxRetries == 0 {
return 0
}
// Sets default value for retryer members
d.setRetryerDefaults()
// minDelay is the minimum retryer delay
minDelay := d.MinRetryDelay
var initialDelay time.Duration
isThrottle := r.IsErrorThrottle()
if isThrottle {
if delay, ok := getRetryAfterDelay(r); ok {
initialDelay = delay
}
minDelay = d.MinThrottleDelay
}
retryCount := r.RetryCount
// maxDelay the maximum retryer delay
maxDelay := d.MaxRetryDelay
if isThrottle {
maxDelay = d.MaxThrottleDelay
}
var delay time.Duration
// Logic to cap the retry count based on the minDelay provided
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
if actualRetryCount < 63-retryCount {
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
if delay > maxDelay {
delay = getJitterDelay(maxDelay / 2)
}
} else {
delay = getJitterDelay(maxDelay / 2)
}
return delay + initialDelay
}
// getJitterDelay returns a jittered delay for retry
func getJitterDelay(duration time.Duration) time.Duration {
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
}
// ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldRetry returns false if number of max retries is 0.
if d.NumMaxRetries == 0 {
return false
}
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable != nil {
return *r.Retryable
}
return r.IsErrorRetryable() || r.IsErrorThrottle()
}
// This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request
func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) {
return 0, false
}
delayStr := r.HTTPResponse.Header.Get("Retry-After")
if len(delayStr) == 0 {
return 0, false
}
delay, err := strconv.Atoi(delayStr)
if err != nil {
return 0, false
}
return time.Duration(delay) * time.Second, true
}
// Will look at the status code to see if the retry header pertains to
// the status code.
func canUseRetryAfterHeader(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 503:
default:
return false
}
return true
}

Some files were not shown because too many files have changed in this diff Show More