You've already forked openaccounting-server
mirror of
https://github.com/openaccounting/oa-server.git
synced 2025-12-17 04:40:41 +13:00
Compare commits
26 Commits
master
...
6558a09258
| Author | SHA1 | Date | |
|---|---|---|---|
| 6558a09258 | |||
| f99a866e13 | |||
| e3152d9f40 | |||
| e78098ad45 | |||
| 7c43726abf | |||
| b7ac4b0152 | |||
| 1b115fe0ff | |||
| a87df47231 | |||
| 8b0a72c81f | |||
| f64f83e66f | |||
| f5f0853040 | |||
| 04653f2f02 | |||
| 3b89d8137e | |||
| d10686e70f | |||
| c335c834ba | |||
| b5ea2095e4 | |||
| 88c996a383 | |||
| 8c7088040d | |||
| 77ab4b0e1d | |||
| 62dea0e53c | |||
| d2ea9960bf | |||
| f547d8d75b | |||
| 0d1cb22044 | |||
| bd3f101fb4 | |||
| e865c4c1a2 | |||
| 51deace1da |
39
.dockerignore
Normal file
39
.dockerignore
Normal file
@@ -0,0 +1,39 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
*.md
|
||||
|
||||
# Docker
|
||||
Dockerfile
|
||||
.dockerignore
|
||||
|
||||
# Build artifacts
|
||||
server
|
||||
*.exe
|
||||
|
||||
# Development files
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Local config and data
|
||||
config.json
|
||||
*.db
|
||||
data/
|
||||
|
||||
# Test files
|
||||
*_test.go
|
||||
test*
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.log
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Dependencies (will be downloaded)
|
||||
vendor/
|
||||
38
.env.storage.example
Normal file
38
.env.storage.example
Normal file
@@ -0,0 +1,38 @@
|
||||
# OpenAccounting Storage Configuration
|
||||
# Copy this file to .env and modify as needed
|
||||
|
||||
# Database Configuration
|
||||
OA_DATABASEDRIVER=sqlite
|
||||
OA_DATABASEFILE=./openaccounting.db
|
||||
OA_ADDRESS=localhost
|
||||
OA_PORT=8080
|
||||
OA_APIPREFIX=/api/v1
|
||||
|
||||
# Storage Backend Configuration
|
||||
# Options: local, s3, b2
|
||||
OA_STORAGE_BACKEND=local
|
||||
|
||||
# Local Storage Configuration
|
||||
OA_STORAGE_LOCAL_ROOTDIR=./uploads
|
||||
OA_STORAGE_LOCAL_BASEURL=
|
||||
|
||||
# Amazon S3 Storage Configuration (uncomment if using S3)
|
||||
# OA_STORAGE_S3_REGION=us-east-1
|
||||
# OA_STORAGE_S3_BUCKET=my-openaccounting-attachments
|
||||
# OA_STORAGE_S3_PREFIX=attachments
|
||||
# OA_STORAGE_S3_ACCESSKEYID=AKIA...
|
||||
# OA_STORAGE_S3_SECRETACCESSKEY=...
|
||||
# OA_STORAGE_S3_ENDPOINT=
|
||||
# OA_STORAGE_S3_PATHSTYLE=false
|
||||
|
||||
# Backblaze B2 Storage Configuration (uncomment if using B2)
|
||||
# OA_STORAGE_B2_ACCOUNTID=your-b2-account-id
|
||||
# OA_STORAGE_B2_APPLICATIONKEY=your-b2-application-key
|
||||
# OA_STORAGE_B2_BUCKET=my-openaccounting-attachments
|
||||
# OA_STORAGE_B2_PREFIX=attachments
|
||||
|
||||
# Email Configuration (optional)
|
||||
# OA_MAILGUNDOMAIN=
|
||||
# OA_MAILGUNKEY=
|
||||
# OA_MAILGUNEMAIL=
|
||||
# OA_MAILGUNSENDER=
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -97,3 +97,6 @@ config.json
|
||||
*.csr
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
.vscode/
|
||||
server
|
||||
attachments/
|
||||
|
||||
64
Dockerfile
Normal file
64
Dockerfile
Normal file
@@ -0,0 +1,64 @@
|
||||
# Build stage
|
||||
FROM golang:1.24-alpine AS builder
|
||||
|
||||
# Install build dependencies for CGO (needed for SQLite)
|
||||
RUN apk add --no-cache git gcc musl-dev
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy go mod files
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
# Download dependencies
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -o server ./core/
|
||||
|
||||
# Final stage
|
||||
FROM alpine:latest
|
||||
|
||||
# Install ca-certificates for HTTPS and sqlite for database
|
||||
RUN apk --no-cache add ca-certificates sqlite
|
||||
|
||||
# Create app user for security
|
||||
RUN adduser -D -s /bin/sh appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy binary from builder stage
|
||||
COPY --from=builder /app/server .
|
||||
|
||||
# Create data directory for SQLite
|
||||
RUN mkdir -p /app/data && chown appuser:appuser /app/data
|
||||
|
||||
# Copy config sample (optional)
|
||||
COPY config.json.sample .
|
||||
|
||||
# Change ownership to app user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Expose port (default 8080, can be overridden with OA_PORT)
|
||||
EXPOSE 8080
|
||||
|
||||
# Health check - requires Accept-Version header
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider --header="Accept-Version: v1" http://localhost:8080/ || exit 1
|
||||
|
||||
# Set default environment variables
|
||||
ENV OA_DATABASE_DRIVER=sqlite \
|
||||
OA_DATABASE_FILE=/app/data/openaccounting.db \
|
||||
OA_ADDRESS=0.0.0.0 \
|
||||
OA_PORT=8080 \
|
||||
OA_API_PREFIX=/api/v1
|
||||
|
||||
# Run the application
|
||||
CMD ["./server"]
|
||||
485
README.md
485
README.md
@@ -1,30 +1,491 @@
|
||||
# Open Accounting Server
|
||||
|
||||
Open Accounting Server is a modern financial accounting system built with Go, featuring GORM integration, Viper configuration management, and Docker support.
|
||||
|
||||
## Features
|
||||
|
||||
- **GORM Integration**: Modern ORM with SQLite and MySQL support
|
||||
- **Viper Configuration**: Flexible config management with environment variables
|
||||
- **Modular Storage**: S3-compatible attachment storage (Local, AWS S3, Backblaze B2, Cloudflare R2, MinIO)
|
||||
- **Docker Ready**: Containerized deployment with multi-stage builds
|
||||
- **SQLite Support**: Easy local development and testing
|
||||
- **Security**: Environment variable support for sensitive data
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Go 1.8+
|
||||
2. MySQL 5.7+
|
||||
- **Go 1.24+** (updated from 1.8+)
|
||||
- **SQLite** (for development) or **MySQL 5.7+** (for production)
|
||||
- **Docker** (optional, for containerized deployment)
|
||||
- **Just** (optional, for build automation)
|
||||
|
||||
## Database setup
|
||||
## Quick Start
|
||||
|
||||
Use schema.sql and indexes.sql to create a MySQL database to store Open Accounting data.
|
||||
### Using Just (Recommended)
|
||||
|
||||
```bash
|
||||
# Setup development environment
|
||||
just dev-setup
|
||||
|
||||
# Run in development mode
|
||||
just run-dev
|
||||
|
||||
# Build and run with Docker
|
||||
just docker-run
|
||||
```
|
||||
|
||||
### Manual Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
go mod download
|
||||
|
||||
# Run with SQLite (development)
|
||||
OA_DATABASE_DRIVER=sqlite ./server
|
||||
|
||||
# Run with MySQL (production)
|
||||
OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret ./server
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Copy config.json.sample to config.json and edit to match your information.
|
||||
The server now uses **Viper** for advanced configuration management with multiple sources:
|
||||
|
||||
## Run
|
||||
### Configuration Sources (in order of precedence)
|
||||
|
||||
`go run core/server.go`
|
||||
1. **Environment Variables** (highest priority)
|
||||
2. **Config Files**: `config.json`, `config.yaml`, `config.toml`
|
||||
3. **Default Values** (lowest priority)
|
||||
|
||||
## Build
|
||||
### Config File Locations
|
||||
|
||||
`go build core/server.go`
|
||||
- `./config.json` (current directory)
|
||||
- `/etc/openaccounting/config.json`
|
||||
- `~/.openaccounting/config.json`
|
||||
|
||||
### Environment Variables
|
||||
|
||||
All configuration can be overridden with environment variables using the `OA_` prefix:
|
||||
|
||||
| Environment Variable | Config Field | Default | Description |
|
||||
|---------------------|--------------|---------|-------------|
|
||||
| `OA_ADDRESS` | Address | `localhost` | Server bind address |
|
||||
| `OA_PORT` | Port | `8080` | Server port |
|
||||
| `OA_API_PREFIX` | ApiPrefix | `/api/v1` | API route prefix |
|
||||
| `OA_DATABASE_DRIVER` | DatabaseDriver | `sqlite` | Database type: `sqlite` or `mysql` |
|
||||
| `OA_DATABASE_FILE` | DatabaseFile | `./openaccounting.db` | SQLite database file |
|
||||
| `OA_DATABASE_ADDRESS` | DatabaseAddress | `localhost:3306` | MySQL server address |
|
||||
| `OA_DATABASE` | Database | | MySQL database name |
|
||||
| `OA_USER` | User | | Database username |
|
||||
| `OA_PASSWORD` | Password | | Database password ⚠️ |
|
||||
| `OA_MAILGUN_DOMAIN` | MailgunDomain | | Mailgun domain |
|
||||
| `OA_MAILGUN_KEY` | MailgunKey | | Mailgun API key ⚠️ |
|
||||
| `OA_MAILGUN_EMAIL` | MailgunEmail | | Mailgun email |
|
||||
| `OA_MAILGUN_SENDER` | MailgunSender | | Mailgun sender name |
|
||||
|
||||
#### Storage Configuration
|
||||
|
||||
| Environment Variable | Config Field | Default | Description |
|
||||
|---------------------|--------------|---------|-------------|
|
||||
| `OA_STORAGE_BACKEND` | Storage.Backend | `local` | Storage backend: `local` or `s3` |
|
||||
|
||||
**Local Storage**
|
||||
| Environment Variable | Config Field | Default | Description |
|
||||
|---------------------|--------------|---------|-------------|
|
||||
| `OA_STORAGE_LOCAL_ROOTDIR` | Storage.Local.RootDir | `./uploads` | Root directory for file storage |
|
||||
| `OA_STORAGE_LOCAL_BASEURL` | Storage.Local.BaseURL | | Base URL for serving files |
|
||||
|
||||
**S3-Compatible Storage** (AWS S3, Backblaze B2, Cloudflare R2, MinIO)
|
||||
| Environment Variable | Config Field | Default | Description |
|
||||
|---------------------|--------------|---------|-------------|
|
||||
| `OA_STORAGE_S3_REGION` | Storage.S3.Region | | Region (use "auto" for Cloudflare R2) |
|
||||
| `OA_STORAGE_S3_BUCKET` | Storage.S3.Bucket | | Bucket name |
|
||||
| `OA_STORAGE_S3_PREFIX` | Storage.S3.Prefix | | Optional prefix for all objects |
|
||||
| `OA_STORAGE_S3_ACCESSKEYID` | Storage.S3.AccessKeyID | | Access Key ID ⚠️ |
|
||||
| `OA_STORAGE_S3_SECRETACCESSKEY` | Storage.S3.SecretAccessKey | | Secret Access Key ⚠️ |
|
||||
| `OA_STORAGE_S3_ENDPOINT` | Storage.S3.Endpoint | | Custom endpoint (see examples below) |
|
||||
| `OA_STORAGE_S3_PATHSTYLE` | Storage.S3.PathStyle | `false` | Use path-style addressing |
|
||||
|
||||
**S3-Compatible Service Endpoints:**
|
||||
- **AWS S3**: Leave endpoint empty, set appropriate region
|
||||
- **Backblaze B2**: `https://s3.us-west-004.backblazeb2.com` (replace region as needed)
|
||||
- **Cloudflare R2**: `https://<account-id>.r2.cloudflarestorage.com`
|
||||
- **MinIO**: `http://localhost:9000` (or your MinIO server URL)
|
||||
|
||||
⚠️ **Security**: Always use environment variables for sensitive data like passwords and API keys.
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Development (SQLite)
|
||||
```bash
|
||||
# Minimal - uses defaults
|
||||
./server
|
||||
|
||||
# Custom database file and port
|
||||
OA_DATABASE_FILE=./dev.db OA_PORT=9090 ./server
|
||||
```
|
||||
|
||||
#### Production (MySQL)
|
||||
```bash
|
||||
# With environment variables (recommended)
|
||||
export OA_DATABASE_DRIVER=mysql
|
||||
export OA_DATABASE_ADDRESS=db.example.com:3306
|
||||
export OA_DATABASE=openaccounting_prod
|
||||
export OA_USER=openaccounting
|
||||
export OA_PASSWORD=secure_password
|
||||
export OA_MAILGUN_KEY=key-abc123
|
||||
./server
|
||||
|
||||
# Or inline
|
||||
OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret OA_MAILGUN_KEY=key-123 ./server
|
||||
```
|
||||
|
||||
#### Storage Configuration Examples
|
||||
```bash
|
||||
# Local storage (default)
|
||||
export OA_STORAGE_BACKEND=local
|
||||
export OA_STORAGE_LOCAL_ROOTDIR=./uploads
|
||||
./server
|
||||
|
||||
# AWS S3
|
||||
export OA_STORAGE_BACKEND=s3
|
||||
export OA_STORAGE_S3_REGION=us-west-2
|
||||
export OA_STORAGE_S3_BUCKET=my-app-attachments
|
||||
export OA_STORAGE_S3_ACCESSKEYID=your-access-key
|
||||
export OA_STORAGE_S3_SECRETACCESSKEY=your-secret-key
|
||||
./server
|
||||
|
||||
# Backblaze B2 (S3-compatible)
|
||||
export OA_STORAGE_BACKEND=s3
|
||||
export OA_STORAGE_S3_REGION=us-west-004
|
||||
export OA_STORAGE_S3_BUCKET=my-app-attachments
|
||||
export OA_STORAGE_S3_ACCESSKEYID=your-b2-key-id
|
||||
export OA_STORAGE_S3_SECRETACCESSKEY=your-b2-application-key
|
||||
export OA_STORAGE_S3_ENDPOINT=https://s3.us-west-004.backblazeb2.com
|
||||
export OA_STORAGE_S3_PATHSTYLE=true
|
||||
./server
|
||||
|
||||
# Cloudflare R2
|
||||
export OA_STORAGE_BACKEND=s3
|
||||
export OA_STORAGE_S3_REGION=auto
|
||||
export OA_STORAGE_S3_BUCKET=my-app-attachments
|
||||
export OA_STORAGE_S3_ACCESSKEYID=your-r2-access-key
|
||||
export OA_STORAGE_S3_SECRETACCESSKEY=your-r2-secret-key
|
||||
export OA_STORAGE_S3_ENDPOINT=https://your-account-id.r2.cloudflarestorage.com
|
||||
./server
|
||||
|
||||
# MinIO (self-hosted)
|
||||
export OA_STORAGE_BACKEND=s3
|
||||
export OA_STORAGE_S3_REGION=us-east-1
|
||||
export OA_STORAGE_S3_BUCKET=my-app-attachments
|
||||
export OA_STORAGE_S3_ACCESSKEYID=minioadmin
|
||||
export OA_STORAGE_S3_SECRETACCESSKEY=minioadmin
|
||||
export OA_STORAGE_S3_ENDPOINT=http://localhost:9000
|
||||
export OA_STORAGE_S3_PATHSTYLE=true
|
||||
./server
|
||||
```
|
||||
|
||||
#### Docker
|
||||
```bash
|
||||
# SQLite with volume mount
|
||||
docker run -p 8080:8080 \
|
||||
-e OA_DATABASE_DRIVER=sqlite \
|
||||
-v ./data:/app/data \
|
||||
openaccounting-server:latest
|
||||
|
||||
# MySQL with environment variables
|
||||
docker run -p 8080:8080 \
|
||||
-e OA_DATABASE_DRIVER=mysql \
|
||||
-e OA_DATABASE_ADDRESS=mysql:3306 \
|
||||
-e OA_PASSWORD=secret \
|
||||
openaccounting-server:latest
|
||||
|
||||
# With AWS S3 storage
|
||||
docker run -p 8080:8080 \
|
||||
-e OA_STORAGE_BACKEND=s3 \
|
||||
-e OA_STORAGE_S3_REGION=us-west-2 \
|
||||
-e OA_STORAGE_S3_BUCKET=my-attachments \
|
||||
-e OA_STORAGE_S3_ACCESSKEYID=your-key \
|
||||
-e OA_STORAGE_S3_SECRETACCESSKEY=your-secret \
|
||||
openaccounting-server:latest
|
||||
|
||||
# With Cloudflare R2 storage
|
||||
docker run -p 8080:8080 \
|
||||
-e OA_STORAGE_BACKEND=s3 \
|
||||
-e OA_STORAGE_S3_REGION=auto \
|
||||
-e OA_STORAGE_S3_BUCKET=my-attachments \
|
||||
-e OA_STORAGE_S3_ACCESSKEYID=your-r2-key \
|
||||
-e OA_STORAGE_S3_SECRETACCESSKEY=your-r2-secret \
|
||||
-e OA_STORAGE_S3_ENDPOINT=https://account-id.r2.cloudflarestorage.com \
|
||||
openaccounting-server:latest
|
||||
```
|
||||
|
||||
## Database Setup
|
||||
|
||||
### SQLite (Development)
|
||||
|
||||
SQLite databases are created automatically. No manual setup required.
|
||||
|
||||
```bash
|
||||
# Uses ./openaccounting.db by default
|
||||
OA_DATABASE_DRIVER=sqlite ./server
|
||||
|
||||
# Custom location
|
||||
OA_DATABASE_DRIVER=sqlite OA_DATABASE_FILE=./data/myapp.db ./server
|
||||
```
|
||||
|
||||
### MySQL (Production)
|
||||
|
||||
Use the provided schema files to create your MySQL database:
|
||||
|
||||
```sql
|
||||
-- Create database and user
|
||||
CREATE DATABASE openaccounting;
|
||||
CREATE USER 'openaccounting'@'%' IDENTIFIED BY 'secure_password';
|
||||
GRANT ALL PRIVILEGES ON openaccounting.* TO 'openaccounting'@'%';
|
||||
```
|
||||
|
||||
The server will automatically create tables and run migrations on startup.
|
||||
|
||||
## Building
|
||||
|
||||
### Local Build
|
||||
|
||||
```bash
|
||||
# Development build
|
||||
go build -o server ./core/
|
||||
|
||||
# Production build (optimized)
|
||||
CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -ldflags="-w -s" -o server ./core/
|
||||
```
|
||||
|
||||
### Docker Build
|
||||
|
||||
```bash
|
||||
# Build image
|
||||
docker build -t openaccounting-server:latest .
|
||||
|
||||
# Multi-platform build
|
||||
docker buildx build --platform linux/amd64,linux/arm64 -t openaccounting-server:latest .
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
### Development
|
||||
|
||||
```bash
|
||||
# Local with SQLite
|
||||
just run-dev
|
||||
|
||||
# Or manually
|
||||
OA_DATABASE_DRIVER=sqlite OA_PORT=8080 ./server
|
||||
```
|
||||
|
||||
### Production
|
||||
|
||||
```bash
|
||||
# With Docker Compose (recommended)
|
||||
docker-compose up -d
|
||||
|
||||
# Or manually with environment file
|
||||
export $(cat .env | xargs)
|
||||
./server
|
||||
```
|
||||
|
||||
## Just Recipes
|
||||
|
||||
This project includes a `justfile` with common tasks:
|
||||
|
||||
```bash
|
||||
just --list # Show all available recipes
|
||||
just build # Build the application
|
||||
just run-dev # Run in development mode
|
||||
just docker-build # Build Docker image
|
||||
just docker-run # Run container
|
||||
just test # Run tests
|
||||
just config-help # Show configuration help
|
||||
just dev-setup # Complete development setup
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
The server provides a REST API at `/api/v1/` (configurable via `OA_API_PREFIX`).
|
||||
|
||||
### Health Check
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/v1/health
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
|
||||
```bash
|
||||
# Install Go dependencies
|
||||
go mod download
|
||||
|
||||
# Install development tools (optional)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
just test
|
||||
# or
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
just fmt
|
||||
|
||||
# Lint code (requires golangci-lint)
|
||||
just lint
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
If you are interested in running Open Accounting via Docker, @alokmenghrajani has created a [repo](https://github.com/alokmenghrajani/openaccounting-docker) for this.
|
||||
### Official Images
|
||||
|
||||
## Help
|
||||
Docker images are available with multi-stage builds for optimal size and security:
|
||||
|
||||
[Join our Slack chatroom](https://join.slack.com/t/openaccounting/shared_invite/zt-23zy988e8-93HP1GfLDB7osoQ6umpfiA) and talk with us!
|
||||
- Non-root user for security
|
||||
- Alpine Linux base for minimal attack surface
|
||||
- Health checks included
|
||||
- Volume support for data persistence
|
||||
|
||||
### Environment Variables in Docker
|
||||
|
||||
```dockerfile
|
||||
ENV OA_DATABASE_DRIVER=sqlite \
|
||||
OA_DATABASE_FILE=/app/data/openaccounting.db \
|
||||
OA_ADDRESS=0.0.0.0 \
|
||||
OA_PORT=8080
|
||||
```
|
||||
|
||||
### Data Persistence
|
||||
|
||||
```bash
|
||||
# Mount volume for SQLite data
|
||||
docker run -v ./data:/app/data openaccounting-server:latest
|
||||
|
||||
# Use named volume
|
||||
docker volume create openaccounting-data
|
||||
docker run -v openaccounting-data:/app/data openaccounting-server:latest
|
||||
```
|
||||
|
||||
## Deployment
|
||||
|
||||
### Docker Compose
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
openaccounting:
|
||||
image: openaccounting-server:latest
|
||||
ports:
|
||||
- "8080:8080"
|
||||
environment:
|
||||
OA_DATABASE_DRIVER: mysql
|
||||
OA_DATABASE_ADDRESS: mysql:3306
|
||||
OA_DATABASE: openaccounting
|
||||
OA_USER: openaccounting
|
||||
OA_PASSWORD: ${DB_PASSWORD}
|
||||
depends_on:
|
||||
- mysql
|
||||
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
environment:
|
||||
MYSQL_DATABASE: openaccounting
|
||||
MYSQL_USER: openaccounting
|
||||
MYSQL_PASSWORD: ${DB_PASSWORD}
|
||||
MYSQL_ROOT_PASSWORD: ${DB_ROOT_PASSWORD}
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
```
|
||||
|
||||
### Kubernetes
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: openaccounting-server
|
||||
spec:
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: openaccounting-server
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: openaccounting-server
|
||||
spec:
|
||||
containers:
|
||||
- name: openaccounting-server
|
||||
image: openaccounting-server:latest
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
env:
|
||||
- name: OA_DATABASE_DRIVER
|
||||
value: "mysql"
|
||||
- name: OA_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: openaccounting-secrets
|
||||
key: db-password
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Config file not found**: The server will use environment variables and defaults if no config file is found
|
||||
2. **Database connection failed**: Check your database credentials and connectivity
|
||||
3. **Permission denied**: Ensure proper file permissions for SQLite database files
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
# Enable verbose logging
|
||||
OA_LOG_LEVEL=debug ./server
|
||||
|
||||
# Check configuration
|
||||
just config-help
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
```bash
|
||||
# Application health
|
||||
curl http://localhost:8080/api/v1/health
|
||||
|
||||
# Docker health check
|
||||
docker inspect --format='{{.State.Health.Status}}' container_name
|
||||
```
|
||||
|
||||
## Migration from Legacy Setup
|
||||
|
||||
The server maintains backward compatibility with existing `config.json` files while adding Viper features:
|
||||
|
||||
1. Existing `config.json` files continue to work
|
||||
2. Add environment variables for sensitive data
|
||||
3. Use SQLite for easier local development
|
||||
4. Leverage Docker for production deployments
|
||||
|
||||
## Help & Support
|
||||
|
||||
- **Documentation**: This README and inline code comments
|
||||
- **Issues**: GitHub Issues for bug reports and feature requests
|
||||
- **Community**: [Join our Slack chatroom](https://join.slack.com/t/openaccounting/shared_invite/zt-23zy988e8-93HP1GfLDB7osoQ6umpfiA)
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
239
STORAGE.md
Normal file
239
STORAGE.md
Normal file
@@ -0,0 +1,239 @@
|
||||
# Modular Storage System
|
||||
|
||||
The OpenAccounting server now supports multiple storage backends for file attachments. This allows you to choose between local filesystem storage for simple deployments or cloud storage for production/multi-user environments.
|
||||
|
||||
## Supported Storage Backends
|
||||
|
||||
### 1. Local Filesystem Storage
|
||||
Perfect for self-hosted deployments or development environments.
|
||||
|
||||
**Configuration:**
|
||||
```json
|
||||
{
|
||||
"storage": {
|
||||
"backend": "local",
|
||||
"local": {
|
||||
"root_dir": "./uploads",
|
||||
"base_url": "https://yourapp.com/files"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
```bash
|
||||
OA_STORAGE_BACKEND=local
|
||||
OA_STORAGE_LOCAL_ROOT_DIR=./uploads
|
||||
OA_STORAGE_LOCAL_BASE_URL=https://yourapp.com/files
|
||||
```
|
||||
|
||||
### 2. Amazon S3 Storage
|
||||
Reliable cloud storage for production deployments.
|
||||
|
||||
**Configuration:**
|
||||
```json
|
||||
{
|
||||
"storage": {
|
||||
"backend": "s3",
|
||||
"s3": {
|
||||
"region": "us-east-1",
|
||||
"bucket": "my-openaccounting-attachments",
|
||||
"prefix": "attachments",
|
||||
"access_key_id": "AKIA...",
|
||||
"secret_access_key": "...",
|
||||
"endpoint": "",
|
||||
"path_style": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
```bash
|
||||
OA_STORAGE_BACKEND=s3
|
||||
OA_STORAGE_S3_REGION=us-east-1
|
||||
OA_STORAGE_S3_BUCKET=my-openaccounting-attachments
|
||||
OA_STORAGE_S3_PREFIX=attachments
|
||||
OA_STORAGE_S3_ACCESS_KEY_ID=AKIA...
|
||||
OA_STORAGE_S3_SECRET_ACCESS_KEY=...
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Automatic presigned URL generation
|
||||
- Configurable expiry times
|
||||
- Support for S3-compatible services (MinIO, DigitalOcean Spaces)
|
||||
- IAM role support (leave credentials empty to use IAM)
|
||||
|
||||
### 3. Backblaze B2 Storage
|
||||
Cost-effective cloud storage alternative to S3.
|
||||
|
||||
**Configuration:**
|
||||
```json
|
||||
{
|
||||
"storage": {
|
||||
"backend": "b2",
|
||||
"b2": {
|
||||
"account_id": "your-b2-account-id",
|
||||
"application_key": "your-b2-application-key",
|
||||
"bucket": "my-openaccounting-attachments",
|
||||
"prefix": "attachments"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
```bash
|
||||
OA_STORAGE_BACKEND=b2
|
||||
OA_STORAGE_B2_ACCOUNT_ID=your-b2-account-id
|
||||
OA_STORAGE_B2_APPLICATION_KEY=your-b2-application-key
|
||||
OA_STORAGE_B2_BUCKET=my-openaccounting-attachments
|
||||
OA_STORAGE_B2_PREFIX=attachments
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The storage system provides both legacy and new endpoints:
|
||||
|
||||
### New Storage-Agnostic Endpoints
|
||||
|
||||
**Upload Attachment:**
|
||||
```
|
||||
POST /api/v1/attachments
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
transactionId: uuid
|
||||
description: string (optional)
|
||||
file: binary data
|
||||
```
|
||||
|
||||
**Get Attachment Metadata:**
|
||||
```
|
||||
GET /api/v1/attachments/{id}
|
||||
```
|
||||
|
||||
**Get Download URL:**
|
||||
```
|
||||
GET /api/v1/attachments/{id}/url
|
||||
```
|
||||
|
||||
**Download File:**
|
||||
```
|
||||
GET /api/v1/attachments/{id}?download=true
|
||||
```
|
||||
|
||||
**Delete Attachment:**
|
||||
```
|
||||
DELETE /api/v1/attachments/{id}
|
||||
```
|
||||
|
||||
### Legacy Endpoints (Still Supported)
|
||||
The original transaction-scoped endpoints remain available for backward compatibility:
|
||||
- `GET/POST /api/v1/orgs/{orgId}/transactions/{transactionId}/attachments`
|
||||
|
||||
## Security Features
|
||||
|
||||
- **File type validation** - Only allowed MIME types are accepted
|
||||
- **File size limits** - Configurable maximum file size (default 10MB)
|
||||
- **Path traversal protection** - Prevents directory traversal attacks
|
||||
- **Access control** - Files are linked to users and organizations
|
||||
- **Presigned URLs** - Time-limited access for cloud storage
|
||||
|
||||
## File Organization
|
||||
|
||||
Files are automatically organized by date:
|
||||
```
|
||||
uploads/
|
||||
├── 2025/
|
||||
│ ├── 01/
|
||||
│ │ ├── 15/
|
||||
│ │ │ ├── uuid1.pdf
|
||||
│ │ │ └── uuid2.png
|
||||
│ │ └── 16/
|
||||
│ │ └── uuid3.jpg
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Development (Local Storage)
|
||||
```json
|
||||
{
|
||||
"storage": {
|
||||
"backend": "local",
|
||||
"local": {
|
||||
"root_dir": "./dev-uploads"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Production (S3 with IAM)
|
||||
```json
|
||||
{
|
||||
"storage": {
|
||||
"backend": "s3",
|
||||
"s3": {
|
||||
"region": "us-west-2",
|
||||
"bucket": "prod-openaccounting-files",
|
||||
"prefix": "attachments"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Cost-Optimized (Backblaze B2)
|
||||
```json
|
||||
{
|
||||
"storage": {
|
||||
"backend": "b2",
|
||||
"b2": {
|
||||
"account_id": "${B2_ACCOUNT_ID}",
|
||||
"application_key": "${B2_APP_KEY}",
|
||||
"bucket": "openaccounting-prod"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Migration Between Storage Backends
|
||||
|
||||
When changing storage backends, existing attachments will remain in the old storage location. The database records contain the storage path, so files can be accessed until migrated.
|
||||
|
||||
To migrate:
|
||||
1. Update configuration to new backend
|
||||
2. Restart server
|
||||
3. New uploads will use the new backend
|
||||
4. Optional: Run migration script to move existing files
|
||||
|
||||
## Environment-Specific Considerations
|
||||
|
||||
### Self-Hosted
|
||||
- Use local storage for simplicity
|
||||
- Ensure backup strategy includes upload directory
|
||||
- Consider disk space management
|
||||
|
||||
### Cloud Deployment
|
||||
- Use S3 or B2 for reliability and scalability
|
||||
- Configure proper IAM policies
|
||||
- Enable versioning and lifecycle policies
|
||||
|
||||
### Multi-Region
|
||||
- Use cloud storage with appropriate region selection
|
||||
- Consider CDN integration for better performance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Storage backend not initialized:**
|
||||
- Check configuration syntax
|
||||
- Verify credentials for cloud backends
|
||||
- Ensure storage directories/buckets exist
|
||||
|
||||
**Permission denied:**
|
||||
- Check file system permissions for local storage
|
||||
- Verify IAM policies for S3
|
||||
- Confirm B2 application key permissions
|
||||
|
||||
**Large file uploads failing:**
|
||||
- Check `MaxFileSize` configuration
|
||||
- Verify network timeouts
|
||||
- Consider multipart upload for large files
|
||||
17
config.b2.json.sample
Normal file
17
config.b2.json.sample
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"weburl": "https://yourapp.com",
|
||||
"address": "localhost",
|
||||
"port": 8080,
|
||||
"apiprefix": "/api/v1",
|
||||
"databasedriver": "sqlite",
|
||||
"databasefile": "./openaccounting.db",
|
||||
"storage": {
|
||||
"backend": "b2",
|
||||
"b2": {
|
||||
"account_id": "your-b2-account-id",
|
||||
"application_key": "your-b2-application-key",
|
||||
"bucket": "my-openaccounting-attachments",
|
||||
"prefix": "attachments"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,31 @@
|
||||
{
|
||||
"_comment_config": "OpenAccounting Server Configuration - now supports Viper for multiple config sources",
|
||||
"_comment_viper": "You can override any setting with environment variables using OA_ prefix (e.g., OA_PASSWORD, OA_MAILGUN_KEY)",
|
||||
|
||||
"WebUrl": "https://domain.com",
|
||||
"Address": "",
|
||||
"Address": "localhost",
|
||||
"Port": 8080,
|
||||
"ApiPrefix": "",
|
||||
"ApiPrefix": "/api/v1",
|
||||
"KeyFile": "",
|
||||
"CertFile": "",
|
||||
"DatabaseAddress": "",
|
||||
|
||||
"_comment_database": "Database configuration - choose 'sqlite' for local testing or 'mysql' for production",
|
||||
"DatabaseDriver": "sqlite",
|
||||
"DatabaseFile": "./data/openaccounting.db",
|
||||
"DatabaseAddress": "localhost:3306",
|
||||
"Database": "openaccounting",
|
||||
"User": "openaccounting",
|
||||
"Password": "openaccounting",
|
||||
"Password": "",
|
||||
"_comment_password": "SECURITY: Set password via OA_PASSWORD environment variable instead of this file",
|
||||
|
||||
"_comment_mailgun": "Mailgun configuration for email sending",
|
||||
"MailgunDomain": "mg.domain.com",
|
||||
"MailgunKey": "",
|
||||
"_comment_mailgun_key": "SECURITY: Set Mailgun key via OA_MAILGUN_KEY environment variable",
|
||||
"MailgunEmail": "noreply@domain.com",
|
||||
"MailgunSender": "Sender"
|
||||
"MailgunSender": "Sender",
|
||||
|
||||
"_comment_env_examples": "Environment variable examples:",
|
||||
"_example_development": "OA_DATABASE_DRIVER=sqlite OA_DATABASE_FILE=./dev.db ./server",
|
||||
"_example_production": "OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret OA_MAILGUN_KEY=key-123 ./server"
|
||||
}
|
||||
|
||||
19
config.mysql.json.sample
Normal file
19
config.mysql.json.sample
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"WebUrl": "https://domain.com",
|
||||
"Address": "",
|
||||
"Port": 8080,
|
||||
"ApiPrefix": "",
|
||||
"KeyFile": "",
|
||||
"CertFile": "",
|
||||
"_comment_database": "MySQL configuration for production use",
|
||||
"DatabaseDriver": "mysql",
|
||||
"DatabaseAddress": "localhost:3306",
|
||||
"Database": "openaccounting",
|
||||
"User": "openaccounting",
|
||||
"Password": "openaccounting",
|
||||
"_comment_mailgun": "Mailgun configuration for email sending",
|
||||
"MailgunDomain": "mg.domain.com",
|
||||
"MailgunKey": "",
|
||||
"MailgunEmail": "noreply@domain.com",
|
||||
"MailgunSender": "Sender"
|
||||
}
|
||||
20
config.s3.json.sample
Normal file
20
config.s3.json.sample
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"weburl": "https://yourapp.com",
|
||||
"address": "localhost",
|
||||
"port": 8080,
|
||||
"apiprefix": "/api/v1",
|
||||
"databasedriver": "sqlite",
|
||||
"databasefile": "./openaccounting.db",
|
||||
"storage": {
|
||||
"backend": "s3",
|
||||
"s3": {
|
||||
"region": "us-east-1",
|
||||
"bucket": "my-openaccounting-attachments",
|
||||
"prefix": "attachments",
|
||||
"access_key_id": "",
|
||||
"secret_access_key": "",
|
||||
"endpoint": "",
|
||||
"path_style": false
|
||||
}
|
||||
}
|
||||
}
|
||||
15
config.storage.json.sample
Normal file
15
config.storage.json.sample
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"weburl": "https://yourapp.com",
|
||||
"address": "localhost",
|
||||
"port": 8080,
|
||||
"apiprefix": "/api/v1",
|
||||
"databasedriver": "sqlite",
|
||||
"databasefile": "./openaccounting.db",
|
||||
"storage": {
|
||||
"backend": "local",
|
||||
"local": {
|
||||
"root_dir": "./uploads",
|
||||
"base_url": "https://yourapp.com/files"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -12,48 +12,7 @@ import (
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
)
|
||||
|
||||
/**
|
||||
* @api {get} /orgs/:orgId/accounts Get Accounts by Org id
|
||||
* @apiVersion 1.4.0
|
||||
* @apiName GetOrgAccounts
|
||||
* @apiGroup Account
|
||||
*
|
||||
* @apiHeader {String} Authorization HTTP Basic Auth
|
||||
* @apiHeader {String} Accept-Version ^1.4.0 semver versioning
|
||||
*
|
||||
* @apiSuccess {String} id Id of the Account.
|
||||
* @apiSuccess {String} orgId Id of the Org.
|
||||
* @apiSuccess {Date} inserted Date Account was created
|
||||
* @apiSuccess {Date} updated Date Account was updated
|
||||
* @apiSuccess {String} name Name of the Account.
|
||||
* @apiSuccess {String} parent Id of the parent Account.
|
||||
* @apiSuccess {String} currency Three letter currency code.
|
||||
* @apiSuccess {Number} precision How many digits the currency goes out to.
|
||||
* @apiSuccess {Boolean} debitBalance True if Account has a debit balance.
|
||||
* @apiSuccess {Number} balance Current Account balance in this Account's currency
|
||||
* @apiSuccess {Number} nativeBalance Current Account balance in the Org's currency
|
||||
*
|
||||
* @apiSuccessExample Success-Response:
|
||||
* HTTP/1.1 200 OK
|
||||
* [
|
||||
* {
|
||||
* "id": "22222222222222222222222222222222",
|
||||
* "orgId": "11111111111111111111111111111111",
|
||||
* "inserted": "2018-09-11T18:05:04.420Z",
|
||||
* "updated": "2018-09-11T18:05:04.420Z",
|
||||
* "name": "Cash",
|
||||
* "parent": "11111111111111111111111111111111",
|
||||
* "currency": "USD",
|
||||
* "precision": 2,
|
||||
* "debitBalance": true,
|
||||
* "balance": 10000,
|
||||
* "nativeBalance": 10000
|
||||
* }
|
||||
* ]
|
||||
*
|
||||
* @apiUse NotAuthorizedError
|
||||
* @apiUse InternalServerError
|
||||
*/
|
||||
// GetOrgAccounts /**
|
||||
func GetOrgAccounts(w rest.ResponseWriter, r *rest.Request) {
|
||||
user := r.Env["USER"].(*types.User)
|
||||
orgId := r.PathParam("orgId")
|
||||
@@ -208,7 +167,7 @@ func PostAccount(w rest.ResponseWriter, r *rest.Request) {
|
||||
user := r.Env["USER"].(*types.User)
|
||||
orgId := r.PathParam("orgId")
|
||||
|
||||
content, err := ioutil.ReadAll(r.Body)
|
||||
content, err := io.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
|
||||
313
core/api/attachment.go
Normal file
313
core/api/attachment.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ant0ine/go-json-rest/rest"
|
||||
"github.com/openaccounting/oa-server/core/model"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxFileSize = 10 * 1024 * 1024 // 10MB
|
||||
MaxFilesPerTx = 10
|
||||
AttachmentDir = "attachments"
|
||||
)
|
||||
|
||||
var AllowedMimeTypes = map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"application/pdf": true,
|
||||
"text/plain": true,
|
||||
"text/csv": true,
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true, // .xlsx
|
||||
"application/vnd.ms-excel": true, // .xls
|
||||
}
|
||||
|
||||
func PostAttachment(w rest.ResponseWriter, r *rest.Request) {
|
||||
orgId := r.PathParam("orgId")
|
||||
transactionId := r.PathParam("transactionId")
|
||||
|
||||
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) {
|
||||
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
// Parse multipart form
|
||||
err := r.ParseMultipartForm(MaxFileSize)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
files := r.MultipartForm.File["files"]
|
||||
if len(files) == 0 {
|
||||
rest.Error(w, "No files provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) > MaxFilesPerTx {
|
||||
rest.Error(w, fmt.Sprintf("Too many files. Maximum %d files allowed", MaxFilesPerTx), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify transaction exists and user has permission
|
||||
tx, err := model.Instance.GetTransaction(transactionId, orgId, user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Transaction not found or access denied", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if tx == nil {
|
||||
rest.Error(w, "Transaction not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var attachments []*types.Attachment
|
||||
var description string
|
||||
if desc := r.FormValue("description"); desc != "" {
|
||||
description = desc
|
||||
}
|
||||
|
||||
for _, fileHeader := range files {
|
||||
attachment, err := processFileUpload(fileHeader, transactionId, orgId, user.Id, description)
|
||||
if err != nil {
|
||||
// Clean up any successfully uploaded files
|
||||
for _, att := range attachments {
|
||||
os.Remove(att.FilePath)
|
||||
}
|
||||
rest.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Save attachment to database
|
||||
createdAttachment, err := model.Instance.CreateAttachment(attachment)
|
||||
if err != nil {
|
||||
// Clean up file and any previously uploaded files
|
||||
os.Remove(attachment.FilePath)
|
||||
for _, att := range attachments {
|
||||
os.Remove(att.FilePath)
|
||||
}
|
||||
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
attachments = append(attachments, createdAttachment)
|
||||
}
|
||||
|
||||
w.WriteJson(map[string]interface{}{
|
||||
"attachments": attachments,
|
||||
"count": len(attachments),
|
||||
})
|
||||
}
|
||||
|
||||
func GetAttachments(w rest.ResponseWriter, r *rest.Request) {
|
||||
orgId := r.PathParam("orgId")
|
||||
transactionId := r.PathParam("transactionId")
|
||||
|
||||
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) {
|
||||
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionId, orgId, user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to retrieve attachments", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteJson(attachments)
|
||||
}
|
||||
|
||||
func GetAttachment(w rest.ResponseWriter, r *rest.Request) {
|
||||
orgId := r.PathParam("orgId")
|
||||
transactionId := r.PathParam("transactionId")
|
||||
attachmentId := r.PathParam("attachmentId")
|
||||
|
||||
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
|
||||
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Attachment not found or access denied", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteJson(attachment)
|
||||
}
|
||||
|
||||
func DownloadAttachment(w rest.ResponseWriter, r *rest.Request) {
|
||||
orgId := r.PathParam("orgId")
|
||||
transactionId := r.PathParam("transactionId")
|
||||
attachmentId := r.PathParam("attachmentId")
|
||||
|
||||
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
|
||||
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Attachment not found or access denied", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(attachment.FilePath); os.IsNotExist(err) {
|
||||
rest.Error(w, "File not found on disk", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for file download
|
||||
w.Header().Set("Content-Type", attachment.ContentType)
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
|
||||
|
||||
// Open and serve file
|
||||
file, err := os.Open(attachment.FilePath)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to open file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
io.Copy(w.(http.ResponseWriter), file)
|
||||
}
|
||||
|
||||
func DeleteAttachment(w rest.ResponseWriter, r *rest.Request) {
|
||||
orgId := r.PathParam("orgId")
|
||||
transactionId := r.PathParam("transactionId")
|
||||
attachmentId := r.PathParam("attachmentId")
|
||||
|
||||
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
|
||||
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
err := model.Instance.DeleteAttachment(attachmentId, transactionId, orgId, user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to delete attachment or access denied", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteJson(map[string]string{"status": "deleted"})
|
||||
}
|
||||
|
||||
func processFileUpload(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
|
||||
// Validate file size
|
||||
if fileHeader.Size > MaxFileSize {
|
||||
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
|
||||
}
|
||||
|
||||
// Open uploaded file
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open uploaded file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Validate file type from header
|
||||
contentType := fileHeader.Header.Get("Content-Type")
|
||||
if !AllowedMimeTypes[contentType] {
|
||||
return nil, fmt.Errorf("file type %s not allowed", contentType)
|
||||
}
|
||||
|
||||
// Validate file type by detecting content (more secure)
|
||||
buffer := make([]byte, 512)
|
||||
n, err := file.Read(buffer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file for content detection: %v", err)
|
||||
}
|
||||
|
||||
// Reset file pointer to beginning
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset file pointer: %v", err)
|
||||
}
|
||||
|
||||
detectedType := http.DetectContentType(buffer[:n])
|
||||
if !AllowedMimeTypes[detectedType] {
|
||||
return nil, fmt.Errorf("detected file type %s not allowed (header claimed %s)", detectedType, contentType)
|
||||
}
|
||||
|
||||
// Generate unique filename
|
||||
attachmentId := util.NewUUID()
|
||||
ext := filepath.Ext(fileHeader.Filename)
|
||||
fileName := attachmentId + ext
|
||||
|
||||
// Create attachments directory if it doesn't exist
|
||||
uploadDir := filepath.Join(AttachmentDir, orgId, transactionId)
|
||||
if err := os.MkdirAll(uploadDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create upload directory: %v", err)
|
||||
}
|
||||
|
||||
// Create file path
|
||||
filePath := filepath.Join(uploadDir, fileName)
|
||||
|
||||
// Create destination file
|
||||
dst, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create destination file: %v", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Copy file contents
|
||||
if _, err := io.Copy(dst, file); err != nil {
|
||||
return nil, fmt.Errorf("failed to save file: %v", err)
|
||||
}
|
||||
|
||||
// Create attachment object
|
||||
attachment := &types.Attachment{
|
||||
Id: attachmentId,
|
||||
TransactionId: transactionId,
|
||||
OrgId: orgId,
|
||||
UserId: userId,
|
||||
FileName: fileName,
|
||||
OriginalName: fileHeader.Filename,
|
||||
ContentType: contentType,
|
||||
FileSize: fileHeader.Size,
|
||||
FilePath: filePath,
|
||||
Description: description,
|
||||
Uploaded: time.Now(),
|
||||
Deleted: false,
|
||||
}
|
||||
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func sanitizeFilename(filename string) string {
|
||||
// Remove potentially dangerous characters
|
||||
filename = strings.ReplaceAll(filename, "..", "")
|
||||
filename = strings.ReplaceAll(filename, "/", "")
|
||||
filename = strings.ReplaceAll(filename, "\\", "")
|
||||
filename = strings.ReplaceAll(filename, "\x00", "") // null bytes
|
||||
filename = strings.ReplaceAll(filename, "\r", "") // carriage return
|
||||
filename = strings.ReplaceAll(filename, "\n", "") // newline
|
||||
|
||||
// Limit filename length
|
||||
if len(filename) > 255 {
|
||||
ext := filepath.Ext(filename)
|
||||
base := filename[:255-len(ext)]
|
||||
filename = base + ext
|
||||
}
|
||||
|
||||
return filename
|
||||
}
|
||||
306
core/api/attachment_integration_test.go
Normal file
306
core/api/attachment_integration_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
"github.com/openaccounting/oa-server/database"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupTestDatabase(t *testing.T) (*gorm.DB, func()) {
|
||||
// Create temporary database file
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
// Set global DB for database package
|
||||
database.DB = db
|
||||
|
||||
// Run migrations
|
||||
err = database.AutoMigrate()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run auto migrations: %v", err)
|
||||
}
|
||||
|
||||
err = database.Migrate()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run custom migrations: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup function
|
||||
cleanup := func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
func setupTestData(t *testing.T, db *gorm.DB) (orgID, userID, transactionID string) {
|
||||
// Use hardcoded UUIDs without dashes for hex format
|
||||
orgID = "550e8400e29b41d4a716446655440000"
|
||||
userID = "550e8400e29b41d4a716446655440001"
|
||||
transactionID = "550e8400e29b41d4a716446655440002"
|
||||
accountID := "550e8400e29b41d4a716446655440003"
|
||||
|
||||
// Insert test data using raw SQL for reliability
|
||||
now := time.Now()
|
||||
|
||||
// Insert org
|
||||
err := db.Exec("INSERT INTO orgs (id, inserted, updated, name, currency, `precision`, timezone) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?)",
|
||||
orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Org", "USD", 2, "UTC").Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert org: %v", err)
|
||||
}
|
||||
|
||||
// Insert user
|
||||
err = db.Exec("INSERT INTO users (id, inserted, updated, firstName, lastName, email, passwordHash, agreeToTerms, passwordReset, emailVerified, emailVerifyCode, signupSource) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test", "User", "test@example.com", "hashedpassword", true, "", true, "", "test").Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert user: %v", err)
|
||||
}
|
||||
|
||||
// Insert user-org relationship
|
||||
err = db.Exec("INSERT INTO user_orgs (userId, orgId, admin) VALUES (UNHEX(?), UNHEX(?), ?)",
|
||||
userID, orgID, false).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert user-org: %v", err)
|
||||
}
|
||||
|
||||
// Insert account
|
||||
err = db.Exec("INSERT INTO accounts (id, orgId, inserted, updated, name, parent, currency, `precision`, debitBalance) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?, ?)",
|
||||
accountID, orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Account", []byte{}, "USD", 2, true).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert account: %v", err)
|
||||
}
|
||||
|
||||
// Insert transaction
|
||||
err = db.Exec("INSERT INTO transactions (id, orgId, userId, inserted, updated, date, description, data, deleted) VALUES (UNHEX(?), UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
|
||||
transactionID, orgID, userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Transaction", "", false).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert transaction: %v", err)
|
||||
}
|
||||
|
||||
// Insert split
|
||||
err = db.Exec("INSERT INTO splits (transactionId, accountId, date, inserted, updated, amount, nativeAmount, deleted) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
|
||||
transactionID, accountID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), 100, 100, false).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert split: %v", err)
|
||||
}
|
||||
|
||||
return orgID, userID, transactionID
|
||||
}
|
||||
|
||||
func createTestFile(t *testing.T) (string, []byte) {
|
||||
content := []byte("This is a test file content for attachment testing")
|
||||
tmpDir := t.TempDir()
|
||||
filePath := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
err := os.WriteFile(filePath, content, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
return filePath, content
|
||||
}
|
||||
|
||||
|
||||
func TestAttachmentIntegration(t *testing.T) {
|
||||
db, cleanup := setupTestDatabase(t)
|
||||
defer cleanup()
|
||||
|
||||
orgID, userID, transactionID := setupTestData(t, db)
|
||||
|
||||
// Set up the model instance for the API handlers
|
||||
bc := &util.StandardBcrypt{}
|
||||
|
||||
// Use the existing datastore model which has the attachment implementation
|
||||
// We need to create it with the database connection
|
||||
datastoreModel := model.NewModel(nil, bc, types.Config{})
|
||||
model.Instance = datastoreModel
|
||||
|
||||
t.Run("Database Integration Test", func(t *testing.T) {
|
||||
// Test direct database operations first
|
||||
filePath, originalContent := createTestFile(t)
|
||||
defer os.Remove(filePath)
|
||||
|
||||
// Create attachment record directly
|
||||
attachmentID := id.String(id.New())
|
||||
uploadTime := time.Now()
|
||||
|
||||
attachment := types.Attachment{
|
||||
Id: attachmentID,
|
||||
TransactionId: transactionID,
|
||||
OrgId: orgID,
|
||||
UserId: userID,
|
||||
FileName: "stored_test.txt",
|
||||
OriginalName: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
FileSize: int64(len(originalContent)),
|
||||
FilePath: "uploads/test/" + attachmentID + ".txt",
|
||||
Description: "Test attachment description",
|
||||
Uploaded: uploadTime,
|
||||
Deleted: false,
|
||||
}
|
||||
|
||||
// Insert using the existing model
|
||||
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, createdAttachment)
|
||||
assert.Equal(t, attachmentID, createdAttachment.Id)
|
||||
|
||||
// Verify database persistence
|
||||
var dbAttachment types.Attachment
|
||||
err = db.Raw("SELECT HEX(id) as id, HEX(transactionId) as transactionId, HEX(orgId) as orgId, HEX(userId) as userId, fileName, originalName, contentType, fileSize, filePath, description, uploaded, deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&dbAttachment).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, attachmentID, dbAttachment.Id)
|
||||
assert.Equal(t, transactionID, dbAttachment.TransactionId)
|
||||
assert.Equal(t, "Test attachment description", dbAttachment.Description)
|
||||
|
||||
// Test retrieval
|
||||
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, retrievedAttachment)
|
||||
assert.Equal(t, attachmentID, retrievedAttachment.Id)
|
||||
|
||||
// Test listing by transaction
|
||||
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionID, orgID, userID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, attachments, 1)
|
||||
assert.Equal(t, attachmentID, attachments[0].Id)
|
||||
|
||||
// Test soft deletion
|
||||
err = model.Instance.DeleteAttachment(attachmentID, transactionID, orgID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify soft deletion in database
|
||||
var deletedAttachment types.Attachment
|
||||
err = db.Raw("SELECT deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&deletedAttachment).Error
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, deletedAttachment.Deleted)
|
||||
|
||||
// Verify attachment is no longer accessible
|
||||
retrievedAttachment, err = model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, retrievedAttachment)
|
||||
})
|
||||
|
||||
t.Run("File Upload Integration Test", func(t *testing.T) {
|
||||
// Test file upload functionality
|
||||
filePath, originalContent := createTestFile(t)
|
||||
defer os.Remove(filePath)
|
||||
|
||||
// Create upload directory
|
||||
uploadDir := "uploads/test"
|
||||
os.MkdirAll(uploadDir, 0755)
|
||||
defer os.RemoveAll("uploads")
|
||||
|
||||
// Simulate file upload process
|
||||
attachmentID := id.String(id.New())
|
||||
storedFilePath := filepath.Join(uploadDir, attachmentID+".txt")
|
||||
|
||||
// Copy file to upload location
|
||||
err := copyFile(filePath, storedFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create attachment record
|
||||
attachment := types.Attachment{
|
||||
Id: attachmentID,
|
||||
TransactionId: transactionID,
|
||||
OrgId: orgID,
|
||||
UserId: userID,
|
||||
FileName: filepath.Base(storedFilePath),
|
||||
OriginalName: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
FileSize: int64(len(originalContent)),
|
||||
FilePath: storedFilePath,
|
||||
Description: "Uploaded test file",
|
||||
Uploaded: time.Now(),
|
||||
Deleted: false,
|
||||
}
|
||||
|
||||
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, createdAttachment)
|
||||
|
||||
// Verify file exists
|
||||
_, err = os.Stat(storedFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify database record
|
||||
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, storedFilePath, retrievedAttachment.FilePath)
|
||||
assert.Equal(t, int64(len(originalContent)), retrievedAttachment.FileSize)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to copy files
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, sourceFile)
|
||||
return err
|
||||
}
|
||||
|
||||
func TestAttachmentValidation(t *testing.T) {
|
||||
db, cleanup := setupTestDatabase(t)
|
||||
defer cleanup()
|
||||
|
||||
orgID, userID, transactionID := setupTestData(t, db)
|
||||
|
||||
// Set up the model instance
|
||||
bc := &util.StandardBcrypt{}
|
||||
gormModel := model.NewGormModel(db, bc, types.Config{})
|
||||
model.Instance = gormModel
|
||||
|
||||
t.Run("Invalid attachment data", func(t *testing.T) {
|
||||
// Test with missing required fields
|
||||
attachment := types.Attachment{
|
||||
// Missing ID
|
||||
TransactionId: transactionID,
|
||||
OrgId: orgID,
|
||||
UserId: userID,
|
||||
}
|
||||
|
||||
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, createdAttachment)
|
||||
})
|
||||
|
||||
t.Run("Non-existent attachment retrieval", func(t *testing.T) {
|
||||
nonExistentID := id.String(id.New())
|
||||
|
||||
attachment, err := model.Instance.GetAttachment(nonExistentID, transactionID, orgID, userID)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, attachment)
|
||||
})
|
||||
}
|
||||
289
core/api/attachment_storage.go
Normal file
289
core/api/attachment_storage.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ant0ine/go-json-rest/rest"
|
||||
"github.com/openaccounting/oa-server/core/model"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/storage"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
)
|
||||
|
||||
// AttachmentHandler handles attachment operations with configurable storage
|
||||
type AttachmentHandler struct {
|
||||
storage storage.Storage
|
||||
}
|
||||
|
||||
// Global attachment handler instance (will be initialized during server startup)
|
||||
var attachmentHandler *AttachmentHandler
|
||||
|
||||
// InitializeAttachmentHandler initializes the global attachment handler with storage backend
|
||||
func InitializeAttachmentHandler(storageConfig storage.Config) error {
|
||||
storageBackend, err := storage.NewStorage(storageConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize storage backend: %w", err)
|
||||
}
|
||||
|
||||
attachmentHandler = &AttachmentHandler{
|
||||
storage: storageBackend,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostAttachmentWithStorage handles file upload using the configured storage backend
|
||||
func PostAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
|
||||
if attachmentHandler == nil {
|
||||
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
transactionId := r.FormValue("transactionId")
|
||||
if transactionId == "" {
|
||||
rest.Error(w, "Transaction ID is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !util.IsValidUUID(transactionId) {
|
||||
rest.Error(w, "Invalid transaction ID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
// Parse multipart form
|
||||
err := r.ParseMultipartForm(MaxFileSize)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
files := r.MultipartForm.File["file"]
|
||||
if len(files) == 0 {
|
||||
rest.Error(w, "No file provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
fileHeader := files[0] // Take the first file
|
||||
|
||||
// Verify transaction exists and user has permission
|
||||
tx, err := model.Instance.GetTransaction(transactionId, "", user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Transaction not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if tx == nil {
|
||||
rest.Error(w, "Transaction not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the file upload
|
||||
attachment, err := attachmentHandler.processFileUploadWithStorage(fileHeader, transactionId, tx.OrgId, user.Id, r.FormValue("description"))
|
||||
if err != nil {
|
||||
rest.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Save attachment to database
|
||||
createdAttachment, err := model.Instance.CreateAttachment(attachment)
|
||||
if err != nil {
|
||||
// Clean up the stored file on database error
|
||||
attachmentHandler.storage.Delete(attachment.FilePath)
|
||||
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.WriteJson(createdAttachment)
|
||||
}
|
||||
|
||||
// GetAttachmentWithStorage retrieves an attachment using the configured storage backend
|
||||
func GetAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
|
||||
if attachmentHandler == nil {
|
||||
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
attachmentId := r.PathParam("id")
|
||||
if !util.IsValidUUID(attachmentId) {
|
||||
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
// Get attachment from database
|
||||
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Attachment not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a download request
|
||||
if r.URL.Query().Get("download") == "true" {
|
||||
// Stream the file directly to the client
|
||||
err := attachmentHandler.streamFile(w, attachment)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to retrieve file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return attachment metadata
|
||||
w.WriteJson(attachment)
|
||||
}
|
||||
|
||||
// GetAttachmentDownloadURL returns a download URL for an attachment
|
||||
func GetAttachmentDownloadURL(w rest.ResponseWriter, r *rest.Request) {
|
||||
if attachmentHandler == nil {
|
||||
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
attachmentId := r.PathParam("id")
|
||||
if !util.IsValidUUID(attachmentId) {
|
||||
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
// Get attachment from database
|
||||
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Attachment not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate download URL (valid for 1 hour)
|
||||
url, err := attachmentHandler.storage.GetURL(attachment.FilePath, time.Hour)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to generate download URL", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]string{
|
||||
"url": url,
|
||||
"expiresIn": "3600", // 1 hour in seconds
|
||||
}
|
||||
|
||||
w.WriteJson(response)
|
||||
}
|
||||
|
||||
// DeleteAttachmentWithStorage deletes an attachment using the configured storage backend
|
||||
func DeleteAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
|
||||
if attachmentHandler == nil {
|
||||
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
attachmentId := r.PathParam("id")
|
||||
if !util.IsValidUUID(attachmentId) {
|
||||
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user := r.Env["USER"].(*types.User)
|
||||
|
||||
// Get attachment from database first
|
||||
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Attachment not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete from database (soft delete)
|
||||
err = model.Instance.DeleteAttachment(attachmentId, attachment.TransactionId, attachment.OrgId, user.Id)
|
||||
if err != nil {
|
||||
rest.Error(w, "Failed to delete attachment", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete from storage backend
|
||||
// Note: For production, you might want to delay physical deletion
|
||||
// and run a cleanup job later to handle any issues
|
||||
err = attachmentHandler.storage.Delete(attachment.FilePath)
|
||||
if err != nil {
|
||||
// Log the error but don't fail the request since database deletion succeeded
|
||||
// The file can be cleaned up later by a maintenance job
|
||||
fmt.Printf("Warning: Failed to delete file from storage: %v\n", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.WriteJson(map[string]string{"status": "deleted"})
|
||||
}
|
||||
|
||||
// processFileUploadWithStorage processes a file upload using the storage backend
|
||||
func (h *AttachmentHandler) processFileUploadWithStorage(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
|
||||
// Validate file size
|
||||
if fileHeader.Size > MaxFileSize {
|
||||
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
|
||||
}
|
||||
|
||||
// Validate content type
|
||||
contentType := fileHeader.Header.Get("Content-Type")
|
||||
if !AllowedMimeTypes[contentType] {
|
||||
return nil, fmt.Errorf("unsupported file type: %s", contentType)
|
||||
}
|
||||
|
||||
// Open the file
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Store the file using the storage backend
|
||||
storagePath, err := h.storage.Store(fileHeader.Filename, file, contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store file: %w", err)
|
||||
}
|
||||
|
||||
// Create attachment record
|
||||
attachment := &types.Attachment{
|
||||
Id: id.String(id.New()),
|
||||
TransactionId: transactionId,
|
||||
OrgId: orgId,
|
||||
UserId: userId,
|
||||
FileName: storagePath, // Store the storage path/key
|
||||
OriginalName: fileHeader.Filename,
|
||||
ContentType: contentType,
|
||||
FileSize: fileHeader.Size,
|
||||
FilePath: storagePath, // For backward compatibility
|
||||
Description: description,
|
||||
Uploaded: time.Now(),
|
||||
Deleted: false,
|
||||
}
|
||||
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
// streamFile streams a file from storage to the HTTP response
|
||||
func (h *AttachmentHandler) streamFile(w rest.ResponseWriter, attachment *types.Attachment) error {
|
||||
// Get file from storage
|
||||
reader, err := h.storage.Retrieve(attachment.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve file: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Set appropriate headers
|
||||
w.Header().Set("Content-Type", attachment.ContentType)
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
|
||||
|
||||
// If we know the file size, set Content-Length
|
||||
if attachment.FileSize > 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", attachment.FileSize))
|
||||
}
|
||||
|
||||
// Stream the file to the client
|
||||
_, err = io.Copy(w.(http.ResponseWriter), reader)
|
||||
return err
|
||||
}
|
||||
@@ -31,6 +31,17 @@ func GetRouter(auth *AuthMiddleware, prefix string) (rest.App, error) {
|
||||
rest.Post(prefix+"/orgs/:orgId/transactions", auth.RequireAuth(PostTransaction)),
|
||||
rest.Put(prefix+"/orgs/:orgId/transactions/:transactionId", auth.RequireAuth(PutTransaction)),
|
||||
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId", auth.RequireAuth(DeleteTransaction)),
|
||||
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments", auth.RequireAuth(GetAttachments)),
|
||||
rest.Post(prefix+"/orgs/:orgId/transactions/:transactionId/attachments", auth.RequireAuth(PostAttachment)),
|
||||
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(GetAttachment)),
|
||||
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId/download", auth.RequireAuth(DownloadAttachment)),
|
||||
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(DeleteAttachment)),
|
||||
|
||||
// New storage-based attachment endpoints
|
||||
rest.Post(prefix+"/attachments", auth.RequireAuth(PostAttachmentWithStorage)),
|
||||
rest.Get(prefix+"/attachments/:id", auth.RequireAuth(GetAttachmentWithStorage)),
|
||||
rest.Get(prefix+"/attachments/:id/url", auth.RequireAuth(GetAttachmentDownloadURL)),
|
||||
rest.Delete(prefix+"/attachments/:id", auth.RequireAuth(DeleteAttachmentWithStorage)),
|
||||
rest.Get(prefix+"/orgs/:orgId/prices", auth.RequireAuth(GetPrices)),
|
||||
rest.Post(prefix+"/orgs/:orgId/prices", auth.RequireAuth(PostPrice)),
|
||||
rest.Delete(prefix+"/orgs/:orgId/prices/:priceId", auth.RequireAuth(DeletePrice)),
|
||||
|
||||
@@ -14,6 +14,22 @@ type AuthService struct {
|
||||
bcrypt util.Bcrypt
|
||||
}
|
||||
|
||||
// AuthRepository interface for dependency injection
|
||||
type AuthRepository interface {
|
||||
GetVerifiedUserByEmail(string) (*types.User, error)
|
||||
GetUserByActiveSession(string) (*types.User, error)
|
||||
GetUserByApiKey(string) (*types.User, error)
|
||||
GetUserByEmailVerifyCode(string) (*types.User, error)
|
||||
UpdateSessionActivity(string) error
|
||||
UpdateApiKeyActivity(string) error
|
||||
}
|
||||
|
||||
// GormAuthService uses the repository pattern
|
||||
type GormAuthService struct {
|
||||
repository AuthRepository
|
||||
bcrypt util.Bcrypt
|
||||
}
|
||||
|
||||
type Interface interface {
|
||||
Authenticate(string, string) (*types.User, error)
|
||||
AuthenticateUser(email string, password string) (*types.User, error)
|
||||
@@ -28,6 +44,12 @@ func NewAuthService(db db.Datastore, bcrypt util.Bcrypt) *AuthService {
|
||||
return authService
|
||||
}
|
||||
|
||||
func NewGormAuthService(repository AuthRepository, bcrypt util.Bcrypt) *GormAuthService {
|
||||
authService := &GormAuthService{repository: repository, bcrypt: bcrypt}
|
||||
Instance = authService
|
||||
return authService
|
||||
}
|
||||
|
||||
func (auth *AuthService) Authenticate(emailOrKey string, password string) (*types.User, error) {
|
||||
// authenticate via session, apikey or user
|
||||
user, err := auth.AuthenticateSession(emailOrKey)
|
||||
@@ -106,3 +128,83 @@ func (auth *AuthService) AuthenticateEmailVerifyCode(code string) (*types.User,
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GormAuthService implementations
|
||||
func (auth *GormAuthService) Authenticate(emailOrKey string, password string) (*types.User, error) {
|
||||
// authenticate via session, apikey or user
|
||||
user, err := auth.AuthenticateSession(emailOrKey)
|
||||
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
user, err = auth.AuthenticateApiKey(emailOrKey)
|
||||
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
user, err = auth.AuthenticateUser(emailOrKey, password)
|
||||
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
user, err = auth.AuthenticateEmailVerifyCode(emailOrKey)
|
||||
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("Unauthorized")
|
||||
}
|
||||
|
||||
func (auth *GormAuthService) AuthenticateUser(email string, password string) (*types.User, error) {
|
||||
u, err := auth.repository.GetVerifiedUserByEmail(email)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("Invalid email or password")
|
||||
}
|
||||
|
||||
err = auth.bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("Invalid email or password")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (auth *GormAuthService) AuthenticateSession(id string) (*types.User, error) {
|
||||
u, err := auth.repository.GetUserByActiveSession(id)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("Invalid session")
|
||||
}
|
||||
|
||||
auth.repository.UpdateSessionActivity(id)
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (auth *GormAuthService) AuthenticateApiKey(id string) (*types.User, error) {
|
||||
u, err := auth.repository.GetUserByApiKey(id)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("Access denied")
|
||||
}
|
||||
|
||||
auth.repository.UpdateApiKeyActivity(id)
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (auth *GormAuthService) AuthenticateEmailVerifyCode(code string) (*types.User, error) {
|
||||
u, err := auth.repository.GetUserByEmailVerifyCode(code)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("Access denied")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
@@ -2,12 +2,13 @@ package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TdUser struct {
|
||||
@@ -28,18 +29,19 @@ func (td *TdUser) GetVerifiedUserByEmail(email string) (*types.User, error) {
|
||||
|
||||
func (td *TdUser) GetVerifiedUserByEmail_1(email string) (*types.User, error) {
|
||||
return &types.User{
|
||||
"1",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
"John",
|
||||
"Doe",
|
||||
"johndoe@email.com",
|
||||
"password",
|
||||
"$2a$10$KrtvADe7jwrmYIe3GXFbNupOQaPIvyOKeng5826g4VGOD47TpAisG",
|
||||
true,
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
Id: "1",
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
Email: "johndoe@email.com",
|
||||
Password: "password",
|
||||
PasswordHash: "$2a$10$KrtvADe7jwrmYIe3GXFbNupOQaPIvyOKeng5826g4VGOD47TpAisG",
|
||||
AgreeToTerms: true,
|
||||
PasswordReset: "",
|
||||
EmailVerified: false,
|
||||
EmailVerifyCode: "",
|
||||
SignupSource: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,31 @@ type Datastore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// DeleteBudget implements db.Datastore.
|
||||
func (_m *Datastore) DeleteBudget(string) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetBudget implements db.Datastore.
|
||||
func (_m *Datastore) GetBudget(string) (*types.Budget, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetUserByEmailVerifyCode implements db.Datastore.
|
||||
func (_m *Datastore) GetUserByEmailVerifyCode(string) (*types.User, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// InsertAndReplaceBudget implements db.Datastore.
|
||||
func (_m *Datastore) InsertAndReplaceBudget(*types.Budget) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Ping implements db.Datastore.
|
||||
func (_m *Datastore) Ping() error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// AcceptInvite provides a mock function with given fields: _a0, _a1
|
||||
func (_m *Datastore) AcceptInvite(_a0 *types.Invite, _a1 string) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
@@ -968,3 +993,74 @@ func (_m *Datastore) VerifyUser(_a0 string) error {
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Attachment interface mock methods
|
||||
func (_m *Datastore) InsertAttachment(_a0 *types.Attachment) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*types.Attachment) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
func (_m *Datastore) GetAttachment(_a0 string, _a1 string, _a2 string) (*types.Attachment, error) {
|
||||
ret := _m.Called(_a0, _a1, _a2)
|
||||
|
||||
var r0 *types.Attachment
|
||||
if rf, ok := ret.Get(0).(func(string, string, string) *types.Attachment); ok {
|
||||
r0 = rf(_a0, _a1, _a2)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*types.Attachment)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string, string, string) error); ok {
|
||||
r1 = rf(_a0, _a1, _a2)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (_m *Datastore) GetAttachmentsByTransaction(_a0 string, _a1 string) ([]*types.Attachment, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 []*types.Attachment
|
||||
if rf, ok := ret.Get(0).(func(string, string) []*types.Attachment); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*types.Attachment)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string, string) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (_m *Datastore) DeleteAttachment(_a0 string, _a1 string, _a2 string) error {
|
||||
ret := _m.Called(_a0, _a1, _a2)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, string, string) error); ok {
|
||||
r0 = rf(_a0, _a1, _a2)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -408,6 +408,15 @@ func (model *Model) accountsContainWriteAccess(accounts []*types.Account, accoun
|
||||
return false
|
||||
}
|
||||
|
||||
func (model *Model) accountsContainReadAccess(accounts []*types.Account, accountId string) bool {
|
||||
for _, account := range accounts {
|
||||
if account.Id == accountId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (model *Model) getAccountFromList(accounts []*types.Account, accountId string) *types.Account {
|
||||
for _, account := range accounts {
|
||||
if account.Id == accountId {
|
||||
|
||||
@@ -162,6 +162,10 @@ func TestCreateAccount(t *testing.T) {
|
||||
|
||||
td := &TdAccount{}
|
||||
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
|
||||
// Mock GetSplitCountByAccountId for parent account check
|
||||
if test.account.Parent != "" {
|
||||
td.On("GetSplitCountByAccountId", test.account.Parent).Return(int64(0), nil)
|
||||
}
|
||||
|
||||
model := NewModel(td, nil, types.Config{})
|
||||
|
||||
@@ -206,6 +210,10 @@ func TestUpdateAccount(t *testing.T) {
|
||||
|
||||
td := &TdAccount{}
|
||||
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
|
||||
// Mock GetSplitCountByAccountId for parent account check
|
||||
if test.account.Parent != "" {
|
||||
td.On("GetSplitCountByAccountId", test.account.Parent).Return(int64(0), nil)
|
||||
}
|
||||
|
||||
model := NewModel(td, nil, types.Config{})
|
||||
|
||||
|
||||
163
core/model/attachment.go
Normal file
163
core/model/attachment.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
)
|
||||
|
||||
type AttachmentInterface interface {
|
||||
CreateAttachment(*types.Attachment) (*types.Attachment, error)
|
||||
GetAttachmentsByTransaction(string, string, string) ([]*types.Attachment, error)
|
||||
GetAttachment(string, string, string, string) (*types.Attachment, error)
|
||||
DeleteAttachment(string, string, string, string) error
|
||||
}
|
||||
|
||||
func (model *Model) CreateAttachment(attachment *types.Attachment) (*types.Attachment, error) {
|
||||
if attachment.Id == "" {
|
||||
return nil, errors.New("attachment ID required")
|
||||
}
|
||||
|
||||
if attachment.TransactionId == "" {
|
||||
return nil, errors.New("transaction ID required")
|
||||
}
|
||||
|
||||
if attachment.OrgId == "" {
|
||||
return nil, errors.New("organization ID required")
|
||||
}
|
||||
|
||||
if attachment.UserId == "" {
|
||||
return nil, errors.New("user ID required")
|
||||
}
|
||||
|
||||
if attachment.FileName == "" {
|
||||
return nil, errors.New("file name required")
|
||||
}
|
||||
|
||||
if attachment.FilePath == "" {
|
||||
return nil, errors.New("file path required")
|
||||
}
|
||||
|
||||
// Set upload timestamp
|
||||
attachment.Uploaded = time.Now()
|
||||
attachment.Deleted = false
|
||||
|
||||
// Save to database
|
||||
err := model.db.InsertAttachment(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func (model *Model) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
|
||||
if transactionId == "" {
|
||||
return nil, errors.New("transaction ID required")
|
||||
}
|
||||
|
||||
if orgId == "" {
|
||||
return nil, errors.New("organization ID required")
|
||||
}
|
||||
|
||||
if userId == "" {
|
||||
return nil, errors.New("user ID required")
|
||||
}
|
||||
|
||||
// First verify the user has access to the transaction
|
||||
tx, err := model.GetTransaction(transactionId, orgId, userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx == nil {
|
||||
return nil, errors.New("transaction not found or access denied")
|
||||
}
|
||||
|
||||
// Get attachments for the transaction
|
||||
attachments, err := model.db.GetAttachmentsByTransaction(transactionId, orgId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return attachments, nil
|
||||
}
|
||||
|
||||
func (model *Model) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
|
||||
if attachmentId == "" {
|
||||
return nil, errors.New("attachment ID required")
|
||||
}
|
||||
|
||||
if transactionId == "" {
|
||||
return nil, errors.New("transaction ID required")
|
||||
}
|
||||
|
||||
if orgId == "" {
|
||||
return nil, errors.New("organization ID required")
|
||||
}
|
||||
|
||||
if userId == "" {
|
||||
return nil, errors.New("user ID required")
|
||||
}
|
||||
|
||||
// First verify the user has access to the transaction
|
||||
tx, err := model.GetTransaction(transactionId, orgId, userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx == nil {
|
||||
return nil, errors.New("transaction not found or access denied")
|
||||
}
|
||||
|
||||
// Get the attachment
|
||||
attachment, err := model.db.GetAttachment(attachmentId, transactionId, orgId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func (model *Model) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
|
||||
if attachmentId == "" {
|
||||
return errors.New("attachment ID required")
|
||||
}
|
||||
|
||||
if transactionId == "" {
|
||||
return errors.New("transaction ID required")
|
||||
}
|
||||
|
||||
if orgId == "" {
|
||||
return errors.New("organization ID required")
|
||||
}
|
||||
|
||||
if userId == "" {
|
||||
return errors.New("user ID required")
|
||||
}
|
||||
|
||||
// First verify the user has access to the transaction
|
||||
tx, err := model.GetTransaction(transactionId, orgId, userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tx == nil {
|
||||
return errors.New("transaction not found or access denied")
|
||||
}
|
||||
|
||||
// Verify the attachment exists and belongs to the transaction
|
||||
attachment, err := model.db.GetAttachment(attachmentId, transactionId, orgId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if attachment == nil {
|
||||
return errors.New("attachment not found")
|
||||
}
|
||||
|
||||
// Soft delete the attachment
|
||||
err = model.db.DeleteAttachment(attachmentId, transactionId, orgId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
)
|
||||
|
||||
@@ -18,8 +19,8 @@ func (model *Model) GetBudget(orgId string, userId string) (*types.Budget, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if belongs == false {
|
||||
return nil, errors.New("User does not belong to org")
|
||||
if !belongs {
|
||||
return nil, errors.New("user does not belong to org")
|
||||
}
|
||||
|
||||
return model.db.GetBudget(orgId)
|
||||
@@ -32,8 +33,8 @@ func (model *Model) CreateBudget(budget *types.Budget, userId string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if belongs == false {
|
||||
return errors.New("User does not belong to org")
|
||||
if !belongs {
|
||||
return errors.New("user does not belong to org")
|
||||
}
|
||||
|
||||
if budget.OrgId == "" {
|
||||
@@ -50,8 +51,8 @@ func (model *Model) DeleteBudget(orgId string, userId string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if belongs == false {
|
||||
return errors.New("User does not belong to org")
|
||||
if !belongs {
|
||||
return errors.New("user does not belong to org")
|
||||
}
|
||||
|
||||
return model.db.DeleteBudget(orgId)
|
||||
|
||||
126
core/model/db/attachment.go
Normal file
126
core/model/db/attachment.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
)
|
||||
|
||||
const attachmentFields = "LOWER(HEX(id)),LOWER(HEX(transactionId)),LOWER(HEX(orgId)),LOWER(HEX(userId)),fileName,originalName,contentType,fileSize,filePath,description,uploaded,deleted"
|
||||
|
||||
type AttachmentInterface interface {
|
||||
InsertAttachment(*types.Attachment) error
|
||||
GetAttachment(string, string, string) (*types.Attachment, error)
|
||||
GetAttachmentsByTransaction(string, string) ([]*types.Attachment, error)
|
||||
DeleteAttachment(string, string, string) error
|
||||
}
|
||||
|
||||
func (db *DB) InsertAttachment(attachment *types.Attachment) error {
|
||||
query := "INSERT INTO attachment(id,transactionId,orgId,userId,fileName,originalName,contentType,fileSize,filePath,description,uploaded,deleted) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),UNHEX(?),?,?,?,?,?,?,?,?)"
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
attachment.Id,
|
||||
attachment.TransactionId,
|
||||
attachment.OrgId,
|
||||
attachment.UserId,
|
||||
attachment.FileName,
|
||||
attachment.OriginalName,
|
||||
attachment.ContentType,
|
||||
attachment.FileSize,
|
||||
attachment.FilePath,
|
||||
attachment.Description,
|
||||
util.TimeToMs(attachment.Uploaded),
|
||||
attachment.Deleted,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) GetAttachment(attachmentId, transactionId, orgId string) (*types.Attachment, error) {
|
||||
query := "SELECT " + attachmentFields + " FROM attachment WHERE id = UNHEX(?) AND transactionId = UNHEX(?) AND orgId = UNHEX(?) AND deleted = false"
|
||||
row := db.QueryRow(query, attachmentId, transactionId, orgId)
|
||||
|
||||
return db.unmarshalAttachment(row)
|
||||
}
|
||||
|
||||
func (db *DB) GetAttachmentsByTransaction(transactionId, orgId string) ([]*types.Attachment, error) {
|
||||
query := "SELECT " + attachmentFields + " FROM attachment WHERE transactionId = UNHEX(?) AND orgId = UNHEX(?) AND deleted = false ORDER BY uploaded DESC"
|
||||
rows, err := db.Query(query, transactionId, orgId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.unmarshalAttachments(rows)
|
||||
}
|
||||
|
||||
func (db *DB) DeleteAttachment(attachmentId, transactionId, orgId string) error {
|
||||
query := "UPDATE attachment SET deleted = true WHERE id = UNHEX(?) AND transactionId = UNHEX(?) AND orgId = UNHEX(?)"
|
||||
_, err := db.Exec(query, attachmentId, transactionId, orgId)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalAttachment(row *sql.Row) (*types.Attachment, error) {
|
||||
attachment := &types.Attachment{}
|
||||
var uploaded int64
|
||||
|
||||
err := row.Scan(
|
||||
&attachment.Id,
|
||||
&attachment.TransactionId,
|
||||
&attachment.OrgId,
|
||||
&attachment.UserId,
|
||||
&attachment.FileName,
|
||||
&attachment.OriginalName,
|
||||
&attachment.ContentType,
|
||||
&attachment.FileSize,
|
||||
&attachment.FilePath,
|
||||
&attachment.Description,
|
||||
&uploaded,
|
||||
&attachment.Deleted,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attachment.Uploaded = util.MsToTime(uploaded)
|
||||
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalAttachments(rows *sql.Rows) ([]*types.Attachment, error) {
|
||||
defer rows.Close()
|
||||
|
||||
attachments := []*types.Attachment{}
|
||||
|
||||
for rows.Next() {
|
||||
attachment := &types.Attachment{}
|
||||
var uploaded int64
|
||||
|
||||
err := rows.Scan(
|
||||
&attachment.Id,
|
||||
&attachment.TransactionId,
|
||||
&attachment.OrgId,
|
||||
&attachment.UserId,
|
||||
&attachment.FileName,
|
||||
&attachment.OriginalName,
|
||||
&attachment.ContentType,
|
||||
&attachment.FileSize,
|
||||
&attachment.FilePath,
|
||||
&attachment.Description,
|
||||
&uploaded,
|
||||
&attachment.Deleted,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attachment.Uploaded = util.MsToTime(uploaded)
|
||||
attachments = append(attachments, attachment)
|
||||
}
|
||||
|
||||
return attachments, nil
|
||||
}
|
||||
@@ -15,6 +15,7 @@ type Datastore interface {
|
||||
OrgInterface
|
||||
AccountInterface
|
||||
TransactionInterface
|
||||
AttachmentInterface
|
||||
PriceInterface
|
||||
SessionInterface
|
||||
ApiKeyInterface
|
||||
|
||||
353
core/model/gorm_model.go
Normal file
353
core/model/gorm_model.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/repository"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/openaccounting/oa-server/database"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormModel is the GORM-based implementation of the Model
|
||||
type GormModel struct {
|
||||
repository *repository.GormRepository
|
||||
bcrypt util.Bcrypt
|
||||
config types.Config
|
||||
}
|
||||
|
||||
// NewGormModel creates a new GORM-based model
|
||||
func NewGormModel(gormDB *gorm.DB, bcrypt util.Bcrypt, config types.Config) *GormModel {
|
||||
repo := repository.NewGormRepository(gormDB)
|
||||
return &GormModel{
|
||||
repository: repo,
|
||||
bcrypt: bcrypt,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateModel creates a new model using the existing database connection
|
||||
func CreateGormModel(bcrypt util.Bcrypt, config types.Config) (*GormModel, error) {
|
||||
// Use the existing database connection
|
||||
if database.DB == nil {
|
||||
return nil, errors.New("database connection not initialized")
|
||||
}
|
||||
|
||||
return NewGormModel(database.DB, bcrypt, config), nil
|
||||
}
|
||||
|
||||
// Implement the Interface by delegating to the business logic layer
|
||||
// The business logic layer (existing model methods) will call the repository
|
||||
|
||||
// UserInterface methods - delegate to existing business logic
|
||||
func (m *GormModel) CreateUser(user *types.User) error {
|
||||
// The existing business logic in user.go will be updated to use the repository
|
||||
// For now, delegate directly to repository for basic operations
|
||||
return m.repository.InsertUser(user)
|
||||
}
|
||||
|
||||
func (m *GormModel) VerifyUser(code string) error {
|
||||
return m.repository.VerifyUser(code)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateUser(user *types.User) error {
|
||||
return m.repository.UpdateUser(user)
|
||||
}
|
||||
|
||||
func (m *GormModel) ResetPassword(email string) error {
|
||||
// This would need the full business logic from the original model
|
||||
// For now, simplified implementation
|
||||
user, err := m.repository.GetVerifiedUserByEmail(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.PasswordReset, err = util.NewGuid()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.repository.UpdateUserResetPassword(user)
|
||||
}
|
||||
|
||||
func (m *GormModel) ConfirmResetPassword(password string, code string) (*types.User, error) {
|
||||
user, err := m.repository.GetUserByResetCode(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
passwordHash, err := m.bcrypt.GenerateFromPassword([]byte(password), m.bcrypt.GetDefaultCost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user.PasswordHash = string(passwordHash)
|
||||
user.Password = ""
|
||||
|
||||
err = m.repository.UpdateUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// AccountInterface methods - delegate to repository
|
||||
func (m *GormModel) CreateAccount(account *types.Account, userId string) error {
|
||||
return m.repository.InsertAccount(account)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateAccount(account *types.Account, userId string) error {
|
||||
return m.repository.UpdateAccount(account)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteAccount(id string, userId string, orgId string) error {
|
||||
return m.repository.DeleteAccount(id)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAccounts(orgId string, userId string, tokenId string) ([]*types.Account, error) {
|
||||
return m.repository.GetAccountsByOrgId(orgId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAccountsWithBalances(orgId string, userId string, tokenId string, date time.Time) ([]*types.Account, error) {
|
||||
accounts, err := m.repository.GetAccountsByOrgId(orgId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add balance calculations
|
||||
err = m.repository.AddBalances(accounts, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAccount(orgId, accId, userId, tokenId string) (*types.Account, error) {
|
||||
return m.repository.GetAccount(accId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAccountWithBalance(orgId, accId, userId, tokenId string, date time.Time) (*types.Account, error) {
|
||||
account, err := m.repository.GetAccount(accId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add balance calculation
|
||||
err = m.repository.AddBalance(account, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Complete OrgInterface implementation
|
||||
func (m *GormModel) CreateOrg(org *types.Org, userId string) error {
|
||||
// Get default accounts - this needs to be implemented properly
|
||||
accounts := []*types.Account{} // Empty for now, should create default chart of accounts
|
||||
return m.repository.CreateOrg(org, userId, accounts)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetOrg(orgId, userId string) (*types.Org, error) {
|
||||
return m.repository.GetOrg(orgId, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetOrgs(userId string) ([]*types.Org, error) {
|
||||
return m.repository.GetOrgs(userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateOrg(org *types.Org, userId string) error {
|
||||
return m.repository.UpdateOrg(org)
|
||||
}
|
||||
|
||||
func (m *GormModel) CreateInvite(invite *types.Invite, userId string) error {
|
||||
return m.repository.InsertInvite(invite)
|
||||
}
|
||||
|
||||
func (m *GormModel) AcceptInvite(invite *types.Invite, userId string) error {
|
||||
return m.repository.AcceptInvite(invite, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetInvites(orgId, userId string) ([]*types.Invite, error) {
|
||||
return m.repository.GetInvites(orgId)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteInvite(inviteId, userId string) error {
|
||||
return m.repository.DeleteInvite(inviteId)
|
||||
}
|
||||
|
||||
// SessionInterface implementation
|
||||
func (m *GormModel) CreateSession(session *types.Session) error {
|
||||
return m.repository.InsertSession(session)
|
||||
}
|
||||
|
||||
func (m *GormModel) InsertSession(session *types.Session) error {
|
||||
return m.repository.InsertSession(session)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteSession(sessionId, userId string) error {
|
||||
return m.repository.DeleteSession(sessionId, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateSessionActivity(sessionId string) error {
|
||||
return m.repository.UpdateSessionActivity(sessionId)
|
||||
}
|
||||
|
||||
// ApiKeyInterface implementation
|
||||
func (m *GormModel) CreateApiKey(apiKey *types.ApiKey) error {
|
||||
return m.repository.InsertApiKey(apiKey)
|
||||
}
|
||||
|
||||
func (m *GormModel) InsertApiKey(apiKey *types.ApiKey) error {
|
||||
return m.repository.InsertApiKey(apiKey)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateApiKey(apiKey *types.ApiKey) error {
|
||||
return m.repository.UpdateApiKey(apiKey)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteApiKey(keyId, userId string) error {
|
||||
return m.repository.DeleteApiKey(keyId, userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetApiKeys(userId string) ([]*types.ApiKey, error) {
|
||||
return m.repository.GetApiKeys(userId)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateApiKeyActivity(keyId string) error {
|
||||
return m.repository.UpdateApiKeyActivity(keyId)
|
||||
}
|
||||
|
||||
// TransactionInterface implementation
|
||||
func (m *GormModel) CreateTransaction(transaction *types.Transaction) error {
|
||||
return m.repository.InsertTransaction(transaction)
|
||||
}
|
||||
|
||||
func (m *GormModel) UpdateTransaction(transactionId string, transaction *types.Transaction) error {
|
||||
return m.repository.DeleteAndInsertTransaction(transactionId, transaction)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetTransactionsByAccount(accountId, orgId, userId string, options *types.QueryOptions) ([]*types.Transaction, error) {
|
||||
return m.repository.GetTransactionsByAccount(accountId, options)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetTransactionsByOrg(orgId, userId string, options *types.QueryOptions) ([]*types.Transaction, error) {
|
||||
return m.repository.GetTransactionsByOrg(orgId, options, []string{})
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteTransaction(transactionId, orgId, userId string) error {
|
||||
return m.repository.DeleteTransaction(transactionId)
|
||||
}
|
||||
|
||||
func (m *GormModel) InsertTransaction(transaction *types.Transaction) error {
|
||||
return m.repository.InsertTransaction(transaction)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetTransaction(transactionId, orgId, userId string) (*types.Transaction, error) {
|
||||
// For now, delegate to repository - in a full implementation, this would include permission checking
|
||||
return m.repository.GetTransactionById(transactionId)
|
||||
}
|
||||
|
||||
// AttachmentInterface implementation
|
||||
func (m *GormModel) CreateAttachment(attachment *types.Attachment) (*types.Attachment, error) {
|
||||
if attachment.Id == "" {
|
||||
return nil, errors.New("attachment ID required")
|
||||
}
|
||||
|
||||
// Set upload timestamp
|
||||
attachment.Uploaded = time.Now()
|
||||
attachment.Deleted = false
|
||||
|
||||
// For GORM implementation, we'd need to implement repository methods
|
||||
// For now, return an error indicating not implemented
|
||||
return nil, errors.New("attachment operations not yet implemented for GORM model")
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
|
||||
return nil, errors.New("attachment operations not yet implemented for GORM model")
|
||||
}
|
||||
|
||||
func (m *GormModel) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
|
||||
return nil, errors.New("attachment operations not yet implemented for GORM model")
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
|
||||
return errors.New("attachment operations not yet implemented for GORM model")
|
||||
}
|
||||
|
||||
func (m *GormModel) GetTransactionById(id string) (*types.Transaction, error) {
|
||||
return m.repository.GetTransactionById(id)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteAndInsertTransaction(id string, transaction *types.Transaction) error {
|
||||
return m.repository.DeleteAndInsertTransaction(id, transaction)
|
||||
}
|
||||
|
||||
// PriceInterface implementation
|
||||
func (m *GormModel) CreatePrice(price *types.Price, userId string) error {
|
||||
return m.repository.InsertPrice(price)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeletePrice(priceId, userId string) error {
|
||||
// Stub implementation - would need proper implementation
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GormModel) GetPricesNearestInTime(orgId string, date time.Time, currency string) ([]*types.Price, error) {
|
||||
// Stub implementation - would need proper implementation based on specific logic
|
||||
return m.repository.GetPrices(orgId, date)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetPricesByCurrency(orgId, currency, userId string) ([]*types.Price, error) {
|
||||
// Stub implementation - would need proper implementation based on specific logic
|
||||
return m.repository.GetPrices(orgId, time.Now())
|
||||
}
|
||||
|
||||
func (m *GormModel) GetPrices(orgId string, date time.Time) ([]*types.Price, error) {
|
||||
return m.repository.GetPrices(orgId, date)
|
||||
}
|
||||
|
||||
func (m *GormModel) InsertPrice(price *types.Price) error {
|
||||
return m.repository.InsertPrice(price)
|
||||
}
|
||||
|
||||
// SystemHealthInteface implementation
|
||||
func (m *GormModel) PingDatabase() error {
|
||||
return m.repository.Ping()
|
||||
}
|
||||
|
||||
func (m *GormModel) Ping() error {
|
||||
return m.repository.Ping()
|
||||
}
|
||||
|
||||
// BudgetInterface implementation
|
||||
func (m *GormModel) GetBudget(orgId, userId string) (*types.Budget, error) {
|
||||
// Stub implementation - would need proper implementation
|
||||
return &types.Budget{}, nil
|
||||
}
|
||||
|
||||
func (m *GormModel) CreateBudget(budget *types.Budget, userId string) error {
|
||||
return m.repository.InsertBudget(budget)
|
||||
}
|
||||
|
||||
func (m *GormModel) DeleteBudget(budgetId, userId string) error {
|
||||
// Stub implementation - would need proper implementation
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GormModel) InsertBudget(budget *types.Budget) error {
|
||||
return m.repository.InsertBudget(budget)
|
||||
}
|
||||
|
||||
func (m *GormModel) GetBudgets(orgId string) ([]*types.Budget, error) {
|
||||
return m.repository.GetBudgets(orgId)
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
func (m *GormModel) GetOrgUserIds(orgId string) ([]string, error) {
|
||||
return m.repository.GetOrgUserIds(orgId)
|
||||
}
|
||||
@@ -14,16 +14,19 @@ type Model struct {
|
||||
config types.Config
|
||||
}
|
||||
|
||||
|
||||
type Interface interface {
|
||||
UserInterface
|
||||
OrgInterface
|
||||
AccountInterface
|
||||
TransactionInterface
|
||||
AttachmentInterface
|
||||
PriceInterface
|
||||
SessionInterface
|
||||
ApiKeyInterface
|
||||
SystemHealthInteface
|
||||
BudgetInterface
|
||||
GetTransaction(string, string, string) (*types.Transaction, error)
|
||||
}
|
||||
|
||||
func NewModel(db db.Datastore, bcrypt util.Bcrypt, config types.Config) *Model {
|
||||
@@ -31,3 +34,4 @@ func NewModel(db db.Datastore, bcrypt util.Bcrypt, config types.Config) *Model {
|
||||
Instance = model
|
||||
return model
|
||||
}
|
||||
|
||||
|
||||
@@ -2,44 +2,45 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/mocks"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCreatePrice(t *testing.T) {
|
||||
|
||||
price := types.Price{
|
||||
"1",
|
||||
"2",
|
||||
"BTC",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
6700,
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
Currency: "BTC",
|
||||
Date: time.Unix(0, 0),
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
Price: 6700,
|
||||
}
|
||||
|
||||
badPrice := types.Price{
|
||||
"1",
|
||||
"2",
|
||||
"",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
6700,
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
Currency: "",
|
||||
Date: time.Unix(0, 0),
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
Price: 6700,
|
||||
}
|
||||
|
||||
badOrg := types.Price{
|
||||
"1",
|
||||
"1",
|
||||
"BTC",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
6700,
|
||||
Id: "1",
|
||||
OrgId: "1",
|
||||
Currency: "BTC",
|
||||
Date: time.Unix(0, 0),
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
Price: 6700,
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
@@ -89,13 +90,13 @@ func TestCreatePrice(t *testing.T) {
|
||||
func TestDeletePrice(t *testing.T) {
|
||||
|
||||
price := types.Price{
|
||||
"1",
|
||||
"2",
|
||||
"BTC",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
6700,
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
Currency: "BTC",
|
||||
Date: time.Unix(0, 0),
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
Price: 6700,
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
|
||||
@@ -3,9 +3,10 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/ws"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TransactionInterface interface {
|
||||
@@ -105,7 +106,7 @@ func (model *Model) GetTransactionsByAccount(orgId string, userId string, accoun
|
||||
}
|
||||
|
||||
if !model.accountsContainWriteAccess(userAccounts, accountId) {
|
||||
return nil, errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", accountId))
|
||||
return nil, fmt.Errorf("%s %s", "user does not have permission to access account", accountId)
|
||||
}
|
||||
|
||||
return model.db.GetTransactionsByAccount(accountId, options)
|
||||
@@ -142,7 +143,7 @@ func (model *Model) DeleteTransaction(id string, userId string, orgId string) (e
|
||||
|
||||
for _, split := range transaction.Splits {
|
||||
if !model.accountsContainWriteAccess(userAccounts, split.AccountId) {
|
||||
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", split.AccountId))
|
||||
return fmt.Errorf("%s %s", "user does not have permission to access account", split.AccountId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,6 +169,31 @@ func (model *Model) getTransactionById(id string) (*types.Transaction, error) {
|
||||
return model.db.GetTransactionById(id)
|
||||
}
|
||||
|
||||
func (model *Model) GetTransaction(transactionId, orgId, userId string) (*types.Transaction, error) {
|
||||
transaction, err := model.getTransactionById(transactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if transaction == nil || transaction.OrgId != orgId {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if user has access to all accounts in the transaction
|
||||
userAccounts, err := model.GetAccounts(orgId, userId, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, split := range transaction.Splits {
|
||||
if !model.accountsContainReadAccess(userAccounts, split.AccountId) {
|
||||
return nil, fmt.Errorf("user does not have permission to access account %s", split.AccountId)
|
||||
}
|
||||
}
|
||||
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
func (model *Model) checkSplits(transaction *types.Transaction) (err error) {
|
||||
if len(transaction.Splits) < 2 {
|
||||
return errors.New("at least 2 splits are required")
|
||||
@@ -189,13 +215,13 @@ func (model *Model) checkSplits(transaction *types.Transaction) (err error) {
|
||||
|
||||
for _, split := range transaction.Splits {
|
||||
if !model.accountsContainWriteAccess(userAccounts, split.AccountId) {
|
||||
return errors.New(fmt.Sprintf("%s %s", "user does not have permission to access account", split.AccountId))
|
||||
return fmt.Errorf("%s %s", "user does not have permission to access account", split.AccountId)
|
||||
}
|
||||
|
||||
account := model.getAccountFromList(userAccounts, split.AccountId)
|
||||
|
||||
if account.HasChildren == true {
|
||||
return errors.New("Cannot use parent account for split")
|
||||
if !account.HasChildren {
|
||||
return errors.New("cannot use parent account for split")
|
||||
}
|
||||
|
||||
if account.Currency == org.Currency && split.NativeAmount != split.Amount {
|
||||
|
||||
@@ -2,12 +2,13 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TdTransaction struct {
|
||||
@@ -57,72 +58,72 @@ func TestCreateTransaction(t *testing.T) {
|
||||
"successful": {
|
||||
err: nil,
|
||||
tx: &types.Transaction{
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
"description",
|
||||
"",
|
||||
false,
|
||||
[]*types.Split{
|
||||
&types.Split{"1", "1", 1000, 1000},
|
||||
&types.Split{"1", "2", -1000, -1000},
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
UserId: "3",
|
||||
Date: time.Now(),
|
||||
Inserted: time.Now(),
|
||||
Updated: time.Now(),
|
||||
Description: "description",
|
||||
Data: "",
|
||||
Deleted: false,
|
||||
Splits: []*types.Split{
|
||||
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 1000},
|
||||
&types.Split{TransactionId: "1", AccountId: "2", Amount: -1000, NativeAmount: -1000},
|
||||
},
|
||||
},
|
||||
},
|
||||
"bad split amounts": {
|
||||
err: errors.New("splits must add up to 0"),
|
||||
tx: &types.Transaction{
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
"description",
|
||||
"",
|
||||
false,
|
||||
[]*types.Split{
|
||||
&types.Split{"1", "1", 1000, 1000},
|
||||
&types.Split{"1", "2", -500, -500},
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
UserId: "3",
|
||||
Date: time.Now(),
|
||||
Inserted: time.Now(),
|
||||
Updated: time.Now(),
|
||||
Description: "description",
|
||||
Data: "",
|
||||
Deleted: false,
|
||||
Splits: []*types.Split{
|
||||
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 1000},
|
||||
&types.Split{TransactionId: "1", AccountId: "2", Amount: -500, NativeAmount: -500},
|
||||
},
|
||||
},
|
||||
},
|
||||
"lacking permission": {
|
||||
err: errors.New("user does not have permission to access account 3"),
|
||||
tx: &types.Transaction{
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
"description",
|
||||
"",
|
||||
false,
|
||||
[]*types.Split{
|
||||
&types.Split{"1", "1", 1000, 1000},
|
||||
&types.Split{"1", "3", -1000, -1000},
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
UserId: "3",
|
||||
Date: time.Now(),
|
||||
Inserted: time.Now(),
|
||||
Updated: time.Now(),
|
||||
Description: "description",
|
||||
Data: "",
|
||||
Deleted: false,
|
||||
Splits: []*types.Split{
|
||||
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 1000},
|
||||
&types.Split{TransactionId: "1", AccountId: "3", Amount: -1000, NativeAmount: -1000},
|
||||
},
|
||||
},
|
||||
},
|
||||
"nativeAmount mismatch": {
|
||||
err: errors.New("nativeAmount must equal amount for native currency splits"),
|
||||
tx: &types.Transaction{
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
"description",
|
||||
"",
|
||||
false,
|
||||
[]*types.Split{
|
||||
&types.Split{"1", "1", 1000, 500},
|
||||
&types.Split{"1", "2", -1000, -500},
|
||||
Id: "1",
|
||||
OrgId: "2",
|
||||
UserId: "3",
|
||||
Date: time.Now(),
|
||||
Inserted: time.Now(),
|
||||
Updated: time.Now(),
|
||||
Description: "description",
|
||||
Data: "",
|
||||
Deleted: false,
|
||||
Splits: []*types.Split{
|
||||
&types.Split{TransactionId: "1", AccountId: "1", Amount: 1000, NativeAmount: 500},
|
||||
&types.Split{TransactionId: "1", AccountId: "2", Amount: -1000, NativeAmount: -500},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
20
core/model/types/attachment.go
Normal file
20
core/model/types/attachment.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Attachment struct {
|
||||
Id string `json:"id"`
|
||||
TransactionId string `json:"transactionId"`
|
||||
OrgId string `json:"orgId"`
|
||||
UserId string `json:"userId"`
|
||||
FileName string `json:"fileName"`
|
||||
OriginalName string `json:"originalName"`
|
||||
ContentType string `json:"contentType"`
|
||||
FileSize int64 `json:"fileSize"`
|
||||
FilePath string `json:"filePath"`
|
||||
Description string `json:"description"`
|
||||
Uploaded time.Time `json:"uploaded"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
@@ -1,18 +1,27 @@
|
||||
package types
|
||||
|
||||
import "github.com/openaccounting/oa-server/core/storage"
|
||||
|
||||
type Config struct {
|
||||
WebUrl string
|
||||
Address string
|
||||
Port int
|
||||
ApiPrefix string
|
||||
KeyFile string
|
||||
CertFile string
|
||||
DatabaseAddress string
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
MailgunDomain string
|
||||
MailgunKey string
|
||||
MailgunEmail string
|
||||
MailgunSender string
|
||||
WebUrl string `mapstructure:"weburl"`
|
||||
Address string `mapstructure:"address"`
|
||||
Port int `mapstructure:"port"`
|
||||
ApiPrefix string `mapstructure:"apiprefix"`
|
||||
KeyFile string `mapstructure:"keyfile"`
|
||||
CertFile string `mapstructure:"certfile"`
|
||||
// Database configuration
|
||||
DatabaseDriver string `mapstructure:"databasedriver"` // "mysql" or "sqlite"
|
||||
DatabaseAddress string `mapstructure:"databaseaddress"`
|
||||
Database string `mapstructure:"database"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"` // Sensitive: use OA_PASSWORD env var
|
||||
// SQLite specific
|
||||
DatabaseFile string `mapstructure:"databasefile"`
|
||||
// Email configuration
|
||||
MailgunDomain string `mapstructure:"mailgundomain"`
|
||||
MailgunKey string `mapstructure:"mailgunkey"` // Sensitive: use OA_MAILGUN_KEY env var
|
||||
MailgunEmail string `mapstructure:"mailgunemail"`
|
||||
MailgunSender string `mapstructure:"mailgunsender"`
|
||||
// Storage configuration
|
||||
Storage storage.Config `mapstructure:"storage"`
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (model *Model) CreateUser(user *types.User) error {
|
||||
return errors.New("email required")
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(".+@.+\\..+")
|
||||
re := regexp.MustCompile(`.+@.+\..+`)
|
||||
|
||||
if re.FindString(user.Email) == "" {
|
||||
return errors.New("invalid email address")
|
||||
@@ -47,7 +47,7 @@ func (model *Model) CreateUser(user *types.User) error {
|
||||
return errors.New("password required")
|
||||
}
|
||||
|
||||
if user.AgreeToTerms != true {
|
||||
if !user.AgreeToTerms {
|
||||
return errors.New("must agree to terms")
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (model *Model) ResetPassword(email string) error {
|
||||
|
||||
if err != nil {
|
||||
// Don't send back error so people can't try to find user accounts
|
||||
log.Printf("Invalid email for reset password " + email)
|
||||
log.Printf("Invalid email for reset password %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ func (model *Model) ConfirmResetPassword(password string, code string) (*types.U
|
||||
user, err := model.db.GetUserByResetCode(code)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("Invalid code")
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
|
||||
passwordHash, err := model.bcrypt.GenerateFromPassword([]byte(password), model.bcrypt.GetDefaultCost())
|
||||
|
||||
@@ -2,12 +2,13 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/mocks"
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TdUser struct {
|
||||
@@ -39,33 +40,35 @@ func TestCreateUser(t *testing.T) {
|
||||
// EmailVerifyCode string `json:"-"`
|
||||
|
||||
user := types.User{
|
||||
"0",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
"John",
|
||||
"Doe",
|
||||
"johndoe@email.com",
|
||||
"password",
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
Id: "0",
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
Email: "johndoe@email.com",
|
||||
Password: "password",
|
||||
PasswordHash: "",
|
||||
AgreeToTerms: true,
|
||||
PasswordReset: "",
|
||||
EmailVerified: false,
|
||||
EmailVerifyCode: "",
|
||||
SignupSource: "",
|
||||
}
|
||||
|
||||
badUser := types.User{
|
||||
"0",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
"John",
|
||||
"Doe",
|
||||
"",
|
||||
"password",
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
Id: "0",
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
Email: "",
|
||||
Password: "password",
|
||||
PasswordHash: "",
|
||||
AgreeToTerms: true,
|
||||
PasswordReset: "",
|
||||
EmailVerified: false,
|
||||
EmailVerifyCode: "",
|
||||
SignupSource: "",
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
@@ -109,33 +112,35 @@ func TestCreateUser(t *testing.T) {
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
|
||||
user := types.User{
|
||||
"0",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
"John2",
|
||||
"Doe",
|
||||
"johndoe@email.com",
|
||||
"password",
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
Id: "0",
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
FirstName: "John2",
|
||||
LastName: "Doe",
|
||||
Email: "johndoe@email.com",
|
||||
Password: "password",
|
||||
PasswordHash: "",
|
||||
AgreeToTerms: true,
|
||||
PasswordReset: "",
|
||||
EmailVerified: false,
|
||||
EmailVerifyCode: "",
|
||||
SignupSource: "",
|
||||
}
|
||||
|
||||
badUser := types.User{
|
||||
"0",
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, 0),
|
||||
"John2",
|
||||
"Doe",
|
||||
"johndoe@email.com",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
Id: "0",
|
||||
Inserted: time.Unix(0, 0),
|
||||
Updated: time.Unix(0, 0),
|
||||
FirstName: "John2",
|
||||
LastName: "Doe",
|
||||
Email: "johndoe@email.com",
|
||||
Password: "",
|
||||
PasswordHash: "",
|
||||
AgreeToTerms: true,
|
||||
PasswordReset: "",
|
||||
EmailVerified: false,
|
||||
EmailVerifyCode: "",
|
||||
SignupSource: "",
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
|
||||
375
core/repository/gorm_repository.go
Normal file
375
core/repository/gorm_repository.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/openaccounting/oa-server/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormRepository implements the same interfaces as core/model/db but uses GORM
|
||||
type GormRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// Note: GormRepository implements most of the Datastore interface
|
||||
// Some methods like DeleteAndInsertTransaction need to be added for full compatibility
|
||||
|
||||
// NewGormRepository creates a new GORM repository
|
||||
func NewGormRepository(db *gorm.DB) *GormRepository {
|
||||
return &GormRepository{db: db}
|
||||
}
|
||||
|
||||
// UserInterface implementation
|
||||
func (r *GormRepository) InsertUser(user *types.User) error {
|
||||
user.Inserted = time.Now()
|
||||
user.Updated = user.Inserted
|
||||
user.PasswordReset = ""
|
||||
|
||||
// Convert types.User to models.User
|
||||
gormUser := &models.User{
|
||||
ID: []byte(user.Id), // Convert string ID to []byte
|
||||
Inserted: uint64(util.TimeToMs(user.Inserted)),
|
||||
Updated: uint64(util.TimeToMs(user.Updated)),
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
Email: user.Email,
|
||||
PasswordHash: user.PasswordHash,
|
||||
AgreeToTerms: user.AgreeToTerms,
|
||||
PasswordReset: user.PasswordReset,
|
||||
EmailVerified: user.EmailVerified,
|
||||
EmailVerifyCode: user.EmailVerifyCode,
|
||||
SignupSource: user.SignupSource,
|
||||
}
|
||||
|
||||
result := r.db.Create(gormUser)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected < 1 {
|
||||
return errors.New("unable to insert user into db")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) VerifyUser(code string) error {
|
||||
result := r.db.Model(&models.User{}).
|
||||
Where("email_verify_code = ?", code).
|
||||
Updates(map[string]interface{}{
|
||||
"updated": util.TimeToMs(time.Now()),
|
||||
"email_verified": true,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("invalid code")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateUser(user *types.User) error {
|
||||
user.Updated = time.Now()
|
||||
|
||||
result := r.db.Model(&models.User{}).
|
||||
Where("id = ?", []byte(user.Id)).
|
||||
Updates(map[string]interface{}{
|
||||
"updated": util.TimeToMs(user.Updated),
|
||||
"password_hash": user.PasswordHash,
|
||||
"password_reset": "",
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateUserResetPassword(user *types.User) error {
|
||||
user.Updated = time.Now()
|
||||
|
||||
result := r.db.Model(&models.User{}).
|
||||
Where("id = ?", []byte(user.Id)).
|
||||
Updates(map[string]interface{}{
|
||||
"updated": util.TimeToMs(user.Updated),
|
||||
"password_reset": user.PasswordReset,
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetVerifiedUserByEmail(email string) (*types.User, error) {
|
||||
var gormUser models.User
|
||||
result := r.db.Where("email = ? AND email_verified = ?",
|
||||
strings.TrimSpace(strings.ToLower(email)), true).
|
||||
First(&gormUser)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormUserToTypesUser(&gormUser), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetUserByActiveSession(sessionId string) (*types.User, error) {
|
||||
var gormUser models.User
|
||||
result := r.db.Table("users").
|
||||
Select("users.*").
|
||||
Joins("JOIN sessions ON sessions.user_id = users.id").
|
||||
Where("sessions.terminated IS NULL AND sessions.id = ?", []byte(sessionId)).
|
||||
First(&gormUser)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormUserToTypesUser(&gormUser), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetUserByApiKey(keyId string) (*types.User, error) {
|
||||
var gormUser models.User
|
||||
result := r.db.Table("users").
|
||||
Select("users.*").
|
||||
Joins("JOIN api_keys ON api_keys.user_id = users.id").
|
||||
Where("api_keys.deleted_at IS NULL AND api_keys.id = ?", []byte(keyId)).
|
||||
First(&gormUser)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormUserToTypesUser(&gormUser), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetUserByResetCode(code string) (*types.User, error) {
|
||||
var gormUser models.User
|
||||
result := r.db.Where("password_reset = ?", code).First(&gormUser)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormUserToTypesUser(&gormUser), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetUserByEmailVerifyCode(code string) (*types.User, error) {
|
||||
// only allow this for 3 days
|
||||
minInserted := (time.Now().UnixNano() / 1000000) - (3 * 24 * 60 * 60 * 1000)
|
||||
|
||||
var gormUser models.User
|
||||
result := r.db.Where("email_verify_code = ? AND inserted > ?", code, minInserted).
|
||||
First(&gormUser)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormUserToTypesUser(&gormUser), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetOrgAdmins(orgId string) ([]*types.User, error) {
|
||||
var gormUsers []models.User
|
||||
result := r.db.Table("users").
|
||||
Select("users.*").
|
||||
Joins("JOIN user_orgs ON user_orgs.user_id = users.id").
|
||||
Where("user_orgs.admin = ? AND user_orgs.org_id = ?", true, []byte(orgId)).
|
||||
Find(&gormUsers)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
users := make([]*types.User, len(gormUsers))
|
||||
for i, gormUser := range gormUsers {
|
||||
users[i] = r.convertGormUserToTypesUser(&gormUser)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// Helper function to convert GORM User to types.User
|
||||
func (r *GormRepository) convertGormUserToTypesUser(gormUser *models.User) *types.User {
|
||||
return &types.User{
|
||||
Id: string(gormUser.ID),
|
||||
Inserted: util.MsToTime(int64(gormUser.Inserted)),
|
||||
Updated: util.MsToTime(int64(gormUser.Updated)),
|
||||
FirstName: gormUser.FirstName,
|
||||
LastName: gormUser.LastName,
|
||||
Email: gormUser.Email,
|
||||
PasswordHash: gormUser.PasswordHash,
|
||||
AgreeToTerms: gormUser.AgreeToTerms,
|
||||
PasswordReset: gormUser.PasswordReset,
|
||||
EmailVerified: gormUser.EmailVerified,
|
||||
EmailVerifyCode: gormUser.EmailVerifyCode,
|
||||
SignupSource: gormUser.SignupSource,
|
||||
}
|
||||
}
|
||||
|
||||
// AccountInterface implementation
|
||||
func (r *GormRepository) InsertAccount(account *types.Account) error {
|
||||
account.Inserted = time.Now()
|
||||
account.Updated = account.Inserted
|
||||
|
||||
// Convert types.Account to models.Account
|
||||
gormAccount := &models.Account{
|
||||
ID: []byte(account.Id),
|
||||
OrgID: []byte(account.OrgId),
|
||||
Inserted: uint64(util.TimeToMs(account.Inserted)),
|
||||
Updated: uint64(util.TimeToMs(account.Updated)),
|
||||
Name: account.Name,
|
||||
Parent: []byte(account.Parent),
|
||||
Currency: account.Currency,
|
||||
Precision: account.Precision,
|
||||
DebitBalance: account.DebitBalance,
|
||||
}
|
||||
|
||||
return r.db.Create(gormAccount).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateAccount(account *types.Account) error {
|
||||
account.Updated = time.Now()
|
||||
|
||||
result := r.db.Model(&models.Account{}).
|
||||
Where("id = ?", []byte(account.Id)).
|
||||
Updates(map[string]interface{}{
|
||||
"updated": util.TimeToMs(account.Updated),
|
||||
"name": account.Name,
|
||||
"parent": []byte(account.Parent),
|
||||
"currency": account.Currency,
|
||||
"precision": account.Precision,
|
||||
"debit_balance": account.DebitBalance,
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetAccount(id string) (*types.Account, error) {
|
||||
var gormAccount models.Account
|
||||
result := r.db.Where("id = ?", []byte(id)).First(&gormAccount)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormAccountToTypesAccount(&gormAccount), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetAccountsByOrgId(orgId string) ([]*types.Account, error) {
|
||||
var gormAccounts []models.Account
|
||||
result := r.db.Where("org_id = ?", []byte(orgId)).Find(&gormAccounts)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
accounts := make([]*types.Account, len(gormAccounts))
|
||||
for i, gormAccount := range gormAccounts {
|
||||
accounts[i] = r.convertGormAccountToTypesAccount(&gormAccount)
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetPermissionedAccountIds(orgId, userId, tokenId string) ([]string, error) {
|
||||
var accountIds []string
|
||||
result := r.db.Table("permissions").
|
||||
Select("DISTINCT LOWER(HEX(account_id)) as account_id").
|
||||
Where("org_id = ? AND user_id = ?", []byte(orgId), []byte(userId)).
|
||||
Pluck("account_id", &accountIds)
|
||||
|
||||
return accountIds, result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetSplitCountByAccountId(id string) (int64, error) {
|
||||
var count int64
|
||||
result := r.db.Model(&models.Split{}).
|
||||
Where("account_id = ?", []byte(id)).
|
||||
Count(&count)
|
||||
|
||||
return count, result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetChildCountByAccountId(id string) (int64, error) {
|
||||
var count int64
|
||||
result := r.db.Model(&models.Account{}).
|
||||
Where("parent = ?", []byte(id)).
|
||||
Count(&count)
|
||||
|
||||
return count, result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteAccount(id string) error {
|
||||
return r.db.Where("id = ?", []byte(id)).Delete(&models.Account{}).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetRootAccount(orgId string) (*types.Account, error) {
|
||||
var gormAccount models.Account
|
||||
result := r.db.Where("org_id = ? AND parent = ?", []byte(orgId), []byte{0}).
|
||||
First(&gormAccount)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormAccountToTypesAccount(&gormAccount), nil
|
||||
}
|
||||
|
||||
// Balance-related methods (simplified implementations)
|
||||
func (r *GormRepository) AddBalances(accounts []*types.Account, date time.Time) error {
|
||||
// Implementation would need to be completed based on your balance calculation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) AddNativeBalancesCost(accounts []*types.Account, date time.Time) error {
|
||||
// Implementation would need to be completed based on your balance calculation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) AddNativeBalancesNearestInTime(accounts []*types.Account, date time.Time) error {
|
||||
// Implementation would need to be completed based on your balance calculation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) AddBalance(account *types.Account, date time.Time) error {
|
||||
// Implementation would need to be completed based on your balance calculation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) AddNativeBalanceCost(account *types.Account, date time.Time) error {
|
||||
// Implementation would need to be completed based on your balance calculation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) AddNativeBalanceNearestInTime(account *types.Account, date time.Time) error {
|
||||
// Implementation would need to be completed based on your balance calculation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to convert GORM Account to types.Account
|
||||
func (r *GormRepository) convertGormAccountToTypesAccount(gormAccount *models.Account) *types.Account {
|
||||
return &types.Account{
|
||||
Id: string(gormAccount.ID),
|
||||
OrgId: string(gormAccount.OrgID),
|
||||
Inserted: util.MsToTime(int64(gormAccount.Inserted)),
|
||||
Updated: util.MsToTime(int64(gormAccount.Updated)),
|
||||
Name: gormAccount.Name,
|
||||
Parent: string(gormAccount.Parent),
|
||||
Currency: gormAccount.Currency,
|
||||
Precision: gormAccount.Precision,
|
||||
DebitBalance: gormAccount.DebitBalance,
|
||||
// Balance fields would be populated by the AddBalance methods
|
||||
}
|
||||
}
|
||||
|
||||
// Escape method for SQL injection protection (GORM handles this automatically)
|
||||
func (r *GormRepository) Escape(sql string) string {
|
||||
// GORM handles SQL injection protection automatically
|
||||
// This method is kept for interface compatibility
|
||||
return sql
|
||||
}
|
||||
462
core/repository/gorm_repository_interfaces.go
Normal file
462
core/repository/gorm_repository_interfaces.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/openaccounting/oa-server/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OrgInterface implementation
|
||||
func (r *GormRepository) CreateOrg(org *types.Org, userId string, accounts []*types.Account) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
org.Inserted = time.Now()
|
||||
org.Updated = org.Inserted
|
||||
|
||||
// Create org
|
||||
gormOrg := &models.Org{
|
||||
ID: []byte(org.Id),
|
||||
Inserted: uint64(util.TimeToMs(org.Inserted)),
|
||||
Updated: uint64(util.TimeToMs(org.Updated)),
|
||||
Name: org.Name,
|
||||
Currency: org.Currency,
|
||||
Precision: org.Precision,
|
||||
Timezone: org.Timezone,
|
||||
}
|
||||
|
||||
if err := tx.Create(gormOrg).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create accounts
|
||||
for _, account := range accounts {
|
||||
gormAccount := &models.Account{
|
||||
ID: []byte(account.Id),
|
||||
OrgID: []byte(account.OrgId),
|
||||
Inserted: uint64(util.TimeToMs(time.Now())),
|
||||
Updated: uint64(util.TimeToMs(time.Now())),
|
||||
Name: account.Name,
|
||||
Parent: []byte(account.Parent),
|
||||
Currency: account.Currency,
|
||||
Precision: account.Precision,
|
||||
DebitBalance: account.DebitBalance,
|
||||
}
|
||||
if err := tx.Create(gormAccount).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create userorg association
|
||||
userOrg := &models.UserOrg{
|
||||
UserID: []byte(userId),
|
||||
OrgID: []byte(org.Id),
|
||||
Admin: true,
|
||||
}
|
||||
|
||||
return tx.Create(userOrg).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateOrg(org *types.Org) error {
|
||||
org.Updated = time.Now()
|
||||
|
||||
return r.db.Model(&models.Org{}).
|
||||
Where("id = ?", []byte(org.Id)).
|
||||
Updates(map[string]interface{}{
|
||||
"updated": util.TimeToMs(org.Updated),
|
||||
"name": org.Name,
|
||||
"currency": org.Currency,
|
||||
"precision": org.Precision,
|
||||
"timezone": org.Timezone,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetOrg(orgId, userId string) (*types.Org, error) {
|
||||
var gormOrg models.Org
|
||||
result := r.db.Table("orgs").
|
||||
Select("orgs.*").
|
||||
Joins("JOIN user_orgs ON user_orgs.org_id = orgs.id").
|
||||
Where("orgs.id = ? AND user_orgs.user_id = ?", []byte(orgId), []byte(userId)).
|
||||
First(&gormOrg)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormOrgToTypesOrg(&gormOrg), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetOrgs(userId string) ([]*types.Org, error) {
|
||||
var gormOrgs []models.Org
|
||||
result := r.db.Table("orgs").
|
||||
Select("orgs.*").
|
||||
Joins("JOIN user_orgs ON user_orgs.org_id = orgs.id").
|
||||
Where("user_orgs.user_id = ?", []byte(userId)).
|
||||
Find(&gormOrgs)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
orgs := make([]*types.Org, len(gormOrgs))
|
||||
for i, gormOrg := range gormOrgs {
|
||||
orgs[i] = r.convertGormOrgToTypesOrg(&gormOrg)
|
||||
}
|
||||
|
||||
return orgs, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetOrgUserIds(orgId string) ([]string, error) {
|
||||
var userIds []string
|
||||
result := r.db.Table("user_orgs").
|
||||
Select("LOWER(HEX(user_id)) as user_id").
|
||||
Where("org_id = ?", []byte(orgId)).
|
||||
Pluck("user_id", &userIds)
|
||||
|
||||
return userIds, result.Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) InsertInvite(invite *types.Invite) error {
|
||||
invite.Inserted = time.Now()
|
||||
invite.Updated = invite.Inserted
|
||||
|
||||
gormInvite := &models.Invite{
|
||||
ID: invite.Id,
|
||||
OrgID: []byte(invite.OrgId),
|
||||
Inserted: uint64(util.TimeToMs(invite.Inserted)),
|
||||
Updated: uint64(util.TimeToMs(invite.Updated)),
|
||||
Email: invite.Email,
|
||||
Accepted: invite.Accepted,
|
||||
}
|
||||
|
||||
return r.db.Create(gormInvite).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) AcceptInvite(invite *types.Invite, userId string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Update invite
|
||||
if err := tx.Model(&models.Invite{}).
|
||||
Where("id = ?", []byte(invite.Id)).
|
||||
Update("accepted", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create userorg association
|
||||
userOrg := &models.UserOrg{
|
||||
UserID: []byte(userId),
|
||||
OrgID: []byte(invite.OrgId),
|
||||
Admin: false,
|
||||
}
|
||||
|
||||
return tx.Create(userOrg).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetInvites(orgId string) ([]*types.Invite, error) {
|
||||
var gormInvites []models.Invite
|
||||
result := r.db.Where("org_id = ?", []byte(orgId)).Find(&gormInvites)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
invites := make([]*types.Invite, len(gormInvites))
|
||||
for i, gormInvite := range gormInvites {
|
||||
invites[i] = r.convertGormInviteToTypesInvite(&gormInvite)
|
||||
}
|
||||
|
||||
return invites, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetInvite(id string) (*types.Invite, error) {
|
||||
var gormInvite models.Invite
|
||||
result := r.db.Where("id = ?", id).First(&gormInvite)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return r.convertGormInviteToTypesInvite(&gormInvite), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteInvite(id string) error {
|
||||
return r.db.Where("id = ?", id).Delete(&models.Invite{}).Error
|
||||
}
|
||||
|
||||
// SessionInterface implementation (basic stubs)
|
||||
func (r *GormRepository) InsertSession(session *types.Session) error {
|
||||
var terminated uint64
|
||||
if session.Terminated.Valid {
|
||||
terminated = uint64(util.TimeToMs(session.Terminated.Time))
|
||||
}
|
||||
|
||||
gormSession := &models.Session{
|
||||
ID: []byte(session.Id),
|
||||
UserID: []byte(session.UserId),
|
||||
Inserted: uint64(util.TimeToMs(time.Now())),
|
||||
Updated: uint64(util.TimeToMs(time.Now())),
|
||||
Terminated: terminated,
|
||||
}
|
||||
return r.db.Create(gormSession).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) TerminateSession(sessionId string) error {
|
||||
return r.db.Model(&models.Session{}).
|
||||
Where("id = ?", []byte(sessionId)).
|
||||
Update("terminated", uint64(util.TimeToMs(time.Now()))).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteSession(sessionId, userId string) error {
|
||||
return r.db.Where("id = ? AND user_id = ?", []byte(sessionId), []byte(userId)).
|
||||
Delete(&models.Session{}).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateSessionActivity(sessionId string) error {
|
||||
return r.db.Model(&models.Session{}).
|
||||
Where("id = ?", []byte(sessionId)).
|
||||
Update("updated", uint64(util.TimeToMs(time.Now()))).Error
|
||||
}
|
||||
|
||||
// APIKey interface (basic stubs)
|
||||
func (r *GormRepository) InsertApiKey(apiKey *types.ApiKey) error {
|
||||
gormApiKey := &models.APIKey{
|
||||
ID: []byte(apiKey.Id),
|
||||
UserID: []byte(apiKey.UserId),
|
||||
Inserted: uint64(util.TimeToMs(time.Now())),
|
||||
Updated: uint64(util.TimeToMs(time.Now())),
|
||||
Label: apiKey.Label,
|
||||
}
|
||||
return r.db.Create(gormApiKey).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteApiKey(keyId, userId string) error {
|
||||
return r.db.Where("id = ? AND user_id = ?", []byte(keyId), []byte(userId)).
|
||||
Delete(&models.APIKey{}).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateApiKey(apiKey *types.ApiKey) error {
|
||||
return r.db.Model(&models.APIKey{}).
|
||||
Where("id = ?", []byte(apiKey.Id)).
|
||||
Updates(map[string]interface{}{
|
||||
"updated": uint64(util.TimeToMs(time.Now())),
|
||||
"label": apiKey.Label,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetApiKeys(userId string) ([]*types.ApiKey, error) {
|
||||
var gormApiKeys []models.APIKey
|
||||
result := r.db.Where("user_id = ? AND deleted_at IS NULL", []byte(userId)).
|
||||
Find(&gormApiKeys)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
apiKeys := make([]*types.ApiKey, len(gormApiKeys))
|
||||
for i, gormApiKey := range gormApiKeys {
|
||||
apiKeys[i] = r.convertGormApiKeyToTypesApiKey(&gormApiKey)
|
||||
}
|
||||
|
||||
return apiKeys, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) UpdateApiKeyActivity(keyId string) error {
|
||||
return r.db.Model(&models.APIKey{}).
|
||||
Where("id = ?", []byte(keyId)).
|
||||
Update("updated", uint64(util.TimeToMs(time.Now()))).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) convertGormApiKeyToTypesApiKey(gormApiKey *models.APIKey) *types.ApiKey {
|
||||
return &types.ApiKey{
|
||||
Id: string(gormApiKey.ID),
|
||||
Inserted: util.MsToTime(int64(gormApiKey.Inserted)),
|
||||
Updated: util.MsToTime(int64(gormApiKey.Updated)),
|
||||
UserId: string(gormApiKey.UserID),
|
||||
Label: gormApiKey.Label,
|
||||
}
|
||||
}
|
||||
|
||||
// TransactionInterface implementation
|
||||
func (r *GormRepository) InsertTransaction(transaction *types.Transaction) error {
|
||||
gormTransaction := &models.Transaction{
|
||||
ID: []byte(transaction.Id),
|
||||
OrgID: []byte(transaction.OrgId),
|
||||
UserID: []byte(transaction.UserId),
|
||||
Date: uint64(transaction.Date.Unix()),
|
||||
Inserted: uint64(util.TimeToMs(time.Now())),
|
||||
Updated: uint64(util.TimeToMs(time.Now())),
|
||||
Description: transaction.Description,
|
||||
Data: transaction.Data,
|
||||
}
|
||||
return r.db.Create(gormTransaction).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetTransactionById(id string) (*types.Transaction, error) {
|
||||
var gormTransaction models.Transaction
|
||||
result := r.db.Where("id = ?", []byte(id)).First(&gormTransaction)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return r.convertGormTransactionToTypesTransaction(&gormTransaction), nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetTransactionsByAccount(accountId string, options *types.QueryOptions) ([]*types.Transaction, error) {
|
||||
var gormTransactions []models.Transaction
|
||||
query := r.db.Table("transactions").
|
||||
Joins("JOIN splits ON splits.transaction_id = transactions.id").
|
||||
Where("splits.account_id = ?", []byte(accountId))
|
||||
|
||||
if options != nil {
|
||||
// Apply query options like limit, skip, date range, etc.
|
||||
if options.Limit > 0 {
|
||||
query = query.Limit(int(options.Limit))
|
||||
}
|
||||
if options.Skip > 0 {
|
||||
query = query.Offset(int(options.Skip))
|
||||
}
|
||||
}
|
||||
|
||||
result := query.Find(&gormTransactions)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
transactions := make([]*types.Transaction, len(gormTransactions))
|
||||
for i, gormTx := range gormTransactions {
|
||||
transactions[i] = r.convertGormTransactionToTypesTransaction(&gormTx)
|
||||
}
|
||||
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetTransactionsByOrg(orgId string, options *types.QueryOptions, accountIds []string) ([]*types.Transaction, error) {
|
||||
var gormTransactions []models.Transaction
|
||||
query := r.db.Where("org_id = ?", []byte(orgId))
|
||||
|
||||
if len(accountIds) > 0 {
|
||||
// Convert string IDs to byte arrays
|
||||
byteAccountIds := make([][]byte, len(accountIds))
|
||||
for i, id := range accountIds {
|
||||
byteAccountIds[i] = []byte(id)
|
||||
}
|
||||
query = query.Where("id IN (SELECT DISTINCT transaction_id FROM splits WHERE account_id IN ?)", byteAccountIds)
|
||||
}
|
||||
|
||||
if options != nil {
|
||||
if options.Limit > 0 {
|
||||
query = query.Limit(int(options.Limit))
|
||||
}
|
||||
if options.Skip > 0 {
|
||||
query = query.Offset(int(options.Skip))
|
||||
}
|
||||
}
|
||||
|
||||
result := query.Find(&gormTransactions)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
transactions := make([]*types.Transaction, len(gormTransactions))
|
||||
for i, gormTx := range gormTransactions {
|
||||
transactions[i] = r.convertGormTransactionToTypesTransaction(&gormTx)
|
||||
}
|
||||
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteTransaction(id string) error {
|
||||
return r.db.Where("id = ?", []byte(id)).Delete(&models.Transaction{}).Error
|
||||
}
|
||||
|
||||
func (r *GormRepository) DeleteAndInsertTransaction(id string, transaction *types.Transaction) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete the old transaction
|
||||
if err := tx.Where("id = ?", []byte(id)).Delete(&models.Transaction{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert the new transaction
|
||||
gormTransaction := &models.Transaction{
|
||||
ID: []byte(transaction.Id),
|
||||
OrgID: []byte(transaction.OrgId),
|
||||
UserID: []byte(transaction.UserId),
|
||||
Date: uint64(transaction.Date.Unix()),
|
||||
Inserted: uint64(util.TimeToMs(time.Now())),
|
||||
Updated: uint64(util.TimeToMs(time.Now())),
|
||||
Description: transaction.Description,
|
||||
Data: transaction.Data,
|
||||
}
|
||||
|
||||
return tx.Create(gormTransaction).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *GormRepository) convertGormTransactionToTypesTransaction(gormTx *models.Transaction) *types.Transaction {
|
||||
return &types.Transaction{
|
||||
Id: string(gormTx.ID),
|
||||
OrgId: string(gormTx.OrgID),
|
||||
UserId: string(gormTx.UserID),
|
||||
Date: time.Unix(int64(gormTx.Date), 0),
|
||||
Inserted: util.MsToTime(int64(gormTx.Inserted)),
|
||||
Updated: util.MsToTime(int64(gormTx.Updated)),
|
||||
Description: gormTx.Description,
|
||||
Data: gormTx.Data,
|
||||
Deleted: gormTx.Deleted,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper conversion functions
|
||||
func (r *GormRepository) convertGormOrgToTypesOrg(gormOrg *models.Org) *types.Org {
|
||||
return &types.Org{
|
||||
Id: string(gormOrg.ID),
|
||||
Inserted: util.MsToTime(int64(gormOrg.Inserted)),
|
||||
Updated: util.MsToTime(int64(gormOrg.Updated)),
|
||||
Name: gormOrg.Name,
|
||||
Currency: gormOrg.Currency,
|
||||
Precision: gormOrg.Precision,
|
||||
Timezone: gormOrg.Timezone,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormRepository) convertGormInviteToTypesInvite(gormInvite *models.Invite) *types.Invite {
|
||||
return &types.Invite{
|
||||
Id: string(gormInvite.ID),
|
||||
OrgId: string(gormInvite.OrgID),
|
||||
Inserted: util.MsToTime(int64(gormInvite.Inserted)),
|
||||
Updated: util.MsToTime(int64(gormInvite.Updated)),
|
||||
Email: gormInvite.Email,
|
||||
Accepted: gormInvite.Accepted,
|
||||
}
|
||||
}
|
||||
|
||||
// Stub implementations for remaining interfaces that aren't fully used yet
|
||||
func (r *GormRepository) GetPrices(orgId string, date time.Time) ([]*types.Price, error) {
|
||||
// Stub implementation - add as needed
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) InsertPrice(price *types.Price) error {
|
||||
// Stub implementation - add as needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) Ping() error {
|
||||
// Check if the database connection is alive
|
||||
sqlDB, err := r.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Ping()
|
||||
}
|
||||
|
||||
func (r *GormRepository) InsertBudget(budget *types.Budget) error {
|
||||
// Stub implementation - add as needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GormRepository) GetBudgets(orgId string) ([]*types.Budget, error) {
|
||||
// Stub implementation - add as needed
|
||||
return nil, nil
|
||||
}
|
||||
129
core/server.go
129
core/server.go
@@ -1,48 +1,145 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/api"
|
||||
"github.com/openaccounting/oa-server/core/auth"
|
||||
"github.com/openaccounting/oa-server/core/model"
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"github.com/openaccounting/oa-server/core/repository"
|
||||
"github.com/openaccounting/oa-server/core/util"
|
||||
"github.com/openaccounting/oa-server/database"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
//filename is the path to the json config file
|
||||
// Initialize Viper configuration
|
||||
var config types.Config
|
||||
file, err := os.Open("./config.json")
|
||||
|
||||
|
||||
// Set config file properties
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("json")
|
||||
viper.AddConfigPath(".")
|
||||
viper.AddConfigPath("/etc/openaccounting/")
|
||||
viper.AddConfigPath("$HOME/.openaccounting")
|
||||
|
||||
// Enable environment variables
|
||||
viper.AutomaticEnv()
|
||||
viper.SetEnvPrefix("OA") // will look for OA_DATABASE_PASSWORD, etc.
|
||||
|
||||
// Configure Viper to handle nested config with environment variables
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
|
||||
// Bind specific storage environment variables for better support
|
||||
// Using mapstructure field names (snake_case)
|
||||
viper.BindEnv("Storage.backend", "OA_STORAGE_BACKEND")
|
||||
viper.BindEnv("Storage.local.root_dir", "OA_STORAGE_LOCAL_ROOTDIR")
|
||||
viper.BindEnv("Storage.local.base_url", "OA_STORAGE_LOCAL_BASEURL")
|
||||
viper.BindEnv("Storage.s3.region", "OA_STORAGE_S3_REGION")
|
||||
viper.BindEnv("Storage.s3.bucket", "OA_STORAGE_S3_BUCKET")
|
||||
viper.BindEnv("Storage.s3.prefix", "OA_STORAGE_S3_PREFIX")
|
||||
viper.BindEnv("Storage.s3.access_key_id", "OA_STORAGE_S3_ACCESSKEYID")
|
||||
viper.BindEnv("Storage.s3.secret_access_key", "OA_STORAGE_S3_SECRETACCESSKEY")
|
||||
viper.BindEnv("Storage.s3.endpoint", "OA_STORAGE_S3_ENDPOINT")
|
||||
viper.BindEnv("Storage.s3.path_style", "OA_STORAGE_S3_PATHSTYLE")
|
||||
|
||||
// Set default values
|
||||
viper.SetDefault("Address", "localhost")
|
||||
viper.SetDefault("Port", 8080)
|
||||
viper.SetDefault("DatabaseDriver", "sqlite")
|
||||
viper.SetDefault("DatabaseFile", "./openaccounting.db")
|
||||
viper.SetDefault("ApiPrefix", "/api/v1")
|
||||
|
||||
// Set storage defaults (using mapstructure field names)
|
||||
viper.SetDefault("Storage.backend", "local")
|
||||
viper.SetDefault("Storage.local.root_dir", "./uploads")
|
||||
viper.SetDefault("Storage.local.base_url", "")
|
||||
|
||||
// Read configuration
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to open ./config.json with: %s", err.Error()))
|
||||
log.Printf("Warning: Could not read config file: %v", err)
|
||||
log.Println("Using environment variables and defaults")
|
||||
}
|
||||
|
||||
// Unmarshal config into struct
|
||||
err = viper.Unmarshal(&config)
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to unmarshal config: %s", err.Error()))
|
||||
}
|
||||
|
||||
// Set storage defaults if not configured (Viper doesn't handle nested defaults well)
|
||||
if config.Storage.Backend == "" {
|
||||
config.Storage.Backend = "local"
|
||||
}
|
||||
if config.Storage.Local.RootDir == "" {
|
||||
config.Storage.Local.RootDir = "./uploads"
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
err = decoder.Decode(&config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to decode ./config.json with: %s", err.Error()))
|
||||
// Parse database address (assuming format host:port for MySQL)
|
||||
host := config.DatabaseAddress
|
||||
port := "3306"
|
||||
if len(config.DatabaseAddress) > 0 {
|
||||
// If there's a colon, split host and port
|
||||
if colonIndex := len(config.DatabaseAddress); colonIndex > 0 {
|
||||
host = config.DatabaseAddress
|
||||
}
|
||||
}
|
||||
|
||||
connectionString := config.User + ":" + config.Password + "@" + config.DatabaseAddress + "/" + config.Database
|
||||
// Default to SQLite if no driver specified
|
||||
driver := config.DatabaseDriver
|
||||
if driver == "" {
|
||||
driver = "sqlite"
|
||||
}
|
||||
|
||||
db, err := db.NewDB(connectionString)
|
||||
// Initialize GORM database
|
||||
dbConfig := &database.Config{
|
||||
Driver: driver,
|
||||
Host: host,
|
||||
Port: port,
|
||||
User: config.User,
|
||||
Password: config.Password,
|
||||
DBName: config.Database,
|
||||
File: config.DatabaseFile,
|
||||
SSLMode: "disable", // Adjust as needed
|
||||
}
|
||||
|
||||
err = database.Connect(dbConfig)
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to connect to database with: %s", err.Error()))
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
err = database.AutoMigrate()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to run migrations: %s", err.Error()))
|
||||
}
|
||||
|
||||
err = database.Migrate()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to run custom migrations: %s", err.Error()))
|
||||
}
|
||||
|
||||
bc := &util.StandardBcrypt{}
|
||||
|
||||
model.NewModel(db, bc, config)
|
||||
auth.NewAuthService(db, bc)
|
||||
// Create GORM repository and models
|
||||
gormRepo := repository.NewGormRepository(database.DB)
|
||||
gormModel := model.NewGormModel(database.DB, bc, config)
|
||||
auth.NewGormAuthService(gormRepo, bc)
|
||||
|
||||
// Set the global model instance
|
||||
model.Instance = gormModel
|
||||
|
||||
// Initialize storage backend for attachments
|
||||
err = api.InitializeAttachmentHandler(config.Storage)
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to initialize storage backend: %s", err.Error()))
|
||||
}
|
||||
|
||||
app, err := api.Init(config.ApiPrefix)
|
||||
if err != nil {
|
||||
|
||||
106
core/storage/interface.go
Normal file
106
core/storage/interface.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Storage defines the interface for file storage backends
|
||||
type Storage interface {
|
||||
// Store saves a file and returns the storage path/key
|
||||
Store(filename string, content io.Reader, contentType string) (string, error)
|
||||
|
||||
// Retrieve gets a file by its storage path/key
|
||||
Retrieve(path string) (io.ReadCloser, error)
|
||||
|
||||
// Delete removes a file by its storage path/key
|
||||
Delete(path string) error
|
||||
|
||||
// GetURL returns a URL for accessing the file (may be signed/temporary)
|
||||
GetURL(path string, expiry time.Duration) (string, error)
|
||||
|
||||
// Exists checks if a file exists at the given path
|
||||
Exists(path string) (bool, error)
|
||||
|
||||
// GetMetadata returns file metadata (size, last modified, etc.)
|
||||
GetMetadata(path string) (*FileMetadata, error)
|
||||
}
|
||||
|
||||
// FileMetadata contains information about a stored file
|
||||
type FileMetadata struct {
|
||||
Size int64
|
||||
LastModified time.Time
|
||||
ContentType string
|
||||
ETag string
|
||||
}
|
||||
|
||||
// Config holds configuration for storage backends
|
||||
type Config struct {
|
||||
// Storage backend type: "local", "s3"
|
||||
Backend string `mapstructure:"backend"`
|
||||
|
||||
// Local filesystem configuration
|
||||
Local LocalConfig `mapstructure:"local"`
|
||||
|
||||
// S3-compatible storage configuration (S3, B2, R2, etc.)
|
||||
S3 S3Config `mapstructure:"s3"`
|
||||
}
|
||||
|
||||
// LocalConfig configures local filesystem storage
|
||||
type LocalConfig struct {
|
||||
// Root directory for file storage
|
||||
RootDir string `mapstructure:"root_dir"`
|
||||
|
||||
// Base URL for serving files (optional)
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
}
|
||||
|
||||
// S3Config configures S3-compatible storage (AWS S3, Backblaze B2, Cloudflare R2, etc.)
|
||||
type S3Config struct {
|
||||
// AWS Region (use "auto" for Cloudflare R2)
|
||||
Region string `mapstructure:"region"`
|
||||
|
||||
// S3 Bucket name
|
||||
Bucket string `mapstructure:"bucket"`
|
||||
|
||||
// Optional prefix for all objects
|
||||
Prefix string `mapstructure:"prefix"`
|
||||
|
||||
// Access Key ID
|
||||
AccessKeyID string `mapstructure:"access_key_id"`
|
||||
|
||||
// Secret Access Key
|
||||
SecretAccessKey string `mapstructure:"secret_access_key"`
|
||||
|
||||
// Custom endpoint URL for S3-compatible services:
|
||||
// - Backblaze B2: https://s3.us-west-004.backblazeb2.com
|
||||
// - Cloudflare R2: https://<account-id>.r2.cloudflarestorage.com
|
||||
// - MinIO: http://localhost:9000
|
||||
// Leave empty for AWS S3
|
||||
Endpoint string `mapstructure:"endpoint"`
|
||||
|
||||
// Use path-style addressing (required for some S3-compatible services)
|
||||
PathStyle bool `mapstructure:"path_style"`
|
||||
}
|
||||
|
||||
|
||||
// NewStorage creates a new storage backend based on configuration
|
||||
func NewStorage(config Config) (Storage, error) {
|
||||
switch config.Backend {
|
||||
case "local", "":
|
||||
return NewLocalStorage(config.Local)
|
||||
case "s3":
|
||||
return NewS3Storage(config.S3)
|
||||
default:
|
||||
return nil, &UnsupportedBackendError{Backend: config.Backend}
|
||||
}
|
||||
}
|
||||
|
||||
// UnsupportedBackendError is returned when an unknown storage backend is requested
|
||||
type UnsupportedBackendError struct {
|
||||
Backend string
|
||||
}
|
||||
|
||||
func (e *UnsupportedBackendError) Error() string {
|
||||
return "unsupported storage backend: " + e.Backend
|
||||
}
|
||||
101
core/storage/interface_test.go
Normal file
101
core/storage/interface_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewStorage(t *testing.T) {
|
||||
t.Run("Local Storage", func(t *testing.T) {
|
||||
config := Config{
|
||||
Backend: "local",
|
||||
Local: LocalConfig{
|
||||
RootDir: t.TempDir(),
|
||||
},
|
||||
}
|
||||
|
||||
storage, err := NewStorage(config)
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, &LocalStorage{}, storage)
|
||||
})
|
||||
|
||||
t.Run("Default to Local Storage", func(t *testing.T) {
|
||||
config := Config{
|
||||
// No backend specified
|
||||
Local: LocalConfig{
|
||||
RootDir: t.TempDir(),
|
||||
},
|
||||
}
|
||||
|
||||
storage, err := NewStorage(config)
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, &LocalStorage{}, storage)
|
||||
})
|
||||
|
||||
t.Run("S3 Storage", func(t *testing.T) {
|
||||
config := Config{
|
||||
Backend: "s3",
|
||||
S3: S3Config{
|
||||
Region: "us-east-1",
|
||||
Bucket: "test-bucket",
|
||||
},
|
||||
}
|
||||
|
||||
// This might succeed if AWS credentials are available via IAM roles or env vars
|
||||
// Let's just check that we get an S3Storage instance or an error
|
||||
storage, err := NewStorage(config)
|
||||
if err != nil {
|
||||
// If it fails, that's expected in test environments without AWS access
|
||||
assert.Nil(t, storage)
|
||||
} else {
|
||||
// If it succeeds, we should get an S3Storage instance
|
||||
assert.IsType(t, &S3Storage{}, storage)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("B2 Storage", func(t *testing.T) {
|
||||
config := Config{
|
||||
Backend: "b2",
|
||||
B2: B2Config{
|
||||
AccountID: "test-account",
|
||||
ApplicationKey: "test-key",
|
||||
Bucket: "test-bucket",
|
||||
},
|
||||
}
|
||||
|
||||
// This will fail because we don't have real B2 credentials
|
||||
storage, err := NewStorage(config)
|
||||
assert.Error(t, err) // Expected to fail without credentials
|
||||
assert.Nil(t, storage)
|
||||
})
|
||||
|
||||
t.Run("Unsupported Backend", func(t *testing.T) {
|
||||
config := Config{
|
||||
Backend: "unsupported",
|
||||
}
|
||||
|
||||
storage, err := NewStorage(config)
|
||||
assert.Error(t, err)
|
||||
assert.IsType(t, &UnsupportedBackendError{}, err)
|
||||
assert.Nil(t, storage)
|
||||
assert.Contains(t, err.Error(), "unsupported")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorageErrors(t *testing.T) {
|
||||
t.Run("UnsupportedBackendError", func(t *testing.T) {
|
||||
err := &UnsupportedBackendError{Backend: "ftp"}
|
||||
assert.Equal(t, "unsupported storage backend: ftp", err.Error())
|
||||
})
|
||||
|
||||
t.Run("FileNotFoundError", func(t *testing.T) {
|
||||
err := &FileNotFoundError{Path: "missing.txt"}
|
||||
assert.Equal(t, "file not found: missing.txt", err.Error())
|
||||
})
|
||||
|
||||
t.Run("InvalidPathError", func(t *testing.T) {
|
||||
err := &InvalidPathError{Path: "../../../etc/passwd"}
|
||||
assert.Equal(t, "invalid path: ../../../etc/passwd", err.Error())
|
||||
})
|
||||
}
|
||||
243
core/storage/local.go
Normal file
243
core/storage/local.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
)
|
||||
|
||||
// LocalStorage implements the Storage interface for local filesystem
|
||||
type LocalStorage struct {
|
||||
rootDir string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewLocalStorage creates a new local filesystem storage backend
|
||||
func NewLocalStorage(config LocalConfig) (*LocalStorage, error) {
|
||||
rootDir := config.RootDir
|
||||
if rootDir == "" {
|
||||
rootDir = "./uploads"
|
||||
}
|
||||
|
||||
// Ensure the root directory exists
|
||||
if err := os.MkdirAll(rootDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage directory: %w", err)
|
||||
}
|
||||
|
||||
return &LocalStorage{
|
||||
rootDir: rootDir,
|
||||
baseURL: config.BaseURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Store saves a file to the local filesystem
|
||||
func (l *LocalStorage) Store(filename string, content io.Reader, contentType string) (string, error) {
|
||||
// Generate a unique storage path
|
||||
storagePath := l.generateStoragePath(filename)
|
||||
fullPath := filepath.Join(l.rootDir, storagePath)
|
||||
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Create and write the file
|
||||
file, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, content)
|
||||
if err != nil {
|
||||
// Clean up the file if write failed
|
||||
os.Remove(fullPath)
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
return storagePath, nil
|
||||
}
|
||||
|
||||
// Retrieve gets a file from the local filesystem
|
||||
func (l *LocalStorage) Retrieve(path string) (io.ReadCloser, error) {
|
||||
// Validate path to prevent directory traversal
|
||||
if err := l.validatePath(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(l.rootDir, path)
|
||||
file, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, &FileNotFoundError{Path: path}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Delete removes a file from the local filesystem
|
||||
func (l *LocalStorage) Delete(path string) error {
|
||||
// Validate path to prevent directory traversal
|
||||
if err := l.validatePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(l.rootDir, path)
|
||||
err := os.Remove(fullPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
}
|
||||
|
||||
// Try to remove empty parent directories
|
||||
l.cleanupEmptyDirs(filepath.Dir(fullPath))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetURL returns a URL for accessing the file
|
||||
func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error) {
|
||||
// Validate path to prevent directory traversal
|
||||
if err := l.validatePath(path); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
exists, err := l.Exists(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", &FileNotFoundError{Path: path}
|
||||
}
|
||||
|
||||
if l.baseURL != "" {
|
||||
// Return a public URL if base URL is configured
|
||||
return l.baseURL + "/" + path, nil
|
||||
}
|
||||
|
||||
// For local storage without a base URL, return the file path
|
||||
// In a real application, you might serve these through an endpoint
|
||||
return "/files/" + path, nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists at the given path
|
||||
func (l *LocalStorage) Exists(path string) (bool, error) {
|
||||
// Validate path to prevent directory traversal
|
||||
if err := l.validatePath(path); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(l.rootDir, path)
|
||||
_, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check file existence: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns file metadata
|
||||
func (l *LocalStorage) GetMetadata(path string) (*FileMetadata, error) {
|
||||
// Validate path to prevent directory traversal
|
||||
if err := l.validatePath(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(l.rootDir, path)
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, &FileNotFoundError{Path: path}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get file metadata: %w", err)
|
||||
}
|
||||
|
||||
return &FileMetadata{
|
||||
Size: info.Size(),
|
||||
LastModified: info.ModTime(),
|
||||
ContentType: "", // Local storage doesn't store content type
|
||||
ETag: "", // Local storage doesn't have ETags
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateStoragePath creates a unique storage path for a file
|
||||
func (l *LocalStorage) generateStoragePath(filename string) string {
|
||||
// Generate a unique ID for the file
|
||||
fileID := id.String(id.New())
|
||||
|
||||
// Extract file extension
|
||||
ext := filepath.Ext(filename)
|
||||
|
||||
// Create a path structure: YYYY/MM/DD/uuid.ext
|
||||
now := time.Now()
|
||||
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
|
||||
|
||||
return filepath.Join(datePath, fileID+ext)
|
||||
}
|
||||
|
||||
// validatePath ensures the path doesn't contain directory traversal attempts
|
||||
func (l *LocalStorage) validatePath(path string) error {
|
||||
// Clean the path and check for traversal attempts
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Reject paths that try to go up directories
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return &InvalidPathError{Path: path}
|
||||
}
|
||||
|
||||
// Reject absolute paths
|
||||
if filepath.IsAbs(cleanPath) {
|
||||
return &InvalidPathError{Path: path}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupEmptyDirs removes empty parent directories up to the root
|
||||
func (l *LocalStorage) cleanupEmptyDirs(dir string) {
|
||||
// Don't remove the root directory
|
||||
if dir == l.rootDir {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if directory is empty
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil || len(entries) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Remove empty directory
|
||||
if err := os.Remove(dir); err == nil {
|
||||
// Recursively clean parent directories
|
||||
l.cleanupEmptyDirs(filepath.Dir(dir))
|
||||
}
|
||||
}
|
||||
|
||||
// FileNotFoundError is returned when a file doesn't exist
|
||||
type FileNotFoundError struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func (e *FileNotFoundError) Error() string {
|
||||
return "file not found: " + e.Path
|
||||
}
|
||||
|
||||
// InvalidPathError is returned when a path is invalid or contains traversal attempts
|
||||
type InvalidPathError struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func (e *InvalidPathError) Error() string {
|
||||
return "invalid path: " + e.Path
|
||||
}
|
||||
202
core/storage/local_test.go
Normal file
202
core/storage/local_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLocalStorage(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
config := LocalConfig{
|
||||
RootDir: tmpDir,
|
||||
BaseURL: "http://localhost:8080/files",
|
||||
}
|
||||
|
||||
storage, err := NewLocalStorage(config)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, storage)
|
||||
|
||||
t.Run("Store and Retrieve File", func(t *testing.T) {
|
||||
content := []byte("test file content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Store file
|
||||
path, err := storage.Store("test.txt", reader, "text/plain")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, path)
|
||||
|
||||
// Verify file exists
|
||||
exists, err := storage.Exists(path)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Retrieve file
|
||||
retrievedReader, err := storage.Retrieve(path)
|
||||
assert.NoError(t, err)
|
||||
defer retrievedReader.Close()
|
||||
|
||||
retrievedContent, err := io.ReadAll(retrievedReader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, content, retrievedContent)
|
||||
})
|
||||
|
||||
t.Run("Get File Metadata", func(t *testing.T) {
|
||||
content := []byte("metadata test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
path, err := storage.Store("metadata.txt", reader, "text/plain")
|
||||
assert.NoError(t, err)
|
||||
|
||||
metadata, err := storage.GetMetadata(path)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(len(content)), metadata.Size)
|
||||
assert.False(t, metadata.LastModified.IsZero())
|
||||
})
|
||||
|
||||
t.Run("Get File URL", func(t *testing.T) {
|
||||
content := []byte("url test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
path, err := storage.Store("url.txt", reader, "text/plain")
|
||||
assert.NoError(t, err)
|
||||
|
||||
url, err := storage.GetURL(path, time.Hour)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, url, path)
|
||||
assert.Contains(t, url, config.BaseURL)
|
||||
})
|
||||
|
||||
t.Run("Delete File", func(t *testing.T) {
|
||||
content := []byte("delete test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
path, err := storage.Store("delete.txt", reader, "text/plain")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify file exists
|
||||
exists, err := storage.Exists(path)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete file
|
||||
err = storage.Delete(path)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify file no longer exists
|
||||
exists, err = storage.Exists(path)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Path Validation", func(t *testing.T) {
|
||||
// Test directory traversal prevention
|
||||
_, err := storage.Retrieve("../../../etc/passwd")
|
||||
assert.Error(t, err)
|
||||
assert.IsType(t, &InvalidPathError{}, err)
|
||||
|
||||
// Test absolute path rejection
|
||||
_, err = storage.Retrieve("/etc/passwd")
|
||||
assert.Error(t, err)
|
||||
assert.IsType(t, &InvalidPathError{}, err)
|
||||
})
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
_, err := storage.Retrieve("nonexistent.txt")
|
||||
assert.Error(t, err)
|
||||
assert.IsType(t, &FileNotFoundError{}, err)
|
||||
|
||||
_, err = storage.GetMetadata("nonexistent.txt")
|
||||
assert.Error(t, err)
|
||||
assert.IsType(t, &FileNotFoundError{}, err)
|
||||
|
||||
_, err = storage.GetURL("nonexistent.txt", time.Hour)
|
||||
assert.Error(t, err)
|
||||
assert.IsType(t, &FileNotFoundError{}, err)
|
||||
})
|
||||
|
||||
t.Run("Storage Path Generation", func(t *testing.T) {
|
||||
content := []byte("path test content")
|
||||
reader1 := bytes.NewReader(content)
|
||||
reader2 := bytes.NewReader(content)
|
||||
|
||||
// Store two files with same name
|
||||
path1, err := storage.Store("same.txt", reader1, "text/plain")
|
||||
assert.NoError(t, err)
|
||||
|
||||
path2, err := storage.Store("same.txt", reader2, "text/plain")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Paths should be different (unique)
|
||||
assert.NotEqual(t, path1, path2)
|
||||
|
||||
// Both should exist
|
||||
exists1, err := storage.Exists(path1)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists1)
|
||||
|
||||
exists2, err := storage.Exists(path2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists2)
|
||||
|
||||
// Both should have correct extension
|
||||
assert.True(t, strings.HasSuffix(path1, ".txt"))
|
||||
assert.True(t, strings.HasSuffix(path2, ".txt"))
|
||||
|
||||
// Should be organized by date
|
||||
now := time.Now()
|
||||
expectedPrefix := filepath.Join(
|
||||
fmt.Sprintf("%04d", now.Year()),
|
||||
fmt.Sprintf("%02d", now.Month()),
|
||||
fmt.Sprintf("%02d", now.Day()),
|
||||
)
|
||||
assert.True(t, strings.HasPrefix(path1, expectedPrefix))
|
||||
assert.True(t, strings.HasPrefix(path2, expectedPrefix))
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalStorageConfig(t *testing.T) {
|
||||
t.Run("Default Root Directory", func(t *testing.T) {
|
||||
config := LocalConfig{} // Empty config
|
||||
|
||||
storage, err := NewLocalStorage(config)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, storage)
|
||||
|
||||
// Should create default uploads directory
|
||||
assert.Equal(t, "./uploads", storage.rootDir)
|
||||
|
||||
// Verify directory was created
|
||||
_, err = os.Stat("./uploads")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Clean up
|
||||
os.RemoveAll("./uploads")
|
||||
})
|
||||
|
||||
t.Run("Custom Root Directory", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
customDir := filepath.Join(tmpDir, "custom", "storage")
|
||||
|
||||
config := LocalConfig{
|
||||
RootDir: customDir,
|
||||
}
|
||||
|
||||
storage, err := NewLocalStorage(config)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, customDir, storage.rootDir)
|
||||
|
||||
// Verify custom directory was created
|
||||
_, err = os.Stat(customDir)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
236
core/storage/s3.go
Normal file
236
core/storage/s3.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
)
|
||||
|
||||
// S3Storage implements the Storage interface for Amazon S3
|
||||
type S3Storage struct {
|
||||
client *s3.S3
|
||||
uploader *s3manager.Uploader
|
||||
bucket string
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewS3Storage creates a new S3 storage backend
|
||||
func NewS3Storage(config S3Config) (*S3Storage, error) {
|
||||
if config.Bucket == "" {
|
||||
return nil, fmt.Errorf("S3 bucket name is required")
|
||||
}
|
||||
|
||||
// Create AWS config
|
||||
awsConfig := &aws.Config{
|
||||
Region: aws.String(config.Region),
|
||||
}
|
||||
|
||||
// Set custom endpoint if provided (for S3-compatible services)
|
||||
if config.Endpoint != "" {
|
||||
awsConfig.Endpoint = aws.String(config.Endpoint)
|
||||
awsConfig.S3ForcePathStyle = aws.Bool(config.PathStyle)
|
||||
}
|
||||
|
||||
// Set credentials if provided
|
||||
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
|
||||
awsConfig.Credentials = credentials.NewStaticCredentials(
|
||||
config.AccessKeyID,
|
||||
config.SecretAccessKey,
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
// Create session
|
||||
sess, err := session.NewSession(awsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AWS session: %w", err)
|
||||
}
|
||||
|
||||
// Create S3 client
|
||||
client := s3.New(sess)
|
||||
uploader := s3manager.NewUploader(sess)
|
||||
|
||||
return &S3Storage{
|
||||
client: client,
|
||||
uploader: uploader,
|
||||
bucket: config.Bucket,
|
||||
prefix: config.Prefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Store saves a file to S3
|
||||
func (s *S3Storage) Store(filename string, content io.Reader, contentType string) (string, error) {
|
||||
// Generate a unique storage key
|
||||
storageKey := s.generateStorageKey(filename)
|
||||
|
||||
// Prepare upload input
|
||||
input := &s3manager.UploadInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(storageKey),
|
||||
Body: content,
|
||||
}
|
||||
|
||||
// Set content type if provided
|
||||
if contentType != "" {
|
||||
input.ContentType = aws.String(contentType)
|
||||
}
|
||||
|
||||
// Upload the file
|
||||
_, err := s.uploader.Upload(input)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload file to S3: %w", err)
|
||||
}
|
||||
|
||||
return storageKey, nil
|
||||
}
|
||||
|
||||
// Retrieve gets a file from S3
|
||||
func (s *S3Storage) Retrieve(path string) (io.ReadCloser, error) {
|
||||
input := &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
}
|
||||
|
||||
result, err := s.client.GetObject(input)
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
switch aerr.Code() {
|
||||
case s3.ErrCodeNoSuchKey:
|
||||
return nil, &FileNotFoundError{Path: path}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to retrieve file from S3: %w", err)
|
||||
}
|
||||
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
// Delete removes a file from S3
|
||||
func (s *S3Storage) Delete(path string) error {
|
||||
input := &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
}
|
||||
|
||||
_, err := s.client.DeleteObject(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete file from S3: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetURL returns a presigned URL for accessing the file
|
||||
func (s *S3Storage) GetURL(path string, expiry time.Duration) (string, error) {
|
||||
// Check if file exists first
|
||||
exists, err := s.Exists(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", &FileNotFoundError{Path: path}
|
||||
}
|
||||
|
||||
// Generate presigned URL
|
||||
req, _ := s.client.GetObjectRequest(&s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
})
|
||||
|
||||
url, err := req.Presign(expiry)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate presigned URL: %w", err)
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists in S3
|
||||
func (s *S3Storage) Exists(path string) (bool, error) {
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
}
|
||||
|
||||
_, err := s.client.HeadObject(input)
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
switch aerr.Code() {
|
||||
case s3.ErrCodeNoSuchKey, "NotFound":
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns file metadata from S3
|
||||
func (s *S3Storage) GetMetadata(path string) (*FileMetadata, error) {
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
}
|
||||
|
||||
result, err := s.client.HeadObject(input)
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
switch aerr.Code() {
|
||||
case s3.ErrCodeNoSuchKey, "NotFound":
|
||||
return nil, &FileNotFoundError{Path: path}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get file metadata from S3: %w", err)
|
||||
}
|
||||
|
||||
metadata := &FileMetadata{
|
||||
Size: aws.Int64Value(result.ContentLength),
|
||||
}
|
||||
|
||||
if result.LastModified != nil {
|
||||
metadata.LastModified = *result.LastModified
|
||||
}
|
||||
|
||||
if result.ContentType != nil {
|
||||
metadata.ContentType = *result.ContentType
|
||||
}
|
||||
|
||||
if result.ETag != nil {
|
||||
metadata.ETag = strings.Trim(*result.ETag, "\"")
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// generateStorageKey creates a unique storage key for a file
|
||||
func (s *S3Storage) generateStorageKey(filename string) string {
|
||||
// Generate a unique ID for the file
|
||||
fileID := id.String(id.New())
|
||||
|
||||
// Extract file extension
|
||||
ext := path.Ext(filename)
|
||||
|
||||
// Create a key structure: prefix/YYYY/MM/DD/uuid.ext
|
||||
now := time.Now()
|
||||
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
|
||||
|
||||
key := path.Join(datePath, fileID+ext)
|
||||
|
||||
// Add prefix if configured
|
||||
if s.prefix != "" {
|
||||
key = path.Join(s.prefix, key)
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
48
core/util/id/id.go
Normal file
48
core/util/id/id.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// New creates a new binary ID (16 bytes) using UUID v4
|
||||
func New() []byte {
|
||||
u := uuid.New()
|
||||
return u[:]
|
||||
}
|
||||
|
||||
// ToUUID converts a binary ID back to UUID
|
||||
func ToUUID(b []byte) (uuid.UUID, error) {
|
||||
return uuid.FromBytes(b)
|
||||
}
|
||||
|
||||
// String returns the string representation of a binary ID
|
||||
func String(b []byte) string {
|
||||
u, err := uuid.FromBytes(b)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// FromString parses a string UUID into binary format
|
||||
func FromString(s string) ([]byte, error) {
|
||||
u, err := uuid.Parse(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return u[:], nil
|
||||
}
|
||||
|
||||
// Uint64ToBytes converts a uint64 to 8-byte slice
|
||||
func Uint64ToBytes(v uint64) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, v)
|
||||
return b
|
||||
}
|
||||
|
||||
// BytesToUint64 converts 8-byte slice to uint64
|
||||
func BytesToUint64(b []byte) uint64 {
|
||||
return binary.BigEndian.Uint64(b)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package util
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -44,3 +45,23 @@ func NewInviteId() (string, error) {
|
||||
|
||||
return hex.EncodeToString(byteArray), nil
|
||||
}
|
||||
|
||||
func NewUUID() string {
|
||||
guid, err := NewGuid()
|
||||
if err != nil {
|
||||
// Fallback to timestamp-based UUID if random generation fails
|
||||
return hex.EncodeToString([]byte(time.Now().Format("20060102150405")))
|
||||
}
|
||||
return guid
|
||||
}
|
||||
|
||||
func IsValidUUID(uuid string) bool {
|
||||
// Check if the string is a valid 32-character hex string (16 bytes * 2 hex chars)
|
||||
if len(uuid) != 32 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if all characters are valid hex characters
|
||||
matched, _ := regexp.MatchString("^[0-9a-f]{32}$", uuid)
|
||||
return matched
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func Handler(w rest.ResponseWriter, r *rest.Request) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("recv: %s", message)
|
||||
log.Printf("recv: %+v", message)
|
||||
|
||||
// check version
|
||||
err = checkVersion(message.Version)
|
||||
|
||||
337
database/database.go
Normal file
337
database/database.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
"github.com/openaccounting/oa-server/models"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
type Config struct {
|
||||
Driver string // "mysql" or "sqlite"
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
// SQLite specific
|
||||
File string // SQLite database file path
|
||||
}
|
||||
|
||||
func Connect(config *Config) error {
|
||||
// Configure GORM logger
|
||||
newLogger := logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: true,
|
||||
},
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
|
||||
// Choose driver based on config
|
||||
switch config.Driver {
|
||||
case "sqlite":
|
||||
// Use SQLite
|
||||
dbFile := config.File
|
||||
if dbFile == "" {
|
||||
dbFile = "./openaccounting.db" // Default SQLite file
|
||||
}
|
||||
db, err = gorm.Open(sqlite.Open(dbFile), &gorm.Config{
|
||||
Logger: newLogger,
|
||||
})
|
||||
case "mysql":
|
||||
// Use MySQL (existing logic)
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, config.Password, config.Host, config.Port, config.DBName)
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: newLogger,
|
||||
})
|
||||
default:
|
||||
return fmt.Errorf("unsupported database driver: %s (supported: mysql, sqlite)", config.Driver)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool (only for MySQL, SQLite handles this internally)
|
||||
if config.Driver == "mysql" {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database instance: %w", err)
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(25)
|
||||
sqlDB.SetMaxIdleConns(25)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
}
|
||||
|
||||
DB = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// AutoMigrate runs automatic migrations for all models
|
||||
func AutoMigrate() error {
|
||||
return DB.AutoMigrate(
|
||||
&models.Org{},
|
||||
&models.User{},
|
||||
&models.UserOrg{},
|
||||
&models.Token{},
|
||||
&models.Account{},
|
||||
&models.Transaction{},
|
||||
&models.Split{},
|
||||
&models.Balance{},
|
||||
&models.Permission{},
|
||||
&models.Price{},
|
||||
&models.Session{},
|
||||
&models.APIKey{},
|
||||
&models.Invite{},
|
||||
&models.BudgetItem{},
|
||||
&models.Attachment{},
|
||||
)
|
||||
}
|
||||
|
||||
// Migrate runs custom migrations
|
||||
func Migrate() error {
|
||||
// Create indexes
|
||||
if err := createIndexes(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert default data - temporarily disabled for testing
|
||||
// if err := seedDefaultData(); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createIndexes() error {
|
||||
// Create custom indexes that GORM doesn't handle automatically
|
||||
// Based on original indexes.sql file
|
||||
indexes := []string{
|
||||
// Original indexes from indexes.sql
|
||||
"CREATE INDEX IF NOT EXISTS account_orgId_index ON accounts(orgId)",
|
||||
"CREATE INDEX IF NOT EXISTS split_accountId_index ON splits(accountId)",
|
||||
"CREATE INDEX IF NOT EXISTS split_transactionId_index ON splits(transactionId)",
|
||||
"CREATE INDEX IF NOT EXISTS split_date_index ON splits(date)",
|
||||
"CREATE INDEX IF NOT EXISTS split_updated_index ON splits(updated)",
|
||||
"CREATE INDEX IF NOT EXISTS budgetitem_orgId_index ON budget_items(orgId)",
|
||||
"CREATE INDEX IF NOT EXISTS attachment_transactionId_index ON attachment(transactionId)",
|
||||
"CREATE INDEX IF NOT EXISTS attachment_orgId_index ON attachment(orgId)",
|
||||
"CREATE INDEX IF NOT EXISTS attachment_userId_index ON attachment(userId)",
|
||||
"CREATE INDEX IF NOT EXISTS attachment_uploaded_index ON attachment(uploaded)",
|
||||
|
||||
// Additional useful indexes for performance
|
||||
"CREATE INDEX IF NOT EXISTS idx_transaction_date ON transactions(date)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_transaction_org ON transactions(orgId)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_account_parent ON accounts(parent)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_userorg_user ON user_orgs(userId)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_userorg_org ON user_orgs(orgId)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_balance_account_date ON balances(accountId, date)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_permission_org_account ON permissions(orgId, accountId)",
|
||||
}
|
||||
|
||||
for _, idx := range indexes {
|
||||
if err := DB.Exec(idx).Error; err != nil {
|
||||
return fmt.Errorf("failed to create index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedDefaultData() error {
|
||||
// Check if we already have data
|
||||
var count int64
|
||||
if err := DB.Model(&models.Org{}).Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil // Data already exists
|
||||
}
|
||||
|
||||
// Create a default organization
|
||||
defaultOrg := models.Org{
|
||||
ID: id.New(), // You'll need to implement this
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Default Organization",
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
Timezone: "UTC",
|
||||
}
|
||||
|
||||
if err := DB.Create(&defaultOrg).Error; err != nil {
|
||||
return fmt.Errorf("failed to create default organization: %w", err)
|
||||
}
|
||||
|
||||
// Create default accounts for the organization
|
||||
defaultAccounts := []models.Account{
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Assets",
|
||||
Parent: []byte{0}, // Root account has zero parent
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: true,
|
||||
},
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Liabilities",
|
||||
Parent: []byte{0},
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: false,
|
||||
},
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Equity",
|
||||
Parent: []byte{0},
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: false,
|
||||
},
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Revenue",
|
||||
Parent: []byte{0},
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: false,
|
||||
},
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Expenses",
|
||||
Parent: []byte{0},
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Create accounts and store their IDs for parent-child relationships
|
||||
accountMap := make(map[string]*models.Account)
|
||||
|
||||
for _, acc := range defaultAccounts {
|
||||
account := acc
|
||||
if err := DB.Create(&account).Error; err != nil {
|
||||
return fmt.Errorf("failed to create account %s: %w", acc.Name, err)
|
||||
}
|
||||
accountMap[acc.Name] = &account
|
||||
}
|
||||
|
||||
// Create Current Assets first
|
||||
var assetsParent []byte
|
||||
if assetsAccount, exists := accountMap["Assets"]; exists {
|
||||
assetsParent = assetsAccount.ID
|
||||
} else {
|
||||
return fmt.Errorf("Assets account not found in accountMap")
|
||||
}
|
||||
|
||||
currentAssets := models.Account{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Current Assets",
|
||||
Parent: assetsParent,
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: true,
|
||||
}
|
||||
|
||||
if err := DB.Create(¤tAssets).Error; err != nil {
|
||||
return fmt.Errorf("failed to create Current Assets: %w", err)
|
||||
}
|
||||
accountMap["Current Assets"] = ¤tAssets
|
||||
|
||||
// Create Accounts Payable
|
||||
var liabilitiesParent []byte
|
||||
if liabilitiesAccount, exists := accountMap["Liabilities"]; exists {
|
||||
liabilitiesParent = liabilitiesAccount.ID
|
||||
} else {
|
||||
return fmt.Errorf("Liabilities account not found in accountMap")
|
||||
}
|
||||
|
||||
accountsPayable := models.Account{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Accounts Payable",
|
||||
Parent: liabilitiesParent,
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: false,
|
||||
}
|
||||
|
||||
if err := DB.Create(&accountsPayable).Error; err != nil {
|
||||
return fmt.Errorf("failed to create Accounts Payable: %w", err)
|
||||
}
|
||||
|
||||
// Now create sub-accounts under Current Assets
|
||||
subAccounts := []models.Account{
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Cash",
|
||||
Parent: currentAssets.ID,
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: true,
|
||||
},
|
||||
{
|
||||
ID: id.New(),
|
||||
OrgID: defaultOrg.ID,
|
||||
Inserted: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Updated: uint64(time.Now().UnixNano() / int64(time.Millisecond)),
|
||||
Name: "Accounts Receivable",
|
||||
Parent: currentAssets.ID,
|
||||
Currency: "USD",
|
||||
Precision: 2,
|
||||
DebitBalance: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, acc := range subAccounts {
|
||||
if err := DB.Create(&acc).Error; err != nil {
|
||||
return fmt.Errorf("failed to create sub-account %s: %w", acc.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
51
go.mod
51
go.mod
@@ -1,16 +1,53 @@
|
||||
module github.com/openaccounting/oa-server
|
||||
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf
|
||||
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01
|
||||
github.com/go-sql-driver/mysql v1.4.1
|
||||
github.com/aws/aws-sdk-go v1.44.0
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da
|
||||
github.com/mailgun/mailgun-go/v4 v4.3.0
|
||||
github.com/mitchellh/mapstructure v0.0.0-20180511142126-bb74f1db0675
|
||||
github.com/sendgrid/rest v0.0.0-20180905234047-875828e14d98 // indirect
|
||||
github.com/sendgrid/sendgrid-go v0.0.0-20180905233524-8cb43f4ca4f5 // indirect
|
||||
github.com/stretchr/objx v0.3.0 // indirect
|
||||
github.com/stretchr/testify v1.3.0
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
github.com/spf13/viper v1.20.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/crypto v0.32.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-chi/chi v4.0.0+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.10 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/gorm v1.30.0
|
||||
)
|
||||
|
||||
113
go.sum
113
go.sum
@@ -1,7 +1,11 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf h1:BMUJnVJI5J506LOcyGHEvbCocMHAmKTRcG6CMAwGFYU=
|
||||
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
||||
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01 h1:oYAjCHMjyRaNBo3nUEepDce4LC+Kuh+6jU6y+AllvnU=
|
||||
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01/go.mod h1:q6aCt0GfU6LhpBsnZ/2U+mwe+0XB5WStbmwyoPfc+sk=
|
||||
github.com/aws/aws-sdk-go v1.44.0 h1:jwtHuNqfnJxL4DKHBUVUmQlfueQqBW7oXP6yebZR/R0=
|
||||
github.com/aws/aws-sdk-go v1.44.0/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -11,53 +15,104 @@ github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 h1:JWuenKqqX8nojt
|
||||
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052/go.mod h1:UbMTZqLaRiH3MsBH8va0n7s1pQYcu3uTb8G4tygF4Zg=
|
||||
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQDg5gKsWoLBOB0n+ZW8s599zru8FJ2/Y=
|
||||
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-chi/chi v4.0.0+incompatible h1:SiLLEDyAkqNnw+T/uDTf3aFB9T4FTrwMpuYrgaRcnW4=
|
||||
github.com/go-chi/chi v4.0.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
|
||||
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
|
||||
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da h1:b5fma7aUP2fn6+tdKKCJ0TxXYzY/5wDiqUxNdyi5VF4=
|
||||
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
|
||||
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
|
||||
github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68=
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mailgun/mailgun-go/v4 v4.3.0 h1:9nAF7LI3k6bfDPbMZQMMl63Q8/vs+dr1FUN8eR1XMhk=
|
||||
github.com/mailgun/mailgun-go/v4 v4.3.0/go.mod h1:fWuBI2iaS/pSSyo6+EBpHjatQO3lV8onwqcRy7joSJI=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mitchellh/mapstructure v0.0.0-20180511142126-bb74f1db0675 h1:/rdJjIiKG5rRdwG5yxHmSE/7ZREjpyC0kL7GxGT/qJw=
|
||||
github.com/mitchellh/mapstructure v0.0.0-20180511142126-bb74f1db0675/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sendgrid/rest v0.0.0-20180905234047-875828e14d98 h1:wpBZ5DAYLNl+2v4E4WP8k/y8tM5OjIf1FezJS1qX8sU=
|
||||
github.com/sendgrid/rest v0.0.0-20180905234047-875828e14d98/go.mod h1:kXX7q3jZtJXK5c5qK83bSGMdV6tsOE70KbHoqJls4lE=
|
||||
github.com/sendgrid/sendgrid-go v0.0.0-20180905233524-8cb43f4ca4f5 h1:V18LU+jSbihmDiWfLSzs9FV1d3KVB1gRTkNxgVHmcvg=
|
||||
github.com/sendgrid/sendgrid-go v0.0.0-20180905233524-8cb43f4ca4f5/go.mod h1:QRQt+LX/NmgVEvmdRw0VT/QgUn499+iza2FnDca9fg8=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
|
||||
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
|
||||
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
|
||||
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As=
|
||||
github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
golang.org/x/crypto v0.0.0-20171231215028-0fcca4842a8d h1:GrqEEc3+MtHKTsZrdIGVoYDgLpbSRzW1EF+nLu0PcHE=
|
||||
golang.org/x/crypto v0.0.0-20171231215028-0fcca4842a8d/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
|
||||
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
|
||||
@@ -3,4 +3,8 @@ CREATE INDEX split_accountId_index ON split (accountId);
|
||||
CREATE INDEX split_transactionId_index ON split (transactionId);
|
||||
CREATE INDEX split_date_index ON split (date);
|
||||
CREATE INDEX split_updated_index ON split (updated);
|
||||
CREATE INDEX budgetitem_orgId_index ON budgetitem (orgId);
|
||||
CREATE INDEX budgetitem_orgId_index ON budgetitem (orgId);
|
||||
CREATE INDEX attachment_transactionId_index ON attachment (transactionId);
|
||||
CREATE INDEX attachment_orgId_index ON attachment (orgId);
|
||||
CREATE INDEX attachment_userId_index ON attachment (userId);
|
||||
CREATE INDEX attachment_uploaded_index ON attachment (uploaded);
|
||||
166
justfile
Normal file
166
justfile
Normal file
@@ -0,0 +1,166 @@
|
||||
# OpenAccounting Server - Just recipes
|
||||
# https://github.com/casey/just
|
||||
|
||||
# Default recipe
|
||||
default:
|
||||
@just --list
|
||||
|
||||
# Variables
|
||||
image_name := "openaccounting-server"
|
||||
tag := "latest"
|
||||
|
||||
# Build the Go application
|
||||
build:
|
||||
@echo "Building OpenAccounting Server..."
|
||||
go build -o server ./core/
|
||||
|
||||
# Run the server locally
|
||||
run: build
|
||||
@echo "Starting server locally..."
|
||||
./server
|
||||
|
||||
# Run with custom environment
|
||||
run-dev: build
|
||||
@echo "Starting server in development mode..."
|
||||
OA_DATABASEDRIVER=sqlite OA_DATABASEFILE=./dev.db OA_PORT=8080 ./server
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
go test ./...
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
rm -f server
|
||||
rm -f *.db
|
||||
|
||||
# Docker recipes
|
||||
|
||||
# Build Docker image
|
||||
docker-build:
|
||||
@echo "Building Docker image: {{image_name}}:{{tag}}"
|
||||
docker build -t {{image_name}}:{{tag}} .
|
||||
|
||||
# Run container with SQLite (development)
|
||||
docker-run: docker-build
|
||||
@echo "Running container with SQLite..."
|
||||
docker run --rm -p 8080:8080 \
|
||||
-e OA_DATABASE_DRIVER=sqlite \
|
||||
-e OA_DATABASE_FILE=/app/data/openaccounting.db \
|
||||
-v $(pwd)/data:/app/data \
|
||||
{{image_name}}:{{tag}}
|
||||
|
||||
# Run container with MySQL (production example)
|
||||
docker-run-mysql: docker-build
|
||||
@echo "Running container with MySQL (requires external MySQL)..."
|
||||
docker run --rm -p 8080:8080 \
|
||||
-e OA_DATABASE_DRIVER=mysql \
|
||||
-e OA_DATABASE_ADDRESS=mysql:3306 \
|
||||
-e OA_DATABASE=openaccounting \
|
||||
-e OA_USER=openaccounting \
|
||||
-e OA_PASSWORD=secret \
|
||||
{{image_name}}:{{tag}}
|
||||
|
||||
# Run with docker-compose (if you create one)
|
||||
docker-compose-up:
|
||||
@echo "Starting with docker-compose..."
|
||||
docker-compose up -d
|
||||
|
||||
docker-compose-down:
|
||||
@echo "Stopping docker-compose..."
|
||||
docker-compose down
|
||||
|
||||
# Development utilities
|
||||
|
||||
# Format code
|
||||
fmt:
|
||||
@echo "Formatting code..."
|
||||
go fmt ./...
|
||||
|
||||
# Lint code (requires golangci-lint)
|
||||
lint:
|
||||
@echo "Linting code..."
|
||||
golangci-lint run
|
||||
|
||||
# Install development dependencies
|
||||
install-deps:
|
||||
@echo "Installing development dependencies..."
|
||||
go mod download
|
||||
go mod vendor
|
||||
|
||||
# Update dependencies
|
||||
update-deps:
|
||||
@echo "Updating dependencies..."
|
||||
go get -u ./...
|
||||
go mod tidy
|
||||
go mod vendor
|
||||
|
||||
# Database utilities
|
||||
|
||||
# Create SQLite database directory
|
||||
init-db:
|
||||
@echo "Creating database directory..."
|
||||
mkdir -p data
|
||||
|
||||
# Reset SQLite database
|
||||
reset-db:
|
||||
@echo "Resetting SQLite database..."
|
||||
rm -f *.db data/*.db
|
||||
|
||||
# Migration recipes
|
||||
|
||||
# Run database migrations manually (if needed)
|
||||
migrate:
|
||||
@echo "Running database migrations..."
|
||||
go run ./core/ --migrate-only || echo "Migration command not implemented yet"
|
||||
|
||||
# Production utilities
|
||||
|
||||
# Build for production
|
||||
build-prod:
|
||||
@echo "Building for production..."
|
||||
CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -ldflags="-w -s" -o server ./core/
|
||||
|
||||
# Create release tarball
|
||||
release: build-prod
|
||||
@echo "Creating release package..."
|
||||
tar -czf openaccounting-server-$(date +%Y%m%d).tar.gz server config.json.sample README.md
|
||||
|
||||
# Security scan (requires trivy)
|
||||
security-scan:
|
||||
@echo "Scanning Docker image for vulnerabilities..."
|
||||
trivy image {{image_name}}:{{tag}}
|
||||
|
||||
# Show configuration help
|
||||
config-help:
|
||||
@echo "OpenAccounting Server Configuration:"
|
||||
@echo ""
|
||||
@echo "Environment Variables (prefix with OA_):"
|
||||
@echo " OA_ADDRESS Server address (default: localhost)"
|
||||
@echo " OA_PORT Server port (default: 8080)"
|
||||
@echo " OA_API_PREFIX API prefix (default: /api/v1)"
|
||||
@echo " OA_DATABASE_DRIVER Database driver: sqlite or mysql (default: sqlite)"
|
||||
@echo " OA_DATABASE_FILE SQLite database file (default: ./openaccounting.db)"
|
||||
@echo " OA_DATABASE_ADDRESS MySQL address (e.g., localhost:3306)"
|
||||
@echo " OA_DATABASE MySQL database name"
|
||||
@echo " OA_USER Database username"
|
||||
@echo " OA_PASSWORD Database password (recommended for security)"
|
||||
@echo " OA_MAILGUN_DOMAIN Mailgun domain"
|
||||
@echo " OA_MAILGUN_KEY Mailgun API key (recommended for security)"
|
||||
@echo " OA_MAILGUN_EMAIL Mailgun email"
|
||||
@echo " OA_MAILGUN_SENDER Mailgun sender name"
|
||||
@echo ""
|
||||
@echo "Examples:"
|
||||
@echo " Development: OA_DATABASE_DRIVER=sqlite OA_PORT=8080 ./server"
|
||||
@echo " Production: OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret ./server"
|
||||
|
||||
# All-in-one development setup
|
||||
dev-setup: install-deps init-db build
|
||||
@echo "Development setup complete!"
|
||||
@echo "Run 'just run-dev' to start the server"
|
||||
|
||||
# All-in-one production build
|
||||
prod-build: clean build-prod docker-build
|
||||
@echo "Production build complete!"
|
||||
@echo "Run 'just docker-run' to test the container"
|
||||
@@ -1,105 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 2 {
|
||||
log.Fatal("Usage: migrate1.go <upgrade/downgrade>")
|
||||
}
|
||||
|
||||
command := os.Args[1]
|
||||
|
||||
if command != "upgrade" && command != "downgrade" {
|
||||
log.Fatal("Usage: migrate1.go <upgrade/downgrade>")
|
||||
}
|
||||
|
||||
//filename is the path to the json config file
|
||||
var config types.Config
|
||||
file, err := os.Open("./config.json")
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
err = decoder.Decode(&config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
connectionString := config.User + ":" + config.Password + "@/" + config.Database
|
||||
db, err := db.NewDB(connectionString)
|
||||
|
||||
if command == "upgrade" {
|
||||
err = upgrade(db)
|
||||
} else {
|
||||
err = downgrade(db)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Println("done")
|
||||
}
|
||||
|
||||
func upgrade(db *db.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p) // re-throw panic after Rollback
|
||||
} else if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
query1 := "ALTER TABLE user ADD COLUMN signupSource VARCHAR(100) NOT NULL AFTER emailVerifyCode"
|
||||
|
||||
if _, err = tx.Exec(query1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func downgrade(db *db.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p) // re-throw panic after Rollback
|
||||
} else if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
query1 := "ALTER TABLE user DROP COLUMN signupSource"
|
||||
|
||||
if _, err = tx.Exec(query1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 2 {
|
||||
log.Fatal("Usage: migrate2.go <upgrade/downgrade>")
|
||||
}
|
||||
|
||||
command := os.Args[1]
|
||||
|
||||
if command != "upgrade" && command != "downgrade" {
|
||||
log.Fatal("Usage: migrate2.go <upgrade/downgrade>")
|
||||
}
|
||||
|
||||
//filename is the path to the json config file
|
||||
var config types.Config
|
||||
file, err := os.Open("./config.json")
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
err = decoder.Decode(&config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
connectionString := config.User + ":" + config.Password + "@/" + config.Database
|
||||
db, err := db.NewDB(connectionString)
|
||||
|
||||
if command == "upgrade" {
|
||||
err = upgrade(db)
|
||||
} else {
|
||||
err = downgrade(db)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Println("done")
|
||||
}
|
||||
|
||||
func upgrade(db *db.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p) // re-throw panic after Rollback
|
||||
} else if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
query1 := "ALTER TABLE org ADD COLUMN timezone VARCHAR(100) NOT NULL AFTER `precision`"
|
||||
|
||||
if _, err = tx.Exec(query1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func downgrade(db *db.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p) // re-throw panic after Rollback
|
||||
} else if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
query1 := "ALTER TABLE org DROP COLUMN timezone"
|
||||
|
||||
if _, err = tx.Exec(query1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/openaccounting/oa-server/core/model/db"
|
||||
"github.com/openaccounting/oa-server/core/model/types"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 2 {
|
||||
log.Fatal("Usage: migrate3.go <upgrade/downgrade>")
|
||||
}
|
||||
|
||||
command := os.Args[1]
|
||||
|
||||
if command != "upgrade" && command != "downgrade" {
|
||||
log.Fatal("Usage: migrate3.go <upgrade/downgrade>")
|
||||
}
|
||||
|
||||
//filename is the path to the json config file
|
||||
var config types.Config
|
||||
file, err := os.Open("./config.json")
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
err = decoder.Decode(&config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
connectionString := config.User + ":" + config.Password + "@/" + config.Database
|
||||
db, err := db.NewDB(connectionString)
|
||||
|
||||
if command == "upgrade" {
|
||||
err = upgrade(db)
|
||||
} else {
|
||||
err = downgrade(db)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Println("done")
|
||||
}
|
||||
|
||||
func upgrade(db *db.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p) // re-throw panic after Rollback
|
||||
} else if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
query1 := "CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;"
|
||||
|
||||
if _, err = tx.Exec(query1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func downgrade(db *db.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p) // re-throw panic after Rollback
|
||||
} else if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
query1 := "DROP TABLE budgetitem"
|
||||
|
||||
if _, err = tx.Exec(query1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
18
models/account.go
Normal file
18
models/account.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package models
|
||||
|
||||
// Account represents a financial account
|
||||
type Account struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
Name string `gorm:"column:name;size:100;not null"`
|
||||
Parent []byte `gorm:"column:parent;type:BINARY(16);not null"`
|
||||
Currency string `gorm:"column:currency;size:10;not null"`
|
||||
Precision int `gorm:"column:precision;not null"`
|
||||
DebitBalance bool `gorm:"column:debitBalance;not null"`
|
||||
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
Splits []Split `gorm:"foreignKey:AccountID"`
|
||||
Balances []Balance `gorm:"foreignKey:AccountID"`
|
||||
}
|
||||
13
models/api_key.go
Normal file
13
models/api_key.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package models
|
||||
|
||||
// APIKey represents API keys for users
|
||||
type APIKey struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
|
||||
Label string `gorm:"column:label;size:300;not null"`
|
||||
Deleted uint64 `gorm:"column:deleted"`
|
||||
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
}
|
||||
28
models/attachment.go
Normal file
28
models/attachment.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Attachment struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
TransactionID []byte `gorm:"column:transactionId;type:BINARY(16);not null"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
|
||||
FileName string `gorm:"column:fileName;size:255;not null"`
|
||||
OriginalName string `gorm:"column:originalName;size:255;not null"`
|
||||
ContentType string `gorm:"column:contentType;size:100;not null"`
|
||||
FileSize int64 `gorm:"column:fileSize;not null"`
|
||||
FilePath string `gorm:"column:filePath;size:500;not null"`
|
||||
Description string `gorm:"column:description;size:500"`
|
||||
Uploaded time.Time `gorm:"column:uploaded;not null"`
|
||||
Deleted bool `gorm:"column:deleted;default:false"`
|
||||
|
||||
Transaction Transaction `gorm:"foreignKey:TransactionID"`
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
func (Attachment) TableName() string {
|
||||
return "attachment"
|
||||
}
|
||||
11
models/balance.go
Normal file
11
models/balance.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package models
|
||||
|
||||
// Balance represents an account balance at a point in time
|
||||
type Balance struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
Date uint64 `gorm:"column:date;not null"`
|
||||
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
|
||||
Amount int64 `gorm:"column:amount;not null"`
|
||||
|
||||
Account Account `gorm:"foreignKey:AccountID"`
|
||||
}
|
||||
45
models/base.go
Normal file
45
models/base.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/openaccounting/oa-server/core/util/id"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Base struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
}
|
||||
|
||||
// GetUUID converts binary ID to UUID
|
||||
func (b *Base) GetUUID() (uuid.UUID, error) {
|
||||
return id.ToUUID(b.ID)
|
||||
}
|
||||
|
||||
// GetIDString returns string representation of the ID
|
||||
func (b *Base) GetIDString() string {
|
||||
return id.String(b.ID)
|
||||
}
|
||||
|
||||
// SetIDFromString parses string UUID into binary ID
|
||||
func (b *Base) SetIDFromString(s string) error {
|
||||
binID, err := id.FromString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.ID = binID
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateID checks if the ID is a valid UUID
|
||||
func (b *Base) ValidateID() error {
|
||||
_, err := uuid.FromBytes(b.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
// BeforeCreate GORM hook to set ID if empty
|
||||
func (b *Base) BeforeCreate(tx *gorm.DB) error {
|
||||
if len(b.ID) == 0 {
|
||||
b.ID = id.New()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
13
models/budget_item.go
Normal file
13
models/budget_item.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package models
|
||||
|
||||
// BudgetItem represents budget items
|
||||
type BudgetItem struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Amount int64 `gorm:"column:amount;not null"`
|
||||
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
Account Account `gorm:"foreignKey:AccountID"`
|
||||
}
|
||||
13
models/invite.go
Normal file
13
models/invite.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package models
|
||||
|
||||
// Invite represents organization invitations
|
||||
type Invite struct {
|
||||
ID string `gorm:"size:32;primaryKey"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
Email string `gorm:"column:email;size:100;not null"`
|
||||
Accepted bool `gorm:"column:accepted;not null"`
|
||||
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
}
|
||||
15
models/org.go
Normal file
15
models/org.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package models
|
||||
|
||||
// Org represents an organization
|
||||
type Org struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
Name string `gorm:"column:name;size:100;not null"`
|
||||
Currency string `gorm:"column:currency;size:10;not null"`
|
||||
Precision int `gorm:"column:precision;not null"`
|
||||
Timezone string `gorm:"column:timezone;size:100;not null"`
|
||||
|
||||
Accounts []Account `gorm:"foreignKey:OrgID"`
|
||||
UserOrgs []UserOrg `gorm:"foreignKey:OrgID"`
|
||||
}
|
||||
18
models/permission.go
Normal file
18
models/permission.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package models
|
||||
|
||||
// Permission represents access control rules
|
||||
type Permission struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
UserID []byte `gorm:"column:userId;type:BINARY(16)"`
|
||||
TokenID []byte `gorm:"column:tokenId;type:BINARY(16)"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
|
||||
Type uint `gorm:"column:type;not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
Token Token `gorm:"foreignKey:TokenID"`
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
Account Account `gorm:"foreignKey:AccountID"`
|
||||
}
|
||||
14
models/price.go
Normal file
14
models/price.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package models
|
||||
|
||||
// Price represents currency exchange rates
|
||||
type Price struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
Currency string `gorm:"column:currency;size:10;not null"`
|
||||
Date uint64 `gorm:"column:date;not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
Price float64 `gorm:"column:price;not null"`
|
||||
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
}
|
||||
12
models/session.go
Normal file
12
models/session.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package models
|
||||
|
||||
// Session represents user sessions
|
||||
type Session struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
|
||||
Terminated uint64 `gorm:"column:terminated"`
|
||||
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
}
|
||||
17
models/split.go
Normal file
17
models/split.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package models
|
||||
|
||||
// Split represents a single entry in a transaction
|
||||
type Split struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
TransactionID []byte `gorm:"column:transactionId;type:BINARY(16);not null"`
|
||||
AccountID []byte `gorm:"column:accountId;type:BINARY(16);not null"`
|
||||
Date uint64 `gorm:"column:date;not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
Amount int64 `gorm:"column:amount;not null"`
|
||||
NativeAmount int64 `gorm:"column:nativeAmount;not null"`
|
||||
Deleted bool `gorm:"column:deleted;default:false"`
|
||||
|
||||
Transaction Transaction `gorm:"foreignKey:TransactionID"`
|
||||
Account Account `gorm:"foreignKey:AccountID"`
|
||||
}
|
||||
10
models/token.go
Normal file
10
models/token.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package models
|
||||
|
||||
// Token represents an API token
|
||||
type Token struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
Name string `gorm:"column:name;size:100"`
|
||||
UserOrgID uint `gorm:"column:userOrgId;not null"`
|
||||
|
||||
UserOrg UserOrg `gorm:"foreignKey:UserOrgID"`
|
||||
}
|
||||
18
models/transaction.go
Normal file
18
models/transaction.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package models
|
||||
|
||||
// Transaction represents a financial transaction
|
||||
type Transaction struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
|
||||
Date uint64 `gorm:"column:date;not null"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
Description string `gorm:"column:description;size:300;not null"`
|
||||
Data string `gorm:"column:data;type:TEXT;not null"`
|
||||
Deleted bool `gorm:"column:deleted;default:false"`
|
||||
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
Splits []Split `gorm:"foreignKey:TransactionID"`
|
||||
}
|
||||
21
models/user.go
Normal file
21
models/user.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package models
|
||||
|
||||
// User represents a user account
|
||||
type User struct {
|
||||
ID []byte `gorm:"type:BINARY(16);primaryKey"`
|
||||
Inserted uint64 `gorm:"column:inserted;not null"`
|
||||
Updated uint64 `gorm:"column:updated;not null"`
|
||||
FirstName string `gorm:"column:firstName;size:50;not null"`
|
||||
LastName string `gorm:"column:lastName;size:50;not null"`
|
||||
Email string `gorm:"column:email;size:100;not null;unique"`
|
||||
PasswordHash string `gorm:"column:passwordHash;size:100;not null"`
|
||||
AgreeToTerms bool `gorm:"column:agreeToTerms;not null"`
|
||||
PasswordReset string `gorm:"column:passwordReset;size:32;not null"`
|
||||
EmailVerified bool `gorm:"column:emailVerified;not null"`
|
||||
EmailVerifyCode string `gorm:"column:emailVerifyCode;size:32;not null"`
|
||||
SignupSource string `gorm:"column:signupSource;size:100;not null"`
|
||||
|
||||
UserOrgs []UserOrg `gorm:"foreignKey:UserID"`
|
||||
Sessions []Session `gorm:"foreignKey:UserID"`
|
||||
APIKeys []APIKey `gorm:"foreignKey:UserID"`
|
||||
}
|
||||
12
models/user_org.go
Normal file
12
models/user_org.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package models
|
||||
|
||||
// UserOrg represents the relationship between users and organizations
|
||||
type UserOrg struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
|
||||
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
|
||||
Admin bool `gorm:"column:admin;default:false"`
|
||||
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
Org Org `gorm:"foreignKey:OrgID"`
|
||||
}
|
||||
@@ -30,4 +30,6 @@ CREATE TABLE apikey (id BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL,
|
||||
|
||||
CREATE TABLE invite (id VARCHAR(32) NOT NULL, orgId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, updated BIGINT UNSIGNED NOT NULL, email VARCHAR(100) NOT NULL, accepted BOOLEAN NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
|
||||
|
||||
CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
|
||||
CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
|
||||
|
||||
CREATE TABLE attachment (id BINARY(16) NOT NULL, transactionId BINARY(16) NOT NULL, orgId BINARY(16) NOT NULL, userId BINARY(16) NOT NULL, fileName VARCHAR(255) NOT NULL, originalName VARCHAR(255) NOT NULL, contentType VARCHAR(100) NOT NULL, fileSize BIGINT NOT NULL, filePath VARCHAR(500) NOT NULL, description VARCHAR(500), uploaded BIGINT UNSIGNED NOT NULL, deleted BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY(id)) ENGINE=InnoDB;
|
||||
27
vendor/filippo.io/edwards25519/LICENSE
generated
vendored
Normal file
27
vendor/filippo.io/edwards25519/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
14
vendor/filippo.io/edwards25519/README.md
generated
vendored
Normal file
14
vendor/filippo.io/edwards25519/README.md
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# filippo.io/edwards25519
|
||||
|
||||
```
|
||||
import "filippo.io/edwards25519"
|
||||
```
|
||||
|
||||
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
|
||||
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
|
||||
|
||||
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
|
||||
|
||||
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
|
||||
|
||||
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.
|
||||
20
vendor/filippo.io/edwards25519/doc.go
generated
vendored
Normal file
20
vendor/filippo.io/edwards25519/doc.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright (c) 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package edwards25519 implements group logic for the twisted Edwards curve
|
||||
//
|
||||
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
|
||||
//
|
||||
// This is better known as the Edwards curve equivalent to Curve25519, and is
|
||||
// the curve used by the Ed25519 signature scheme.
|
||||
//
|
||||
// Most users don't need this package, and should instead use crypto/ed25519 for
|
||||
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
|
||||
// github.com/gtank/ristretto255 for prime order group logic.
|
||||
//
|
||||
// However, developers who do need to interact with low-level edwards25519
|
||||
// operations can use this package, which is an extended version of
|
||||
// crypto/internal/edwards25519 from the standard library repackaged as
|
||||
// an importable module.
|
||||
package edwards25519
|
||||
427
vendor/filippo.io/edwards25519/edwards25519.go
generated
vendored
Normal file
427
vendor/filippo.io/edwards25519/edwards25519.go
generated
vendored
Normal file
@@ -0,0 +1,427 @@
|
||||
// Copyright (c) 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package edwards25519
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"filippo.io/edwards25519/field"
|
||||
)
|
||||
|
||||
// Point types.
|
||||
|
||||
type projP1xP1 struct {
|
||||
X, Y, Z, T field.Element
|
||||
}
|
||||
|
||||
type projP2 struct {
|
||||
X, Y, Z field.Element
|
||||
}
|
||||
|
||||
// Point represents a point on the edwards25519 curve.
|
||||
//
|
||||
// This type works similarly to math/big.Int, and all arguments and receivers
|
||||
// are allowed to alias.
|
||||
//
|
||||
// The zero value is NOT valid, and it may be used only as a receiver.
|
||||
type Point struct {
|
||||
// Make the type not comparable (i.e. used with == or as a map key), as
|
||||
// equivalent points can be represented by different Go values.
|
||||
_ incomparable
|
||||
|
||||
// The point is internally represented in extended coordinates (X, Y, Z, T)
|
||||
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
|
||||
x, y, z, t field.Element
|
||||
}
|
||||
|
||||
type incomparable [0]func()
|
||||
|
||||
func checkInitialized(points ...*Point) {
|
||||
for _, p := range points {
|
||||
if p.x == (field.Element{}) && p.y == (field.Element{}) {
|
||||
panic("edwards25519: use of uninitialized Point")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type projCached struct {
|
||||
YplusX, YminusX, Z, T2d field.Element
|
||||
}
|
||||
|
||||
type affineCached struct {
|
||||
YplusX, YminusX, T2d field.Element
|
||||
}
|
||||
|
||||
// Constructors.
|
||||
|
||||
func (v *projP2) Zero() *projP2 {
|
||||
v.X.Zero()
|
||||
v.Y.One()
|
||||
v.Z.One()
|
||||
return v
|
||||
}
|
||||
|
||||
// identity is the point at infinity.
|
||||
var identity, _ = new(Point).SetBytes([]byte{
|
||||
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
|
||||
|
||||
// NewIdentityPoint returns a new Point set to the identity.
|
||||
func NewIdentityPoint() *Point {
|
||||
return new(Point).Set(identity)
|
||||
}
|
||||
|
||||
// generator is the canonical curve basepoint. See TestGenerator for the
|
||||
// correspondence of this encoding with the values in RFC 8032.
|
||||
var generator, _ = new(Point).SetBytes([]byte{
|
||||
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
|
||||
|
||||
// NewGeneratorPoint returns a new Point set to the canonical generator.
|
||||
func NewGeneratorPoint() *Point {
|
||||
return new(Point).Set(generator)
|
||||
}
|
||||
|
||||
func (v *projCached) Zero() *projCached {
|
||||
v.YplusX.One()
|
||||
v.YminusX.One()
|
||||
v.Z.One()
|
||||
v.T2d.Zero()
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *affineCached) Zero() *affineCached {
|
||||
v.YplusX.One()
|
||||
v.YminusX.One()
|
||||
v.T2d.Zero()
|
||||
return v
|
||||
}
|
||||
|
||||
// Assignments.
|
||||
|
||||
// Set sets v = u, and returns v.
|
||||
func (v *Point) Set(u *Point) *Point {
|
||||
*v = *u
|
||||
return v
|
||||
}
|
||||
|
||||
// Encoding.
|
||||
|
||||
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
|
||||
// Section 5.1.2.
|
||||
func (v *Point) Bytes() []byte {
|
||||
// This function is outlined to make the allocations inline in the caller
|
||||
// rather than happen on the heap.
|
||||
var buf [32]byte
|
||||
return v.bytes(&buf)
|
||||
}
|
||||
|
||||
func (v *Point) bytes(buf *[32]byte) []byte {
|
||||
checkInitialized(v)
|
||||
|
||||
var zInv, x, y field.Element
|
||||
zInv.Invert(&v.z) // zInv = 1 / Z
|
||||
x.Multiply(&v.x, &zInv) // x = X / Z
|
||||
y.Multiply(&v.y, &zInv) // y = Y / Z
|
||||
|
||||
out := copyFieldElement(buf, &y)
|
||||
out[31] |= byte(x.IsNegative() << 7)
|
||||
return out
|
||||
}
|
||||
|
||||
var feOne = new(field.Element).One()
|
||||
|
||||
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
|
||||
// represent a valid point on the curve, SetBytes returns nil and an error and
|
||||
// the receiver is unchanged. Otherwise, SetBytes returns v.
|
||||
//
|
||||
// Note that SetBytes accepts all non-canonical encodings of valid points.
|
||||
// That is, it follows decoding rules that match most implementations in
|
||||
// the ecosystem rather than RFC 8032.
|
||||
func (v *Point) SetBytes(x []byte) (*Point, error) {
|
||||
// Specifically, the non-canonical encodings that are accepted are
|
||||
// 1) the ones where the field element is not reduced (see the
|
||||
// (*field.Element).SetBytes docs) and
|
||||
// 2) the ones where the x-coordinate is zero and the sign bit is set.
|
||||
//
|
||||
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
|
||||
// specifically the "Canonical A, R" section.
|
||||
|
||||
y, err := new(field.Element).SetBytes(x)
|
||||
if err != nil {
|
||||
return nil, errors.New("edwards25519: invalid point encoding length")
|
||||
}
|
||||
|
||||
// -x² + y² = 1 + dx²y²
|
||||
// x² + dx²y² = x²(dy² + 1) = y² - 1
|
||||
// x² = (y² - 1) / (dy² + 1)
|
||||
|
||||
// u = y² - 1
|
||||
y2 := new(field.Element).Square(y)
|
||||
u := new(field.Element).Subtract(y2, feOne)
|
||||
|
||||
// v = dy² + 1
|
||||
vv := new(field.Element).Multiply(y2, d)
|
||||
vv = vv.Add(vv, feOne)
|
||||
|
||||
// x = +√(u/v)
|
||||
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
|
||||
if wasSquare == 0 {
|
||||
return nil, errors.New("edwards25519: invalid point encoding")
|
||||
}
|
||||
|
||||
// Select the negative square root if the sign bit is set.
|
||||
xxNeg := new(field.Element).Negate(xx)
|
||||
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
|
||||
|
||||
v.x.Set(xx)
|
||||
v.y.Set(y)
|
||||
v.z.One()
|
||||
v.t.Multiply(xx, y) // xy = T / Z
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
|
||||
copy(buf[:], v.Bytes())
|
||||
return buf[:]
|
||||
}
|
||||
|
||||
// Conversions.
|
||||
|
||||
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
|
||||
v.X.Multiply(&p.X, &p.T)
|
||||
v.Y.Multiply(&p.Y, &p.Z)
|
||||
v.Z.Multiply(&p.Z, &p.T)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *projP2) FromP3(p *Point) *projP2 {
|
||||
v.X.Set(&p.x)
|
||||
v.Y.Set(&p.y)
|
||||
v.Z.Set(&p.z)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
|
||||
v.x.Multiply(&p.X, &p.T)
|
||||
v.y.Multiply(&p.Y, &p.Z)
|
||||
v.z.Multiply(&p.Z, &p.T)
|
||||
v.t.Multiply(&p.X, &p.Y)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *Point) fromP2(p *projP2) *Point {
|
||||
v.x.Multiply(&p.X, &p.Z)
|
||||
v.y.Multiply(&p.Y, &p.Z)
|
||||
v.z.Square(&p.Z)
|
||||
v.t.Multiply(&p.X, &p.Y)
|
||||
return v
|
||||
}
|
||||
|
||||
// d is a constant in the curve equation.
|
||||
var d, _ = new(field.Element).SetBytes([]byte{
|
||||
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
|
||||
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
|
||||
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
|
||||
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
|
||||
var d2 = new(field.Element).Add(d, d)
|
||||
|
||||
func (v *projCached) FromP3(p *Point) *projCached {
|
||||
v.YplusX.Add(&p.y, &p.x)
|
||||
v.YminusX.Subtract(&p.y, &p.x)
|
||||
v.Z.Set(&p.z)
|
||||
v.T2d.Multiply(&p.t, d2)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *affineCached) FromP3(p *Point) *affineCached {
|
||||
v.YplusX.Add(&p.y, &p.x)
|
||||
v.YminusX.Subtract(&p.y, &p.x)
|
||||
v.T2d.Multiply(&p.t, d2)
|
||||
|
||||
var invZ field.Element
|
||||
invZ.Invert(&p.z)
|
||||
v.YplusX.Multiply(&v.YplusX, &invZ)
|
||||
v.YminusX.Multiply(&v.YminusX, &invZ)
|
||||
v.T2d.Multiply(&v.T2d, &invZ)
|
||||
return v
|
||||
}
|
||||
|
||||
// (Re)addition and subtraction.
|
||||
|
||||
// Add sets v = p + q, and returns v.
|
||||
func (v *Point) Add(p, q *Point) *Point {
|
||||
checkInitialized(p, q)
|
||||
qCached := new(projCached).FromP3(q)
|
||||
result := new(projP1xP1).Add(p, qCached)
|
||||
return v.fromP1xP1(result)
|
||||
}
|
||||
|
||||
// Subtract sets v = p - q, and returns v.
|
||||
func (v *Point) Subtract(p, q *Point) *Point {
|
||||
checkInitialized(p, q)
|
||||
qCached := new(projCached).FromP3(q)
|
||||
result := new(projP1xP1).Sub(p, qCached)
|
||||
return v.fromP1xP1(result)
|
||||
}
|
||||
|
||||
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
|
||||
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
|
||||
|
||||
YplusX.Add(&p.y, &p.x)
|
||||
YminusX.Subtract(&p.y, &p.x)
|
||||
|
||||
PP.Multiply(&YplusX, &q.YplusX)
|
||||
MM.Multiply(&YminusX, &q.YminusX)
|
||||
TT2d.Multiply(&p.t, &q.T2d)
|
||||
ZZ2.Multiply(&p.z, &q.Z)
|
||||
|
||||
ZZ2.Add(&ZZ2, &ZZ2)
|
||||
|
||||
v.X.Subtract(&PP, &MM)
|
||||
v.Y.Add(&PP, &MM)
|
||||
v.Z.Add(&ZZ2, &TT2d)
|
||||
v.T.Subtract(&ZZ2, &TT2d)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
|
||||
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
|
||||
|
||||
YplusX.Add(&p.y, &p.x)
|
||||
YminusX.Subtract(&p.y, &p.x)
|
||||
|
||||
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
|
||||
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
|
||||
TT2d.Multiply(&p.t, &q.T2d)
|
||||
ZZ2.Multiply(&p.z, &q.Z)
|
||||
|
||||
ZZ2.Add(&ZZ2, &ZZ2)
|
||||
|
||||
v.X.Subtract(&PP, &MM)
|
||||
v.Y.Add(&PP, &MM)
|
||||
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
|
||||
v.T.Add(&ZZ2, &TT2d) // flipped sign
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
|
||||
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
|
||||
|
||||
YplusX.Add(&p.y, &p.x)
|
||||
YminusX.Subtract(&p.y, &p.x)
|
||||
|
||||
PP.Multiply(&YplusX, &q.YplusX)
|
||||
MM.Multiply(&YminusX, &q.YminusX)
|
||||
TT2d.Multiply(&p.t, &q.T2d)
|
||||
|
||||
Z2.Add(&p.z, &p.z)
|
||||
|
||||
v.X.Subtract(&PP, &MM)
|
||||
v.Y.Add(&PP, &MM)
|
||||
v.Z.Add(&Z2, &TT2d)
|
||||
v.T.Subtract(&Z2, &TT2d)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
|
||||
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
|
||||
|
||||
YplusX.Add(&p.y, &p.x)
|
||||
YminusX.Subtract(&p.y, &p.x)
|
||||
|
||||
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
|
||||
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
|
||||
TT2d.Multiply(&p.t, &q.T2d)
|
||||
|
||||
Z2.Add(&p.z, &p.z)
|
||||
|
||||
v.X.Subtract(&PP, &MM)
|
||||
v.Y.Add(&PP, &MM)
|
||||
v.Z.Subtract(&Z2, &TT2d) // flipped sign
|
||||
v.T.Add(&Z2, &TT2d) // flipped sign
|
||||
return v
|
||||
}
|
||||
|
||||
// Doubling.
|
||||
|
||||
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
|
||||
var XX, YY, ZZ2, XplusYsq field.Element
|
||||
|
||||
XX.Square(&p.X)
|
||||
YY.Square(&p.Y)
|
||||
ZZ2.Square(&p.Z)
|
||||
ZZ2.Add(&ZZ2, &ZZ2)
|
||||
XplusYsq.Add(&p.X, &p.Y)
|
||||
XplusYsq.Square(&XplusYsq)
|
||||
|
||||
v.Y.Add(&YY, &XX)
|
||||
v.Z.Subtract(&YY, &XX)
|
||||
|
||||
v.X.Subtract(&XplusYsq, &v.Y)
|
||||
v.T.Subtract(&ZZ2, &v.Z)
|
||||
return v
|
||||
}
|
||||
|
||||
// Negation.
|
||||
|
||||
// Negate sets v = -p, and returns v.
|
||||
func (v *Point) Negate(p *Point) *Point {
|
||||
checkInitialized(p)
|
||||
v.x.Negate(&p.x)
|
||||
v.y.Set(&p.y)
|
||||
v.z.Set(&p.z)
|
||||
v.t.Negate(&p.t)
|
||||
return v
|
||||
}
|
||||
|
||||
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
|
||||
func (v *Point) Equal(u *Point) int {
|
||||
checkInitialized(v, u)
|
||||
|
||||
var t1, t2, t3, t4 field.Element
|
||||
t1.Multiply(&v.x, &u.z)
|
||||
t2.Multiply(&u.x, &v.z)
|
||||
t3.Multiply(&v.y, &u.z)
|
||||
t4.Multiply(&u.y, &v.z)
|
||||
|
||||
return t1.Equal(&t2) & t3.Equal(&t4)
|
||||
}
|
||||
|
||||
// Constant-time operations
|
||||
|
||||
// Select sets v to a if cond == 1 and to b if cond == 0.
|
||||
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
|
||||
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
|
||||
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
|
||||
v.Z.Select(&a.Z, &b.Z, cond)
|
||||
v.T2d.Select(&a.T2d, &b.T2d, cond)
|
||||
return v
|
||||
}
|
||||
|
||||
// Select sets v to a if cond == 1 and to b if cond == 0.
|
||||
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
|
||||
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
|
||||
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
|
||||
v.T2d.Select(&a.T2d, &b.T2d, cond)
|
||||
return v
|
||||
}
|
||||
|
||||
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
|
||||
func (v *projCached) CondNeg(cond int) *projCached {
|
||||
v.YplusX.Swap(&v.YminusX, cond)
|
||||
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
|
||||
return v
|
||||
}
|
||||
|
||||
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
|
||||
func (v *affineCached) CondNeg(cond int) *affineCached {
|
||||
v.YplusX.Swap(&v.YminusX, cond)
|
||||
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
|
||||
return v
|
||||
}
|
||||
349
vendor/filippo.io/edwards25519/extra.go
generated
vendored
Normal file
349
vendor/filippo.io/edwards25519/extra.go
generated
vendored
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright (c) 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package edwards25519
|
||||
|
||||
// This file contains additional functionality that is not included in the
|
||||
// upstream crypto/internal/edwards25519 package.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"filippo.io/edwards25519/field"
|
||||
)
|
||||
|
||||
// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
|
||||
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
|
||||
func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
|
||||
// This function is outlined to make the allocations inline in the caller
|
||||
// rather than happen on the heap. Don't change the style without making
|
||||
// sure it doesn't increase the inliner cost.
|
||||
var e [4]field.Element
|
||||
X, Y, Z, T = v.extendedCoordinates(&e)
|
||||
return
|
||||
}
|
||||
|
||||
func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
|
||||
checkInitialized(v)
|
||||
X = e[0].Set(&v.x)
|
||||
Y = e[1].Set(&v.y)
|
||||
Z = e[2].Set(&v.z)
|
||||
T = e[3].Set(&v.t)
|
||||
return
|
||||
}
|
||||
|
||||
// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
|
||||
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
|
||||
//
|
||||
// If the coordinates are invalid or don't represent a valid point on the curve,
|
||||
// SetExtendedCoordinates returns nil and an error and the receiver is
|
||||
// unchanged. Otherwise, SetExtendedCoordinates returns v.
|
||||
func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
|
||||
if !isOnCurve(X, Y, Z, T) {
|
||||
return nil, errors.New("edwards25519: invalid point coordinates")
|
||||
}
|
||||
v.x.Set(X)
|
||||
v.y.Set(Y)
|
||||
v.z.Set(Z)
|
||||
v.t.Set(T)
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func isOnCurve(X, Y, Z, T *field.Element) bool {
|
||||
var lhs, rhs field.Element
|
||||
XX := new(field.Element).Square(X)
|
||||
YY := new(field.Element).Square(Y)
|
||||
ZZ := new(field.Element).Square(Z)
|
||||
TT := new(field.Element).Square(T)
|
||||
// -x² + y² = 1 + dx²y²
|
||||
// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
|
||||
// -X² + Y² = Z² + dT²
|
||||
lhs.Subtract(YY, XX)
|
||||
rhs.Multiply(d, TT).Add(&rhs, ZZ)
|
||||
if lhs.Equal(&rhs) != 1 {
|
||||
return false
|
||||
}
|
||||
// xy = T/Z
|
||||
// XY/Z² = T/Z
|
||||
// XY = TZ
|
||||
lhs.Multiply(X, Y)
|
||||
rhs.Multiply(T, Z)
|
||||
return lhs.Equal(&rhs) == 1
|
||||
}
|
||||
|
||||
// BytesMontgomery converts v to a point on the birationally-equivalent
|
||||
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
|
||||
// according to RFC 7748.
|
||||
//
|
||||
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
|
||||
// to the same value. If v is the identity point, BytesMontgomery returns 32
|
||||
// zero bytes, analogously to the X25519 function.
|
||||
//
|
||||
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
|
||||
// while every valid edwards25519 point has a unique u-coordinate Montgomery
|
||||
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
|
||||
// to any edwards25519 point, and every other X25519 input corresponds to two
|
||||
// edwards25519 points.
|
||||
func (v *Point) BytesMontgomery() []byte {
|
||||
// This function is outlined to make the allocations inline in the caller
|
||||
// rather than happen on the heap.
|
||||
var buf [32]byte
|
||||
return v.bytesMontgomery(&buf)
|
||||
}
|
||||
|
||||
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
|
||||
checkInitialized(v)
|
||||
|
||||
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
|
||||
// Montgomery u-coordinate
|
||||
//
|
||||
// u = (1 + y) / (1 - y)
|
||||
//
|
||||
// where y = Y / Z.
|
||||
|
||||
var y, recip, u field.Element
|
||||
|
||||
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
|
||||
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
|
||||
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
|
||||
|
||||
return copyFieldElement(buf, &u)
|
||||
}
|
||||
|
||||
// MultByCofactor sets v = 8 * p, and returns v.
|
||||
func (v *Point) MultByCofactor(p *Point) *Point {
|
||||
checkInitialized(p)
|
||||
result := projP1xP1{}
|
||||
pp := (&projP2{}).FromP3(p)
|
||||
result.Double(pp)
|
||||
pp.FromP1xP1(&result)
|
||||
result.Double(pp)
|
||||
pp.FromP1xP1(&result)
|
||||
result.Double(pp)
|
||||
return v.fromP1xP1(&result)
|
||||
}
|
||||
|
||||
// Given k > 0, set s = s**(2*i).
|
||||
func (s *Scalar) pow2k(k int) {
|
||||
for i := 0; i < k; i++ {
|
||||
s.Multiply(s, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
|
||||
//
|
||||
// If t is zero, Invert returns zero.
|
||||
func (s *Scalar) Invert(t *Scalar) *Scalar {
|
||||
// Uses a hardcoded sliding window of width 4.
|
||||
var table [8]Scalar
|
||||
var tt Scalar
|
||||
tt.Multiply(t, t)
|
||||
table[0] = *t
|
||||
for i := 0; i < 7; i++ {
|
||||
table[i+1].Multiply(&table[i], &tt)
|
||||
}
|
||||
// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
|
||||
// so t**k = t[k/2] for odd k
|
||||
|
||||
// To compute the sliding window digits, use the following Sage script:
|
||||
|
||||
// sage: import itertools
|
||||
// sage: def sliding_window(w,k):
|
||||
// ....: digits = []
|
||||
// ....: while k > 0:
|
||||
// ....: if k % 2 == 1:
|
||||
// ....: kmod = k % (2**w)
|
||||
// ....: digits.append(kmod)
|
||||
// ....: k = k - kmod
|
||||
// ....: else:
|
||||
// ....: digits.append(0)
|
||||
// ....: k = k // 2
|
||||
// ....: return digits
|
||||
|
||||
// Now we can compute s roughly as follows:
|
||||
|
||||
// sage: s = 1
|
||||
// sage: for coeff in reversed(sliding_window(4,l-2)):
|
||||
// ....: s = s*s
|
||||
// ....: if coeff > 0 :
|
||||
// ....: s = s*t**coeff
|
||||
|
||||
// This works on one bit at a time, with many runs of zeros.
|
||||
// The digits can be collapsed into [(count, coeff)] as follows:
|
||||
|
||||
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
|
||||
|
||||
// Entries of the form (k, 0) turn into pow2k(k)
|
||||
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
|
||||
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
|
||||
|
||||
*s = table[1/2]
|
||||
s.pow2k(127 + 1)
|
||||
s.Multiply(s, &table[1/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[9/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[11/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[13/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[15/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[7/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[15/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[5/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[1/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[15/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[15/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[7/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[3/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[11/2])
|
||||
s.pow2k(5 + 1)
|
||||
s.Multiply(s, &table[11/2])
|
||||
s.pow2k(9 + 1)
|
||||
s.Multiply(s, &table[9/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[3/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[3/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[3/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[9/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[7/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[3/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[13/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[7/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[9/2])
|
||||
s.pow2k(3 + 1)
|
||||
s.Multiply(s, &table[15/2])
|
||||
s.pow2k(4 + 1)
|
||||
s.Multiply(s, &table[11/2])
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
|
||||
//
|
||||
// Execution time depends only on the lengths of the two slices, which must match.
|
||||
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
|
||||
if len(scalars) != len(points) {
|
||||
panic("edwards25519: called MultiScalarMult with different size inputs")
|
||||
}
|
||||
checkInitialized(points...)
|
||||
|
||||
// Proceed as in the single-base case, but share doublings
|
||||
// between each point in the multiscalar equation.
|
||||
|
||||
// Build lookup tables for each point
|
||||
tables := make([]projLookupTable, len(points))
|
||||
for i := range tables {
|
||||
tables[i].FromP3(points[i])
|
||||
}
|
||||
// Compute signed radix-16 digits for each scalar
|
||||
digits := make([][64]int8, len(scalars))
|
||||
for i := range digits {
|
||||
digits[i] = scalars[i].signedRadix16()
|
||||
}
|
||||
|
||||
// Unwrap first loop iteration to save computing 16*identity
|
||||
multiple := &projCached{}
|
||||
tmp1 := &projP1xP1{}
|
||||
tmp2 := &projP2{}
|
||||
// Lookup-and-add the appropriate multiple of each input point
|
||||
for j := range tables {
|
||||
tables[j].SelectInto(multiple, digits[j][63])
|
||||
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
|
||||
v.fromP1xP1(tmp1) // update v
|
||||
}
|
||||
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
|
||||
for i := 62; i >= 0; i-- {
|
||||
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
|
||||
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
|
||||
// Lookup-and-add the appropriate multiple of each input point
|
||||
for j := range tables {
|
||||
tables[j].SelectInto(multiple, digits[j][i])
|
||||
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
|
||||
v.fromP1xP1(tmp1) // update v
|
||||
}
|
||||
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
|
||||
//
|
||||
// Execution time depends on the inputs.
|
||||
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
|
||||
if len(scalars) != len(points) {
|
||||
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
|
||||
}
|
||||
checkInitialized(points...)
|
||||
|
||||
// Generalize double-base NAF computation to arbitrary sizes.
|
||||
// Here all the points are dynamic, so we only use the smaller
|
||||
// tables.
|
||||
|
||||
// Build lookup tables for each point
|
||||
tables := make([]nafLookupTable5, len(points))
|
||||
for i := range tables {
|
||||
tables[i].FromP3(points[i])
|
||||
}
|
||||
// Compute a NAF for each scalar
|
||||
nafs := make([][256]int8, len(scalars))
|
||||
for i := range nafs {
|
||||
nafs[i] = scalars[i].nonAdjacentForm(5)
|
||||
}
|
||||
|
||||
multiple := &projCached{}
|
||||
tmp1 := &projP1xP1{}
|
||||
tmp2 := &projP2{}
|
||||
tmp2.Zero()
|
||||
|
||||
// Move from high to low bits, doubling the accumulator
|
||||
// at each iteration and checking whether there is a nonzero
|
||||
// coefficient to look up a multiple of.
|
||||
//
|
||||
// Skip trying to find the first nonzero coefficent, because
|
||||
// searching might be more work than a few extra doublings.
|
||||
for i := 255; i >= 0; i-- {
|
||||
tmp1.Double(tmp2)
|
||||
|
||||
for j := range nafs {
|
||||
if nafs[j][i] > 0 {
|
||||
v.fromP1xP1(tmp1)
|
||||
tables[j].SelectInto(multiple, nafs[j][i])
|
||||
tmp1.Add(v, multiple)
|
||||
} else if nafs[j][i] < 0 {
|
||||
v.fromP1xP1(tmp1)
|
||||
tables[j].SelectInto(multiple, -nafs[j][i])
|
||||
tmp1.Sub(v, multiple)
|
||||
}
|
||||
}
|
||||
|
||||
tmp2.FromP1xP1(tmp1)
|
||||
}
|
||||
|
||||
v.fromP2(tmp2)
|
||||
return v
|
||||
}
|
||||
420
vendor/filippo.io/edwards25519/field/fe.go
generated
vendored
Normal file
420
vendor/filippo.io/edwards25519/field/fe.go
generated
vendored
Normal file
@@ -0,0 +1,420 @@
|
||||
// Copyright (c) 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package field implements fast arithmetic modulo 2^255-19.
|
||||
package field
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// Element represents an element of the field GF(2^255-19). Note that this
|
||||
// is not a cryptographically secure group, and should only be used to interact
|
||||
// with edwards25519.Point coordinates.
|
||||
//
|
||||
// This type works similarly to math/big.Int, and all arguments and receivers
|
||||
// are allowed to alias.
|
||||
//
|
||||
// The zero value is a valid zero element.
|
||||
type Element struct {
|
||||
// An element t represents the integer
|
||||
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
|
||||
//
|
||||
// Between operations, all limbs are expected to be lower than 2^52.
|
||||
l0 uint64
|
||||
l1 uint64
|
||||
l2 uint64
|
||||
l3 uint64
|
||||
l4 uint64
|
||||
}
|
||||
|
||||
const maskLow51Bits uint64 = (1 << 51) - 1
|
||||
|
||||
var feZero = &Element{0, 0, 0, 0, 0}
|
||||
|
||||
// Zero sets v = 0, and returns v.
|
||||
func (v *Element) Zero() *Element {
|
||||
*v = *feZero
|
||||
return v
|
||||
}
|
||||
|
||||
var feOne = &Element{1, 0, 0, 0, 0}
|
||||
|
||||
// One sets v = 1, and returns v.
|
||||
func (v *Element) One() *Element {
|
||||
*v = *feOne
|
||||
return v
|
||||
}
|
||||
|
||||
// reduce reduces v modulo 2^255 - 19 and returns it.
|
||||
func (v *Element) reduce() *Element {
|
||||
v.carryPropagate()
|
||||
|
||||
// After the light reduction we now have a field element representation
|
||||
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
|
||||
|
||||
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
|
||||
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
|
||||
c := (v.l0 + 19) >> 51
|
||||
c = (v.l1 + c) >> 51
|
||||
c = (v.l2 + c) >> 51
|
||||
c = (v.l3 + c) >> 51
|
||||
c = (v.l4 + c) >> 51
|
||||
|
||||
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
|
||||
// effectively applying the reduction identity to the carry.
|
||||
v.l0 += 19 * c
|
||||
|
||||
v.l1 += v.l0 >> 51
|
||||
v.l0 = v.l0 & maskLow51Bits
|
||||
v.l2 += v.l1 >> 51
|
||||
v.l1 = v.l1 & maskLow51Bits
|
||||
v.l3 += v.l2 >> 51
|
||||
v.l2 = v.l2 & maskLow51Bits
|
||||
v.l4 += v.l3 >> 51
|
||||
v.l3 = v.l3 & maskLow51Bits
|
||||
// no additional carry
|
||||
v.l4 = v.l4 & maskLow51Bits
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// Add sets v = a + b, and returns v.
|
||||
func (v *Element) Add(a, b *Element) *Element {
|
||||
v.l0 = a.l0 + b.l0
|
||||
v.l1 = a.l1 + b.l1
|
||||
v.l2 = a.l2 + b.l2
|
||||
v.l3 = a.l3 + b.l3
|
||||
v.l4 = a.l4 + b.l4
|
||||
// Using the generic implementation here is actually faster than the
|
||||
// assembly. Probably because the body of this function is so simple that
|
||||
// the compiler can figure out better optimizations by inlining the carry
|
||||
// propagation.
|
||||
return v.carryPropagateGeneric()
|
||||
}
|
||||
|
||||
// Subtract sets v = a - b, and returns v.
|
||||
func (v *Element) Subtract(a, b *Element) *Element {
|
||||
// We first add 2 * p, to guarantee the subtraction won't underflow, and
|
||||
// then subtract b (which can be up to 2^255 + 2^13 * 19).
|
||||
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
|
||||
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
|
||||
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
|
||||
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
|
||||
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
|
||||
return v.carryPropagate()
|
||||
}
|
||||
|
||||
// Negate sets v = -a, and returns v.
|
||||
func (v *Element) Negate(a *Element) *Element {
|
||||
return v.Subtract(feZero, a)
|
||||
}
|
||||
|
||||
// Invert sets v = 1/z mod p, and returns v.
|
||||
//
|
||||
// If z == 0, Invert returns v = 0.
|
||||
func (v *Element) Invert(z *Element) *Element {
|
||||
// Inversion is implemented as exponentiation with exponent p − 2. It uses the
|
||||
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
|
||||
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
|
||||
|
||||
z2.Square(z) // 2
|
||||
t.Square(&z2) // 4
|
||||
t.Square(&t) // 8
|
||||
z9.Multiply(&t, z) // 9
|
||||
z11.Multiply(&z9, &z2) // 11
|
||||
t.Square(&z11) // 22
|
||||
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
|
||||
|
||||
t.Square(&z2_5_0) // 2^6 - 2^1
|
||||
for i := 0; i < 4; i++ {
|
||||
t.Square(&t) // 2^10 - 2^5
|
||||
}
|
||||
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
|
||||
|
||||
t.Square(&z2_10_0) // 2^11 - 2^1
|
||||
for i := 0; i < 9; i++ {
|
||||
t.Square(&t) // 2^20 - 2^10
|
||||
}
|
||||
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
|
||||
|
||||
t.Square(&z2_20_0) // 2^21 - 2^1
|
||||
for i := 0; i < 19; i++ {
|
||||
t.Square(&t) // 2^40 - 2^20
|
||||
}
|
||||
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
|
||||
|
||||
t.Square(&t) // 2^41 - 2^1
|
||||
for i := 0; i < 9; i++ {
|
||||
t.Square(&t) // 2^50 - 2^10
|
||||
}
|
||||
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
|
||||
|
||||
t.Square(&z2_50_0) // 2^51 - 2^1
|
||||
for i := 0; i < 49; i++ {
|
||||
t.Square(&t) // 2^100 - 2^50
|
||||
}
|
||||
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
|
||||
|
||||
t.Square(&z2_100_0) // 2^101 - 2^1
|
||||
for i := 0; i < 99; i++ {
|
||||
t.Square(&t) // 2^200 - 2^100
|
||||
}
|
||||
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
|
||||
|
||||
t.Square(&t) // 2^201 - 2^1
|
||||
for i := 0; i < 49; i++ {
|
||||
t.Square(&t) // 2^250 - 2^50
|
||||
}
|
||||
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
|
||||
|
||||
t.Square(&t) // 2^251 - 2^1
|
||||
t.Square(&t) // 2^252 - 2^2
|
||||
t.Square(&t) // 2^253 - 2^3
|
||||
t.Square(&t) // 2^254 - 2^4
|
||||
t.Square(&t) // 2^255 - 2^5
|
||||
|
||||
return v.Multiply(&t, &z11) // 2^255 - 21
|
||||
}
|
||||
|
||||
// Set sets v = a, and returns v.
|
||||
func (v *Element) Set(a *Element) *Element {
|
||||
*v = *a
|
||||
return v
|
||||
}
|
||||
|
||||
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
|
||||
// not of the right length, SetBytes returns nil and an error, and the
|
||||
// receiver is unchanged.
|
||||
//
|
||||
// Consistent with RFC 7748, the most significant bit (the high bit of the
|
||||
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
|
||||
// are accepted. Note that this is laxer than specified by RFC 8032, but
|
||||
// consistent with most Ed25519 implementations.
|
||||
func (v *Element) SetBytes(x []byte) (*Element, error) {
|
||||
if len(x) != 32 {
|
||||
return nil, errors.New("edwards25519: invalid field element input size")
|
||||
}
|
||||
|
||||
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
|
||||
v.l0 = binary.LittleEndian.Uint64(x[0:8])
|
||||
v.l0 &= maskLow51Bits
|
||||
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
|
||||
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
|
||||
v.l1 &= maskLow51Bits
|
||||
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
|
||||
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
|
||||
v.l2 &= maskLow51Bits
|
||||
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
|
||||
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
|
||||
v.l3 &= maskLow51Bits
|
||||
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
|
||||
// Note: not bytes 25:33, shift 4, to avoid overread.
|
||||
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
|
||||
v.l4 &= maskLow51Bits
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Bytes returns the canonical 32-byte little-endian encoding of v.
|
||||
func (v *Element) Bytes() []byte {
|
||||
// This function is outlined to make the allocations inline in the caller
|
||||
// rather than happen on the heap.
|
||||
var out [32]byte
|
||||
return v.bytes(&out)
|
||||
}
|
||||
|
||||
func (v *Element) bytes(out *[32]byte) []byte {
|
||||
t := *v
|
||||
t.reduce()
|
||||
|
||||
var buf [8]byte
|
||||
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
|
||||
bitsOffset := i * 51
|
||||
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
|
||||
for i, bb := range buf {
|
||||
off := bitsOffset/8 + i
|
||||
if off >= len(out) {
|
||||
break
|
||||
}
|
||||
out[off] |= bb
|
||||
}
|
||||
}
|
||||
|
||||
return out[:]
|
||||
}
|
||||
|
||||
// Equal returns 1 if v and u are equal, and 0 otherwise.
|
||||
func (v *Element) Equal(u *Element) int {
|
||||
sa, sv := u.Bytes(), v.Bytes()
|
||||
return subtle.ConstantTimeCompare(sa, sv)
|
||||
}
|
||||
|
||||
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
|
||||
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
|
||||
|
||||
// Select sets v to a if cond == 1, and to b if cond == 0.
|
||||
func (v *Element) Select(a, b *Element, cond int) *Element {
|
||||
m := mask64Bits(cond)
|
||||
v.l0 = (m & a.l0) | (^m & b.l0)
|
||||
v.l1 = (m & a.l1) | (^m & b.l1)
|
||||
v.l2 = (m & a.l2) | (^m & b.l2)
|
||||
v.l3 = (m & a.l3) | (^m & b.l3)
|
||||
v.l4 = (m & a.l4) | (^m & b.l4)
|
||||
return v
|
||||
}
|
||||
|
||||
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
|
||||
func (v *Element) Swap(u *Element, cond int) {
|
||||
m := mask64Bits(cond)
|
||||
t := m & (v.l0 ^ u.l0)
|
||||
v.l0 ^= t
|
||||
u.l0 ^= t
|
||||
t = m & (v.l1 ^ u.l1)
|
||||
v.l1 ^= t
|
||||
u.l1 ^= t
|
||||
t = m & (v.l2 ^ u.l2)
|
||||
v.l2 ^= t
|
||||
u.l2 ^= t
|
||||
t = m & (v.l3 ^ u.l3)
|
||||
v.l3 ^= t
|
||||
u.l3 ^= t
|
||||
t = m & (v.l4 ^ u.l4)
|
||||
v.l4 ^= t
|
||||
u.l4 ^= t
|
||||
}
|
||||
|
||||
// IsNegative returns 1 if v is negative, and 0 otherwise.
|
||||
func (v *Element) IsNegative() int {
|
||||
return int(v.Bytes()[0] & 1)
|
||||
}
|
||||
|
||||
// Absolute sets v to |u|, and returns v.
|
||||
func (v *Element) Absolute(u *Element) *Element {
|
||||
return v.Select(new(Element).Negate(u), u, u.IsNegative())
|
||||
}
|
||||
|
||||
// Multiply sets v = x * y, and returns v.
|
||||
func (v *Element) Multiply(x, y *Element) *Element {
|
||||
feMul(v, x, y)
|
||||
return v
|
||||
}
|
||||
|
||||
// Square sets v = x * x, and returns v.
|
||||
func (v *Element) Square(x *Element) *Element {
|
||||
feSquare(v, x)
|
||||
return v
|
||||
}
|
||||
|
||||
// Mult32 sets v = x * y, and returns v.
|
||||
func (v *Element) Mult32(x *Element, y uint32) *Element {
|
||||
x0lo, x0hi := mul51(x.l0, y)
|
||||
x1lo, x1hi := mul51(x.l1, y)
|
||||
x2lo, x2hi := mul51(x.l2, y)
|
||||
x3lo, x3hi := mul51(x.l3, y)
|
||||
x4lo, x4hi := mul51(x.l4, y)
|
||||
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
|
||||
v.l1 = x1lo + x0hi
|
||||
v.l2 = x2lo + x1hi
|
||||
v.l3 = x3lo + x2hi
|
||||
v.l4 = x4lo + x3hi
|
||||
// The hi portions are going to be only 32 bits, plus any previous excess,
|
||||
// so we can skip the carry propagation.
|
||||
return v
|
||||
}
|
||||
|
||||
// mul51 returns lo + hi * 2⁵¹ = a * b.
|
||||
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
|
||||
mh, ml := bits.Mul64(a, uint64(b))
|
||||
lo = ml & maskLow51Bits
|
||||
hi = (mh << 13) | (ml >> 51)
|
||||
return
|
||||
}
|
||||
|
||||
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
|
||||
func (v *Element) Pow22523(x *Element) *Element {
|
||||
var t0, t1, t2 Element
|
||||
|
||||
t0.Square(x) // x^2
|
||||
t1.Square(&t0) // x^4
|
||||
t1.Square(&t1) // x^8
|
||||
t1.Multiply(x, &t1) // x^9
|
||||
t0.Multiply(&t0, &t1) // x^11
|
||||
t0.Square(&t0) // x^22
|
||||
t0.Multiply(&t1, &t0) // x^31
|
||||
t1.Square(&t0) // x^62
|
||||
for i := 1; i < 5; i++ { // x^992
|
||||
t1.Square(&t1)
|
||||
}
|
||||
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
|
||||
t1.Square(&t0) // 2^11 - 2
|
||||
for i := 1; i < 10; i++ { // 2^20 - 2^10
|
||||
t1.Square(&t1)
|
||||
}
|
||||
t1.Multiply(&t1, &t0) // 2^20 - 1
|
||||
t2.Square(&t1) // 2^21 - 2
|
||||
for i := 1; i < 20; i++ { // 2^40 - 2^20
|
||||
t2.Square(&t2)
|
||||
}
|
||||
t1.Multiply(&t2, &t1) // 2^40 - 1
|
||||
t1.Square(&t1) // 2^41 - 2
|
||||
for i := 1; i < 10; i++ { // 2^50 - 2^10
|
||||
t1.Square(&t1)
|
||||
}
|
||||
t0.Multiply(&t1, &t0) // 2^50 - 1
|
||||
t1.Square(&t0) // 2^51 - 2
|
||||
for i := 1; i < 50; i++ { // 2^100 - 2^50
|
||||
t1.Square(&t1)
|
||||
}
|
||||
t1.Multiply(&t1, &t0) // 2^100 - 1
|
||||
t2.Square(&t1) // 2^101 - 2
|
||||
for i := 1; i < 100; i++ { // 2^200 - 2^100
|
||||
t2.Square(&t2)
|
||||
}
|
||||
t1.Multiply(&t2, &t1) // 2^200 - 1
|
||||
t1.Square(&t1) // 2^201 - 2
|
||||
for i := 1; i < 50; i++ { // 2^250 - 2^50
|
||||
t1.Square(&t1)
|
||||
}
|
||||
t0.Multiply(&t1, &t0) // 2^250 - 1
|
||||
t0.Square(&t0) // 2^251 - 2
|
||||
t0.Square(&t0) // 2^252 - 4
|
||||
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
|
||||
}
|
||||
|
||||
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
|
||||
var sqrtM1 = &Element{1718705420411056, 234908883556509,
|
||||
2233514472574048, 2117202627021982, 765476049583133}
|
||||
|
||||
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
|
||||
//
|
||||
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
|
||||
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
|
||||
// and returns r and 0.
|
||||
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
|
||||
t0 := new(Element)
|
||||
|
||||
// r = (u * v3) * (u * v7)^((p-5)/8)
|
||||
v2 := new(Element).Square(v)
|
||||
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
|
||||
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
|
||||
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
|
||||
|
||||
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
|
||||
|
||||
uNeg := new(Element).Negate(u)
|
||||
correctSignSqrt := check.Equal(u)
|
||||
flippedSignSqrt := check.Equal(uNeg)
|
||||
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
|
||||
|
||||
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
|
||||
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
|
||||
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
|
||||
|
||||
r.Absolute(rr) // Choose the nonnegative square root.
|
||||
return r, correctSignSqrt | flippedSignSqrt
|
||||
}
|
||||
16
vendor/filippo.io/edwards25519/field/fe_amd64.go
generated
vendored
Normal file
16
vendor/filippo.io/edwards25519/field/fe_amd64.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
|
||||
|
||||
//go:build amd64 && gc && !purego
|
||||
// +build amd64,gc,!purego
|
||||
|
||||
package field
|
||||
|
||||
// feMul sets out = a * b. It works like feMulGeneric.
|
||||
//
|
||||
//go:noescape
|
||||
func feMul(out *Element, a *Element, b *Element)
|
||||
|
||||
// feSquare sets out = a * a. It works like feSquareGeneric.
|
||||
//
|
||||
//go:noescape
|
||||
func feSquare(out *Element, a *Element)
|
||||
379
vendor/filippo.io/edwards25519/field/fe_amd64.s
generated
vendored
Normal file
379
vendor/filippo.io/edwards25519/field/fe_amd64.s
generated
vendored
Normal file
@@ -0,0 +1,379 @@
|
||||
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
|
||||
|
||||
//go:build amd64 && gc && !purego
|
||||
// +build amd64,gc,!purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func feMul(out *Element, a *Element, b *Element)
|
||||
TEXT ·feMul(SB), NOSPLIT, $0-24
|
||||
MOVQ a+8(FP), CX
|
||||
MOVQ b+16(FP), BX
|
||||
|
||||
// r0 = a0×b0
|
||||
MOVQ (CX), AX
|
||||
MULQ (BX)
|
||||
MOVQ AX, DI
|
||||
MOVQ DX, SI
|
||||
|
||||
// r0 += 19×a1×b4
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r0 += 19×a2×b3
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r0 += 19×a3×b2
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r0 += 19×a4×b1
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 8(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r1 = a0×b1
|
||||
MOVQ (CX), AX
|
||||
MULQ 8(BX)
|
||||
MOVQ AX, R9
|
||||
MOVQ DX, R8
|
||||
|
||||
// r1 += a1×b0
|
||||
MOVQ 8(CX), AX
|
||||
MULQ (BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r1 += 19×a2×b4
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r1 += 19×a3×b3
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r1 += 19×a4×b2
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r2 = a0×b2
|
||||
MOVQ (CX), AX
|
||||
MULQ 16(BX)
|
||||
MOVQ AX, R11
|
||||
MOVQ DX, R10
|
||||
|
||||
// r2 += a1×b1
|
||||
MOVQ 8(CX), AX
|
||||
MULQ 8(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
|
||||
// r2 += a2×b0
|
||||
MOVQ 16(CX), AX
|
||||
MULQ (BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
|
||||
// r2 += 19×a3×b4
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
|
||||
// r2 += 19×a4×b3
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
|
||||
// r3 = a0×b3
|
||||
MOVQ (CX), AX
|
||||
MULQ 24(BX)
|
||||
MOVQ AX, R13
|
||||
MOVQ DX, R12
|
||||
|
||||
// r3 += a1×b2
|
||||
MOVQ 8(CX), AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, R13
|
||||
ADCQ DX, R12
|
||||
|
||||
// r3 += a2×b1
|
||||
MOVQ 16(CX), AX
|
||||
MULQ 8(BX)
|
||||
ADDQ AX, R13
|
||||
ADCQ DX, R12
|
||||
|
||||
// r3 += a3×b0
|
||||
MOVQ 24(CX), AX
|
||||
MULQ (BX)
|
||||
ADDQ AX, R13
|
||||
ADCQ DX, R12
|
||||
|
||||
// r3 += 19×a4×b4
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R13
|
||||
ADCQ DX, R12
|
||||
|
||||
// r4 = a0×b4
|
||||
MOVQ (CX), AX
|
||||
MULQ 32(BX)
|
||||
MOVQ AX, R15
|
||||
MOVQ DX, R14
|
||||
|
||||
// r4 += a1×b3
|
||||
MOVQ 8(CX), AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R15
|
||||
ADCQ DX, R14
|
||||
|
||||
// r4 += a2×b2
|
||||
MOVQ 16(CX), AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, R15
|
||||
ADCQ DX, R14
|
||||
|
||||
// r4 += a3×b1
|
||||
MOVQ 24(CX), AX
|
||||
MULQ 8(BX)
|
||||
ADDQ AX, R15
|
||||
ADCQ DX, R14
|
||||
|
||||
// r4 += a4×b0
|
||||
MOVQ 32(CX), AX
|
||||
MULQ (BX)
|
||||
ADDQ AX, R15
|
||||
ADCQ DX, R14
|
||||
|
||||
// First reduction chain
|
||||
MOVQ $0x0007ffffffffffff, AX
|
||||
SHLQ $0x0d, DI, SI
|
||||
SHLQ $0x0d, R9, R8
|
||||
SHLQ $0x0d, R11, R10
|
||||
SHLQ $0x0d, R13, R12
|
||||
SHLQ $0x0d, R15, R14
|
||||
ANDQ AX, DI
|
||||
IMUL3Q $0x13, R14, R14
|
||||
ADDQ R14, DI
|
||||
ANDQ AX, R9
|
||||
ADDQ SI, R9
|
||||
ANDQ AX, R11
|
||||
ADDQ R8, R11
|
||||
ANDQ AX, R13
|
||||
ADDQ R10, R13
|
||||
ANDQ AX, R15
|
||||
ADDQ R12, R15
|
||||
|
||||
// Second reduction chain (carryPropagate)
|
||||
MOVQ DI, SI
|
||||
SHRQ $0x33, SI
|
||||
MOVQ R9, R8
|
||||
SHRQ $0x33, R8
|
||||
MOVQ R11, R10
|
||||
SHRQ $0x33, R10
|
||||
MOVQ R13, R12
|
||||
SHRQ $0x33, R12
|
||||
MOVQ R15, R14
|
||||
SHRQ $0x33, R14
|
||||
ANDQ AX, DI
|
||||
IMUL3Q $0x13, R14, R14
|
||||
ADDQ R14, DI
|
||||
ANDQ AX, R9
|
||||
ADDQ SI, R9
|
||||
ANDQ AX, R11
|
||||
ADDQ R8, R11
|
||||
ANDQ AX, R13
|
||||
ADDQ R10, R13
|
||||
ANDQ AX, R15
|
||||
ADDQ R12, R15
|
||||
|
||||
// Store output
|
||||
MOVQ out+0(FP), AX
|
||||
MOVQ DI, (AX)
|
||||
MOVQ R9, 8(AX)
|
||||
MOVQ R11, 16(AX)
|
||||
MOVQ R13, 24(AX)
|
||||
MOVQ R15, 32(AX)
|
||||
RET
|
||||
|
||||
// func feSquare(out *Element, a *Element)
|
||||
TEXT ·feSquare(SB), NOSPLIT, $0-16
|
||||
MOVQ a+8(FP), CX
|
||||
|
||||
// r0 = l0×l0
|
||||
MOVQ (CX), AX
|
||||
MULQ (CX)
|
||||
MOVQ AX, SI
|
||||
MOVQ DX, BX
|
||||
|
||||
// r0 += 38×l1×l4
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, SI
|
||||
ADCQ DX, BX
|
||||
|
||||
// r0 += 38×l2×l3
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, SI
|
||||
ADCQ DX, BX
|
||||
|
||||
// r1 = 2×l0×l1
|
||||
MOVQ (CX), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 8(CX)
|
||||
MOVQ AX, R8
|
||||
MOVQ DX, DI
|
||||
|
||||
// r1 += 38×l2×l4
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R8
|
||||
ADCQ DX, DI
|
||||
|
||||
// r1 += 19×l3×l3
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, R8
|
||||
ADCQ DX, DI
|
||||
|
||||
// r2 = 2×l0×l2
|
||||
MOVQ (CX), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 16(CX)
|
||||
MOVQ AX, R10
|
||||
MOVQ DX, R9
|
||||
|
||||
// r2 += l1×l1
|
||||
MOVQ 8(CX), AX
|
||||
MULQ 8(CX)
|
||||
ADDQ AX, R10
|
||||
ADCQ DX, R9
|
||||
|
||||
// r2 += 38×l3×l4
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R10
|
||||
ADCQ DX, R9
|
||||
|
||||
// r3 = 2×l0×l3
|
||||
MOVQ (CX), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 24(CX)
|
||||
MOVQ AX, R12
|
||||
MOVQ DX, R11
|
||||
|
||||
// r3 += 2×l1×l2
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x02, AX, AX
|
||||
MULQ 16(CX)
|
||||
ADDQ AX, R12
|
||||
ADCQ DX, R11
|
||||
|
||||
// r3 += 19×l4×l4
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R12
|
||||
ADCQ DX, R11
|
||||
|
||||
// r4 = 2×l0×l4
|
||||
MOVQ (CX), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 32(CX)
|
||||
MOVQ AX, R14
|
||||
MOVQ DX, R13
|
||||
|
||||
// r4 += 2×l1×l3
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x02, AX, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, R14
|
||||
ADCQ DX, R13
|
||||
|
||||
// r4 += l2×l2
|
||||
MOVQ 16(CX), AX
|
||||
MULQ 16(CX)
|
||||
ADDQ AX, R14
|
||||
ADCQ DX, R13
|
||||
|
||||
// First reduction chain
|
||||
MOVQ $0x0007ffffffffffff, AX
|
||||
SHLQ $0x0d, SI, BX
|
||||
SHLQ $0x0d, R8, DI
|
||||
SHLQ $0x0d, R10, R9
|
||||
SHLQ $0x0d, R12, R11
|
||||
SHLQ $0x0d, R14, R13
|
||||
ANDQ AX, SI
|
||||
IMUL3Q $0x13, R13, R13
|
||||
ADDQ R13, SI
|
||||
ANDQ AX, R8
|
||||
ADDQ BX, R8
|
||||
ANDQ AX, R10
|
||||
ADDQ DI, R10
|
||||
ANDQ AX, R12
|
||||
ADDQ R9, R12
|
||||
ANDQ AX, R14
|
||||
ADDQ R11, R14
|
||||
|
||||
// Second reduction chain (carryPropagate)
|
||||
MOVQ SI, BX
|
||||
SHRQ $0x33, BX
|
||||
MOVQ R8, DI
|
||||
SHRQ $0x33, DI
|
||||
MOVQ R10, R9
|
||||
SHRQ $0x33, R9
|
||||
MOVQ R12, R11
|
||||
SHRQ $0x33, R11
|
||||
MOVQ R14, R13
|
||||
SHRQ $0x33, R13
|
||||
ANDQ AX, SI
|
||||
IMUL3Q $0x13, R13, R13
|
||||
ADDQ R13, SI
|
||||
ANDQ AX, R8
|
||||
ADDQ BX, R8
|
||||
ANDQ AX, R10
|
||||
ADDQ DI, R10
|
||||
ANDQ AX, R12
|
||||
ADDQ R9, R12
|
||||
ANDQ AX, R14
|
||||
ADDQ R11, R14
|
||||
|
||||
// Store output
|
||||
MOVQ out+0(FP), AX
|
||||
MOVQ SI, (AX)
|
||||
MOVQ R8, 8(AX)
|
||||
MOVQ R10, 16(AX)
|
||||
MOVQ R12, 24(AX)
|
||||
MOVQ R14, 32(AX)
|
||||
RET
|
||||
12
vendor/filippo.io/edwards25519/field/fe_amd64_noasm.go
generated
vendored
Normal file
12
vendor/filippo.io/edwards25519/field/fe_amd64_noasm.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright (c) 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !amd64 || !gc || purego
|
||||
// +build !amd64 !gc purego
|
||||
|
||||
package field
|
||||
|
||||
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
|
||||
|
||||
func feSquare(v, x *Element) { feSquareGeneric(v, x) }
|
||||
16
vendor/filippo.io/edwards25519/field/fe_arm64.go
generated
vendored
Normal file
16
vendor/filippo.io/edwards25519/field/fe_arm64.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2020 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build arm64 && gc && !purego
|
||||
// +build arm64,gc,!purego
|
||||
|
||||
package field
|
||||
|
||||
//go:noescape
|
||||
func carryPropagate(v *Element)
|
||||
|
||||
func (v *Element) carryPropagate() *Element {
|
||||
carryPropagate(v)
|
||||
return v
|
||||
}
|
||||
42
vendor/filippo.io/edwards25519/field/fe_arm64.s
generated
vendored
Normal file
42
vendor/filippo.io/edwards25519/field/fe_arm64.s
generated
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2020 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build arm64 && gc && !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// carryPropagate works exactly like carryPropagateGeneric and uses the
|
||||
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
|
||||
// avoids loading R0-R4 twice and uses LDP and STP.
|
||||
//
|
||||
// See https://golang.org/issues/43145 for the main compiler issue.
|
||||
//
|
||||
// func carryPropagate(v *Element)
|
||||
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
|
||||
MOVD v+0(FP), R20
|
||||
|
||||
LDP 0(R20), (R0, R1)
|
||||
LDP 16(R20), (R2, R3)
|
||||
MOVD 32(R20), R4
|
||||
|
||||
AND $0x7ffffffffffff, R0, R10
|
||||
AND $0x7ffffffffffff, R1, R11
|
||||
AND $0x7ffffffffffff, R2, R12
|
||||
AND $0x7ffffffffffff, R3, R13
|
||||
AND $0x7ffffffffffff, R4, R14
|
||||
|
||||
ADD R0>>51, R11, R11
|
||||
ADD R1>>51, R12, R12
|
||||
ADD R2>>51, R13, R13
|
||||
ADD R3>>51, R14, R14
|
||||
// R4>>51 * 19 + R10 -> R10
|
||||
LSR $51, R4, R21
|
||||
MOVD $19, R22
|
||||
MADD R22, R10, R21, R10
|
||||
|
||||
STP (R10, R11), 0(R20)
|
||||
STP (R12, R13), 16(R20)
|
||||
MOVD R14, 32(R20)
|
||||
|
||||
RET
|
||||
12
vendor/filippo.io/edwards25519/field/fe_arm64_noasm.go
generated
vendored
Normal file
12
vendor/filippo.io/edwards25519/field/fe_arm64_noasm.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright (c) 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !arm64 || !gc || purego
|
||||
// +build !arm64 !gc purego
|
||||
|
||||
package field
|
||||
|
||||
func (v *Element) carryPropagate() *Element {
|
||||
return v.carryPropagateGeneric()
|
||||
}
|
||||
50
vendor/filippo.io/edwards25519/field/fe_extra.go
generated
vendored
Normal file
50
vendor/filippo.io/edwards25519/field/fe_extra.go
generated
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package field
|
||||
|
||||
import "errors"
|
||||
|
||||
// This file contains additional functionality that is not included in the
|
||||
// upstream crypto/ed25519/edwards25519/field package.
|
||||
|
||||
// SetWideBytes sets v to x, where x is a 64-byte little-endian encoding, which
|
||||
// is reduced modulo the field order. If x is not of the right length,
|
||||
// SetWideBytes returns nil and an error, and the receiver is unchanged.
|
||||
//
|
||||
// SetWideBytes is not necessary to select a uniformly distributed value, and is
|
||||
// only provided for compatibility: SetBytes can be used instead as the chance
|
||||
// of bias is less than 2⁻²⁵⁰.
|
||||
func (v *Element) SetWideBytes(x []byte) (*Element, error) {
|
||||
if len(x) != 64 {
|
||||
return nil, errors.New("edwards25519: invalid SetWideBytes input size")
|
||||
}
|
||||
|
||||
// Split the 64 bytes into two elements, and extract the most significant
|
||||
// bit of each, which is ignored by SetBytes.
|
||||
lo, _ := new(Element).SetBytes(x[:32])
|
||||
loMSB := uint64(x[31] >> 7)
|
||||
hi, _ := new(Element).SetBytes(x[32:])
|
||||
hiMSB := uint64(x[63] >> 7)
|
||||
|
||||
// The output we want is
|
||||
//
|
||||
// v = lo + loMSB * 2²⁵⁵ + hi * 2²⁵⁶ + hiMSB * 2⁵¹¹
|
||||
//
|
||||
// which applying the reduction identity comes out to
|
||||
//
|
||||
// v = lo + loMSB * 19 + hi * 2 * 19 + hiMSB * 2 * 19²
|
||||
//
|
||||
// l0 will be the sum of a 52 bits value (lo.l0), plus a 5 bits value
|
||||
// (loMSB * 19), a 6 bits value (hi.l0 * 2 * 19), and a 10 bits value
|
||||
// (hiMSB * 2 * 19²), so it fits in a uint64.
|
||||
|
||||
v.l0 = lo.l0 + loMSB*19 + hi.l0*2*19 + hiMSB*2*19*19
|
||||
v.l1 = lo.l1 + hi.l1*2*19
|
||||
v.l2 = lo.l2 + hi.l2*2*19
|
||||
v.l3 = lo.l3 + hi.l3*2*19
|
||||
v.l4 = lo.l4 + hi.l4*2*19
|
||||
|
||||
return v.carryPropagate(), nil
|
||||
}
|
||||
266
vendor/filippo.io/edwards25519/field/fe_generic.go
generated
vendored
Normal file
266
vendor/filippo.io/edwards25519/field/fe_generic.go
generated
vendored
Normal file
@@ -0,0 +1,266 @@
|
||||
// Copyright (c) 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package field
|
||||
|
||||
import "math/bits"
|
||||
|
||||
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
|
||||
// bits.Mul64 and bits.Add64 intrinsics.
|
||||
type uint128 struct {
|
||||
lo, hi uint64
|
||||
}
|
||||
|
||||
// mul64 returns a * b.
|
||||
func mul64(a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(a, b)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
// addMul64 returns v + a * b.
|
||||
func addMul64(v uint128, a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(a, b)
|
||||
lo, c := bits.Add64(lo, v.lo, 0)
|
||||
hi, _ = bits.Add64(hi, v.hi, c)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
|
||||
func shiftRightBy51(a uint128) uint64 {
|
||||
return (a.hi << (64 - 51)) | (a.lo >> 51)
|
||||
}
|
||||
|
||||
func feMulGeneric(v, a, b *Element) {
|
||||
a0 := a.l0
|
||||
a1 := a.l1
|
||||
a2 := a.l2
|
||||
a3 := a.l3
|
||||
a4 := a.l4
|
||||
|
||||
b0 := b.l0
|
||||
b1 := b.l1
|
||||
b2 := b.l2
|
||||
b3 := b.l3
|
||||
b4 := b.l4
|
||||
|
||||
// Limb multiplication works like pen-and-paper columnar multiplication, but
|
||||
// with 51-bit limbs instead of digits.
|
||||
//
|
||||
// a4 a3 a2 a1 a0 x
|
||||
// b4 b3 b2 b1 b0 =
|
||||
// ------------------------
|
||||
// a4b0 a3b0 a2b0 a1b0 a0b0 +
|
||||
// a4b1 a3b1 a2b1 a1b1 a0b1 +
|
||||
// a4b2 a3b2 a2b2 a1b2 a0b2 +
|
||||
// a4b3 a3b3 a2b3 a1b3 a0b3 +
|
||||
// a4b4 a3b4 a2b4 a1b4 a0b4 =
|
||||
// ----------------------------------------------
|
||||
// r8 r7 r6 r5 r4 r3 r2 r1 r0
|
||||
//
|
||||
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
|
||||
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
|
||||
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
|
||||
//
|
||||
// Reduction can be carried out simultaneously to multiplication. For
|
||||
// example, we do not compute r5: whenever the result of a multiplication
|
||||
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
|
||||
//
|
||||
// a4b0 a3b0 a2b0 a1b0 a0b0 +
|
||||
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
|
||||
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
|
||||
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
|
||||
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
|
||||
// --------------------------------------
|
||||
// r4 r3 r2 r1 r0
|
||||
//
|
||||
// Finally we add up the columns into wide, overlapping limbs.
|
||||
|
||||
a1_19 := a1 * 19
|
||||
a2_19 := a2 * 19
|
||||
a3_19 := a3 * 19
|
||||
a4_19 := a4 * 19
|
||||
|
||||
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
|
||||
r0 := mul64(a0, b0)
|
||||
r0 = addMul64(r0, a1_19, b4)
|
||||
r0 = addMul64(r0, a2_19, b3)
|
||||
r0 = addMul64(r0, a3_19, b2)
|
||||
r0 = addMul64(r0, a4_19, b1)
|
||||
|
||||
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
|
||||
r1 := mul64(a0, b1)
|
||||
r1 = addMul64(r1, a1, b0)
|
||||
r1 = addMul64(r1, a2_19, b4)
|
||||
r1 = addMul64(r1, a3_19, b3)
|
||||
r1 = addMul64(r1, a4_19, b2)
|
||||
|
||||
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
|
||||
r2 := mul64(a0, b2)
|
||||
r2 = addMul64(r2, a1, b1)
|
||||
r2 = addMul64(r2, a2, b0)
|
||||
r2 = addMul64(r2, a3_19, b4)
|
||||
r2 = addMul64(r2, a4_19, b3)
|
||||
|
||||
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
|
||||
r3 := mul64(a0, b3)
|
||||
r3 = addMul64(r3, a1, b2)
|
||||
r3 = addMul64(r3, a2, b1)
|
||||
r3 = addMul64(r3, a3, b0)
|
||||
r3 = addMul64(r3, a4_19, b4)
|
||||
|
||||
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
|
||||
r4 := mul64(a0, b4)
|
||||
r4 = addMul64(r4, a1, b3)
|
||||
r4 = addMul64(r4, a2, b2)
|
||||
r4 = addMul64(r4, a3, b1)
|
||||
r4 = addMul64(r4, a4, b0)
|
||||
|
||||
// After the multiplication, we need to reduce (carry) the five coefficients
|
||||
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
|
||||
// to respect the Element invariant.
|
||||
//
|
||||
// Overall, the reduction works the same as carryPropagate, except with
|
||||
// wider inputs: we take the carry for each coefficient by shifting it right
|
||||
// by 51, and add it to the limb above it. The top carry is multiplied by 19
|
||||
// according to the reduction identity and added to the lowest limb.
|
||||
//
|
||||
// The largest coefficient (r0) will be at most 111 bits, which guarantees
|
||||
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
|
||||
//
|
||||
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
|
||||
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
|
||||
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
|
||||
// r0 < 2⁷ × 2⁵² × 2⁵²
|
||||
// r0 < 2¹¹¹
|
||||
//
|
||||
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
|
||||
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
|
||||
// allows us to easily apply the reduction identity.
|
||||
//
|
||||
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
|
||||
// r4 < 5 × 2⁵² × 2⁵²
|
||||
// r4 < 2¹⁰⁷
|
||||
//
|
||||
|
||||
c0 := shiftRightBy51(r0)
|
||||
c1 := shiftRightBy51(r1)
|
||||
c2 := shiftRightBy51(r2)
|
||||
c3 := shiftRightBy51(r3)
|
||||
c4 := shiftRightBy51(r4)
|
||||
|
||||
rr0 := r0.lo&maskLow51Bits + c4*19
|
||||
rr1 := r1.lo&maskLow51Bits + c0
|
||||
rr2 := r2.lo&maskLow51Bits + c1
|
||||
rr3 := r3.lo&maskLow51Bits + c2
|
||||
rr4 := r4.lo&maskLow51Bits + c3
|
||||
|
||||
// Now all coefficients fit into 64-bit registers but are still too large to
|
||||
// be passed around as an Element. We therefore do one last carry chain,
|
||||
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
|
||||
*v = Element{rr0, rr1, rr2, rr3, rr4}
|
||||
v.carryPropagate()
|
||||
}
|
||||
|
||||
func feSquareGeneric(v, a *Element) {
|
||||
l0 := a.l0
|
||||
l1 := a.l1
|
||||
l2 := a.l2
|
||||
l3 := a.l3
|
||||
l4 := a.l4
|
||||
|
||||
// Squaring works precisely like multiplication above, but thanks to its
|
||||
// symmetry we get to group a few terms together.
|
||||
//
|
||||
// l4 l3 l2 l1 l0 x
|
||||
// l4 l3 l2 l1 l0 =
|
||||
// ------------------------
|
||||
// l4l0 l3l0 l2l0 l1l0 l0l0 +
|
||||
// l4l1 l3l1 l2l1 l1l1 l0l1 +
|
||||
// l4l2 l3l2 l2l2 l1l2 l0l2 +
|
||||
// l4l3 l3l3 l2l3 l1l3 l0l3 +
|
||||
// l4l4 l3l4 l2l4 l1l4 l0l4 =
|
||||
// ----------------------------------------------
|
||||
// r8 r7 r6 r5 r4 r3 r2 r1 r0
|
||||
//
|
||||
// l4l0 l3l0 l2l0 l1l0 l0l0 +
|
||||
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
|
||||
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
|
||||
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
|
||||
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
|
||||
// --------------------------------------
|
||||
// r4 r3 r2 r1 r0
|
||||
//
|
||||
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
|
||||
// only three Mul64 and four Add64, instead of five and eight.
|
||||
|
||||
l0_2 := l0 * 2
|
||||
l1_2 := l1 * 2
|
||||
|
||||
l1_38 := l1 * 38
|
||||
l2_38 := l2 * 38
|
||||
l3_38 := l3 * 38
|
||||
|
||||
l3_19 := l3 * 19
|
||||
l4_19 := l4 * 19
|
||||
|
||||
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
|
||||
r0 := mul64(l0, l0)
|
||||
r0 = addMul64(r0, l1_38, l4)
|
||||
r0 = addMul64(r0, l2_38, l3)
|
||||
|
||||
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
|
||||
r1 := mul64(l0_2, l1)
|
||||
r1 = addMul64(r1, l2_38, l4)
|
||||
r1 = addMul64(r1, l3_19, l3)
|
||||
|
||||
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
|
||||
r2 := mul64(l0_2, l2)
|
||||
r2 = addMul64(r2, l1, l1)
|
||||
r2 = addMul64(r2, l3_38, l4)
|
||||
|
||||
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
|
||||
r3 := mul64(l0_2, l3)
|
||||
r3 = addMul64(r3, l1_2, l2)
|
||||
r3 = addMul64(r3, l4_19, l4)
|
||||
|
||||
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
|
||||
r4 := mul64(l0_2, l4)
|
||||
r4 = addMul64(r4, l1_2, l3)
|
||||
r4 = addMul64(r4, l2, l2)
|
||||
|
||||
c0 := shiftRightBy51(r0)
|
||||
c1 := shiftRightBy51(r1)
|
||||
c2 := shiftRightBy51(r2)
|
||||
c3 := shiftRightBy51(r3)
|
||||
c4 := shiftRightBy51(r4)
|
||||
|
||||
rr0 := r0.lo&maskLow51Bits + c4*19
|
||||
rr1 := r1.lo&maskLow51Bits + c0
|
||||
rr2 := r2.lo&maskLow51Bits + c1
|
||||
rr3 := r3.lo&maskLow51Bits + c2
|
||||
rr4 := r4.lo&maskLow51Bits + c3
|
||||
|
||||
*v = Element{rr0, rr1, rr2, rr3, rr4}
|
||||
v.carryPropagate()
|
||||
}
|
||||
|
||||
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
|
||||
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
|
||||
func (v *Element) carryPropagateGeneric() *Element {
|
||||
c0 := v.l0 >> 51
|
||||
c1 := v.l1 >> 51
|
||||
c2 := v.l2 >> 51
|
||||
c3 := v.l3 >> 51
|
||||
c4 := v.l4 >> 51
|
||||
|
||||
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
|
||||
// the final l0 will be at most 52 bits. Similarly for the rest.
|
||||
v.l0 = v.l0&maskLow51Bits + c4*19
|
||||
v.l1 = v.l1&maskLow51Bits + c0
|
||||
v.l2 = v.l2&maskLow51Bits + c1
|
||||
v.l3 = v.l3&maskLow51Bits + c2
|
||||
v.l4 = v.l4&maskLow51Bits + c3
|
||||
|
||||
return v
|
||||
}
|
||||
343
vendor/filippo.io/edwards25519/scalar.go
generated
vendored
Normal file
343
vendor/filippo.io/edwards25519/scalar.go
generated
vendored
Normal file
@@ -0,0 +1,343 @@
|
||||
// Copyright (c) 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package edwards25519
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// A Scalar is an integer modulo
|
||||
//
|
||||
// l = 2^252 + 27742317777372353535851937790883648493
|
||||
//
|
||||
// which is the prime order of the edwards25519 group.
|
||||
//
|
||||
// This type works similarly to math/big.Int, and all arguments and
|
||||
// receivers are allowed to alias.
|
||||
//
|
||||
// The zero value is a valid zero element.
|
||||
type Scalar struct {
|
||||
// s is the scalar in the Montgomery domain, in the format of the
|
||||
// fiat-crypto implementation.
|
||||
s fiatScalarMontgomeryDomainFieldElement
|
||||
}
|
||||
|
||||
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
|
||||
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
|
||||
// from a formally verified model.
|
||||
//
|
||||
// fiat-crypto code comes under the following license.
|
||||
//
|
||||
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
|
||||
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
|
||||
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
|
||||
// NewScalar returns a new zero Scalar.
|
||||
func NewScalar() *Scalar {
|
||||
return &Scalar{}
|
||||
}
|
||||
|
||||
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
|
||||
// using Multiply and then Add.
|
||||
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
|
||||
// Make a copy of z in case it aliases s.
|
||||
zCopy := new(Scalar).Set(z)
|
||||
return s.Multiply(x, y).Add(s, zCopy)
|
||||
}
|
||||
|
||||
// Add sets s = x + y mod l, and returns s.
|
||||
func (s *Scalar) Add(x, y *Scalar) *Scalar {
|
||||
// s = 1 * x + y mod l
|
||||
fiatScalarAdd(&s.s, &x.s, &y.s)
|
||||
return s
|
||||
}
|
||||
|
||||
// Subtract sets s = x - y mod l, and returns s.
|
||||
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
|
||||
// s = -1 * y + x mod l
|
||||
fiatScalarSub(&s.s, &x.s, &y.s)
|
||||
return s
|
||||
}
|
||||
|
||||
// Negate sets s = -x mod l, and returns s.
|
||||
func (s *Scalar) Negate(x *Scalar) *Scalar {
|
||||
// s = -1 * x + 0 mod l
|
||||
fiatScalarOpp(&s.s, &x.s)
|
||||
return s
|
||||
}
|
||||
|
||||
// Multiply sets s = x * y mod l, and returns s.
|
||||
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
|
||||
// s = x * y + 0 mod l
|
||||
fiatScalarMul(&s.s, &x.s, &y.s)
|
||||
return s
|
||||
}
|
||||
|
||||
// Set sets s = x, and returns s.
|
||||
func (s *Scalar) Set(x *Scalar) *Scalar {
|
||||
*s = *x
|
||||
return s
|
||||
}
|
||||
|
||||
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
|
||||
// If x is not of the right length, SetUniformBytes returns nil and an error,
|
||||
// and the receiver is unchanged.
|
||||
//
|
||||
// SetUniformBytes can be used to set s to a uniformly distributed value given
|
||||
// 64 uniformly distributed random bytes.
|
||||
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
|
||||
if len(x) != 64 {
|
||||
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
|
||||
}
|
||||
|
||||
// We have a value x of 512 bits, but our fiatScalarFromBytes function
|
||||
// expects an input lower than l, which is a little over 252 bits.
|
||||
//
|
||||
// Instead of writing a reduction function that operates on wider inputs, we
|
||||
// can interpret x as the sum of three shorter values a, b, and c.
|
||||
//
|
||||
// x = a + b * 2^168 + c * 2^336 mod l
|
||||
//
|
||||
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
|
||||
// with two multiplications and two additions.
|
||||
|
||||
s.setShortBytes(x[:21])
|
||||
t := new(Scalar).setShortBytes(x[21:42])
|
||||
s.Add(s, t.Multiply(t, scalarTwo168))
|
||||
t.setShortBytes(x[42:])
|
||||
s.Add(s, t.Multiply(t, scalarTwo336))
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
|
||||
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
|
||||
// in the 2^256 Montgomery domain.
|
||||
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
|
||||
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
|
||||
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
|
||||
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
|
||||
|
||||
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
|
||||
// than 32 bytes.
|
||||
func (s *Scalar) setShortBytes(x []byte) *Scalar {
|
||||
if len(x) >= 32 {
|
||||
panic("edwards25519: internal error: setShortBytes called with a long string")
|
||||
}
|
||||
var buf [32]byte
|
||||
copy(buf[:], x)
|
||||
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
|
||||
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
|
||||
return s
|
||||
}
|
||||
|
||||
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
|
||||
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
|
||||
// returns nil and an error, and the receiver is unchanged.
|
||||
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
|
||||
if len(x) != 32 {
|
||||
return nil, errors.New("invalid scalar length")
|
||||
}
|
||||
if !isReduced(x) {
|
||||
return nil, errors.New("invalid scalar encoding")
|
||||
}
|
||||
|
||||
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
|
||||
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// scalarMinusOneBytes is l - 1 in little endian.
|
||||
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
|
||||
|
||||
// isReduced returns whether the given scalar in 32-byte little endian encoded
|
||||
// form is reduced modulo l.
|
||||
func isReduced(s []byte) bool {
|
||||
if len(s) != 32 {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
switch {
|
||||
case s[i] > scalarMinusOneBytes[i]:
|
||||
return false
|
||||
case s[i] < scalarMinusOneBytes[i]:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
|
||||
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
|
||||
// must be 32 bytes, and it is not modified. If x is not of the right length,
|
||||
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
|
||||
//
|
||||
// Note that since Scalar values are always reduced modulo the prime order of
|
||||
// the curve, the resulting value will not preserve any of the cofactor-clearing
|
||||
// properties that clamping is meant to provide. It will however work as
|
||||
// expected as long as it is applied to points on the prime order subgroup, like
|
||||
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
|
||||
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
|
||||
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
|
||||
// The description above omits the purpose of the high bits of the clamping
|
||||
// for brevity, but those are also lost to reductions, and are also
|
||||
// irrelevant to edwards25519 as they protect against a specific
|
||||
// implementation bug that was once observed in a generic Montgomery ladder.
|
||||
if len(x) != 32 {
|
||||
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
|
||||
}
|
||||
|
||||
// We need to use the wide reduction from SetUniformBytes, since clamping
|
||||
// sets the 2^254 bit, making the value higher than the order.
|
||||
var wideBytes [64]byte
|
||||
copy(wideBytes[:], x[:])
|
||||
wideBytes[0] &= 248
|
||||
wideBytes[31] &= 63
|
||||
wideBytes[31] |= 64
|
||||
return s.SetUniformBytes(wideBytes[:])
|
||||
}
|
||||
|
||||
// Bytes returns the canonical 32-byte little-endian encoding of s.
|
||||
func (s *Scalar) Bytes() []byte {
|
||||
// This function is outlined to make the allocations inline in the caller
|
||||
// rather than happen on the heap.
|
||||
var encoded [32]byte
|
||||
return s.bytes(&encoded)
|
||||
}
|
||||
|
||||
func (s *Scalar) bytes(out *[32]byte) []byte {
|
||||
var ss fiatScalarNonMontgomeryDomainFieldElement
|
||||
fiatScalarFromMontgomery(&ss, &s.s)
|
||||
fiatScalarToBytes(out, (*[4]uint64)(&ss))
|
||||
return out[:]
|
||||
}
|
||||
|
||||
// Equal returns 1 if s and t are equal, and 0 otherwise.
|
||||
func (s *Scalar) Equal(t *Scalar) int {
|
||||
var diff fiatScalarMontgomeryDomainFieldElement
|
||||
fiatScalarSub(&diff, &s.s, &t.s)
|
||||
var nonzero uint64
|
||||
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
|
||||
nonzero |= nonzero >> 32
|
||||
nonzero |= nonzero >> 16
|
||||
nonzero |= nonzero >> 8
|
||||
nonzero |= nonzero >> 4
|
||||
nonzero |= nonzero >> 2
|
||||
nonzero |= nonzero >> 1
|
||||
return int(^nonzero) & 1
|
||||
}
|
||||
|
||||
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
|
||||
//
|
||||
// w must be between 2 and 8, or nonAdjacentForm will panic.
|
||||
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
|
||||
// This implementation is adapted from the one
|
||||
// in curve25519-dalek and is documented there:
|
||||
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
|
||||
b := s.Bytes()
|
||||
if b[31] > 127 {
|
||||
panic("scalar has high bit set illegally")
|
||||
}
|
||||
if w < 2 {
|
||||
panic("w must be at least 2 by the definition of NAF")
|
||||
} else if w > 8 {
|
||||
panic("NAF digits must fit in int8")
|
||||
}
|
||||
|
||||
var naf [256]int8
|
||||
var digits [5]uint64
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
|
||||
}
|
||||
|
||||
width := uint64(1 << w)
|
||||
windowMask := uint64(width - 1)
|
||||
|
||||
pos := uint(0)
|
||||
carry := uint64(0)
|
||||
for pos < 256 {
|
||||
indexU64 := pos / 64
|
||||
indexBit := pos % 64
|
||||
var bitBuf uint64
|
||||
if indexBit < 64-w {
|
||||
// This window's bits are contained in a single u64
|
||||
bitBuf = digits[indexU64] >> indexBit
|
||||
} else {
|
||||
// Combine the current 64 bits with bits from the next 64
|
||||
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
|
||||
}
|
||||
|
||||
// Add carry into the current window
|
||||
window := carry + (bitBuf & windowMask)
|
||||
|
||||
if window&1 == 0 {
|
||||
// If the window value is even, preserve the carry and continue.
|
||||
// Why is the carry preserved?
|
||||
// If carry == 0 and window & 1 == 0,
|
||||
// then the next carry should be 0
|
||||
// If carry == 1 and window & 1 == 0,
|
||||
// then bit_buf & 1 == 1 so the next carry should be 1
|
||||
pos += 1
|
||||
continue
|
||||
}
|
||||
|
||||
if window < width/2 {
|
||||
carry = 0
|
||||
naf[pos] = int8(window)
|
||||
} else {
|
||||
carry = 1
|
||||
naf[pos] = int8(window) - int8(width)
|
||||
}
|
||||
|
||||
pos += w
|
||||
}
|
||||
return naf
|
||||
}
|
||||
|
||||
func (s *Scalar) signedRadix16() [64]int8 {
|
||||
b := s.Bytes()
|
||||
if b[31] > 127 {
|
||||
panic("scalar has high bit set illegally")
|
||||
}
|
||||
|
||||
var digits [64]int8
|
||||
|
||||
// Compute unsigned radix-16 digits:
|
||||
for i := 0; i < 32; i++ {
|
||||
digits[2*i] = int8(b[i] & 15)
|
||||
digits[2*i+1] = int8((b[i] >> 4) & 15)
|
||||
}
|
||||
|
||||
// Recenter coefficients:
|
||||
for i := 0; i < 63; i++ {
|
||||
carry := (digits[i] + 8) >> 4
|
||||
digits[i] -= carry << 4
|
||||
digits[i+1] += carry
|
||||
}
|
||||
|
||||
return digits
|
||||
}
|
||||
1147
vendor/filippo.io/edwards25519/scalar_fiat.go
generated
vendored
Normal file
1147
vendor/filippo.io/edwards25519/scalar_fiat.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
214
vendor/filippo.io/edwards25519/scalarmult.go
generated
vendored
Normal file
214
vendor/filippo.io/edwards25519/scalarmult.go
generated
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
// Copyright (c) 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package edwards25519
|
||||
|
||||
import "sync"
|
||||
|
||||
// basepointTable is a set of 32 affineLookupTables, where table i is generated
|
||||
// from 256i * basepoint. It is precomputed the first time it's used.
|
||||
func basepointTable() *[32]affineLookupTable {
|
||||
basepointTablePrecomp.initOnce.Do(func() {
|
||||
p := NewGeneratorPoint()
|
||||
for i := 0; i < 32; i++ {
|
||||
basepointTablePrecomp.table[i].FromP3(p)
|
||||
for j := 0; j < 8; j++ {
|
||||
p.Add(p, p)
|
||||
}
|
||||
}
|
||||
})
|
||||
return &basepointTablePrecomp.table
|
||||
}
|
||||
|
||||
var basepointTablePrecomp struct {
|
||||
table [32]affineLookupTable
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
|
||||
// returns v.
|
||||
//
|
||||
// The scalar multiplication is done in constant time.
|
||||
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
|
||||
basepointTable := basepointTable()
|
||||
|
||||
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
|
||||
// as described in the Ed25519 paper
|
||||
//
|
||||
// Group even and odd coefficients
|
||||
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
|
||||
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
|
||||
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
|
||||
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
|
||||
//
|
||||
// We use a lookup table for each i to get x_i*16^(2*i)*B
|
||||
// and do four doublings to multiply by 16.
|
||||
digits := x.signedRadix16()
|
||||
|
||||
multiple := &affineCached{}
|
||||
tmp1 := &projP1xP1{}
|
||||
tmp2 := &projP2{}
|
||||
|
||||
// Accumulate the odd components first
|
||||
v.Set(NewIdentityPoint())
|
||||
for i := 1; i < 64; i += 2 {
|
||||
basepointTable[i/2].SelectInto(multiple, digits[i])
|
||||
tmp1.AddAffine(v, multiple)
|
||||
v.fromP1xP1(tmp1)
|
||||
}
|
||||
|
||||
// Multiply by 16
|
||||
tmp2.FromP3(v) // tmp2 = v in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
|
||||
v.fromP1xP1(tmp1) // now v = 16*(odd components)
|
||||
|
||||
// Accumulate the even components
|
||||
for i := 0; i < 64; i += 2 {
|
||||
basepointTable[i/2].SelectInto(multiple, digits[i])
|
||||
tmp1.AddAffine(v, multiple)
|
||||
v.fromP1xP1(tmp1)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// ScalarMult sets v = x * q, and returns v.
|
||||
//
|
||||
// The scalar multiplication is done in constant time.
|
||||
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
|
||||
checkInitialized(q)
|
||||
|
||||
var table projLookupTable
|
||||
table.FromP3(q)
|
||||
|
||||
// Write x = sum(x_i * 16^i)
|
||||
// so x*Q = sum( Q*x_i*16^i )
|
||||
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
|
||||
// <------compute inside out---------
|
||||
//
|
||||
// We use the lookup table to get the x_i*Q values
|
||||
// and do four doublings to compute 16*Q
|
||||
digits := x.signedRadix16()
|
||||
|
||||
// Unwrap first loop iteration to save computing 16*identity
|
||||
multiple := &projCached{}
|
||||
tmp1 := &projP1xP1{}
|
||||
tmp2 := &projP2{}
|
||||
table.SelectInto(multiple, digits[63])
|
||||
|
||||
v.Set(NewIdentityPoint())
|
||||
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
|
||||
for i := 62; i >= 0; i-- {
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
|
||||
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
|
||||
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
|
||||
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
|
||||
table.SelectInto(multiple, digits[i])
|
||||
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
|
||||
}
|
||||
v.fromP1xP1(tmp1)
|
||||
return v
|
||||
}
|
||||
|
||||
// basepointNafTable is the nafLookupTable8 for the basepoint.
|
||||
// It is precomputed the first time it's used.
|
||||
func basepointNafTable() *nafLookupTable8 {
|
||||
basepointNafTablePrecomp.initOnce.Do(func() {
|
||||
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
|
||||
})
|
||||
return &basepointNafTablePrecomp.table
|
||||
}
|
||||
|
||||
var basepointNafTablePrecomp struct {
|
||||
table nafLookupTable8
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
|
||||
// generator, and returns v.
|
||||
//
|
||||
// Execution time depends on the inputs.
|
||||
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
|
||||
checkInitialized(A)
|
||||
|
||||
// Similarly to the single variable-base approach, we compute
|
||||
// digits and use them with a lookup table. However, because
|
||||
// we are allowed to do variable-time operations, we don't
|
||||
// need constant-time lookups or constant-time digit
|
||||
// computations.
|
||||
//
|
||||
// So we use a non-adjacent form of some width w instead of
|
||||
// radix 16. This is like a binary representation (one digit
|
||||
// for each binary place) but we allow the digits to grow in
|
||||
// magnitude up to 2^{w-1} so that the nonzero digits are as
|
||||
// sparse as possible. Intuitively, this "condenses" the
|
||||
// "mass" of the scalar onto sparse coefficients (meaning
|
||||
// fewer additions).
|
||||
|
||||
basepointNafTable := basepointNafTable()
|
||||
var aTable nafLookupTable5
|
||||
aTable.FromP3(A)
|
||||
// Because the basepoint is fixed, we can use a wider NAF
|
||||
// corresponding to a bigger table.
|
||||
aNaf := a.nonAdjacentForm(5)
|
||||
bNaf := b.nonAdjacentForm(8)
|
||||
|
||||
// Find the first nonzero coefficient.
|
||||
i := 255
|
||||
for j := i; j >= 0; j-- {
|
||||
if aNaf[j] != 0 || bNaf[j] != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
multA := &projCached{}
|
||||
multB := &affineCached{}
|
||||
tmp1 := &projP1xP1{}
|
||||
tmp2 := &projP2{}
|
||||
tmp2.Zero()
|
||||
|
||||
// Move from high to low bits, doubling the accumulator
|
||||
// at each iteration and checking whether there is a nonzero
|
||||
// coefficient to look up a multiple of.
|
||||
for ; i >= 0; i-- {
|
||||
tmp1.Double(tmp2)
|
||||
|
||||
// Only update v if we have a nonzero coeff to add in.
|
||||
if aNaf[i] > 0 {
|
||||
v.fromP1xP1(tmp1)
|
||||
aTable.SelectInto(multA, aNaf[i])
|
||||
tmp1.Add(v, multA)
|
||||
} else if aNaf[i] < 0 {
|
||||
v.fromP1xP1(tmp1)
|
||||
aTable.SelectInto(multA, -aNaf[i])
|
||||
tmp1.Sub(v, multA)
|
||||
}
|
||||
|
||||
if bNaf[i] > 0 {
|
||||
v.fromP1xP1(tmp1)
|
||||
basepointNafTable.SelectInto(multB, bNaf[i])
|
||||
tmp1.AddAffine(v, multB)
|
||||
} else if bNaf[i] < 0 {
|
||||
v.fromP1xP1(tmp1)
|
||||
basepointNafTable.SelectInto(multB, -bNaf[i])
|
||||
tmp1.SubAffine(v, multB)
|
||||
}
|
||||
|
||||
tmp2.FromP1xP1(tmp1)
|
||||
}
|
||||
|
||||
v.fromP2(tmp2)
|
||||
return v
|
||||
}
|
||||
129
vendor/filippo.io/edwards25519/tables.go
generated
vendored
Normal file
129
vendor/filippo.io/edwards25519/tables.go
generated
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package edwards25519
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
)
|
||||
|
||||
// A dynamic lookup table for variable-base, constant-time scalar muls.
|
||||
type projLookupTable struct {
|
||||
points [8]projCached
|
||||
}
|
||||
|
||||
// A precomputed lookup table for fixed-base, constant-time scalar muls.
|
||||
type affineLookupTable struct {
|
||||
points [8]affineCached
|
||||
}
|
||||
|
||||
// A dynamic lookup table for variable-base, variable-time scalar muls.
|
||||
type nafLookupTable5 struct {
|
||||
points [8]projCached
|
||||
}
|
||||
|
||||
// A precomputed lookup table for fixed-base, variable-time scalar muls.
|
||||
type nafLookupTable8 struct {
|
||||
points [64]affineCached
|
||||
}
|
||||
|
||||
// Constructors.
|
||||
|
||||
// Builds a lookup table at runtime. Fast.
|
||||
func (v *projLookupTable) FromP3(q *Point) {
|
||||
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
|
||||
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
|
||||
v.points[0].FromP3(q)
|
||||
tmpP3 := Point{}
|
||||
tmpP1xP1 := projP1xP1{}
|
||||
for i := 0; i < 7; i++ {
|
||||
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
|
||||
// This is needlessly complicated because the API has explicit
|
||||
// receivers instead of creating stack objects and relying on RVO
|
||||
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
|
||||
}
|
||||
}
|
||||
|
||||
// This is not optimised for speed; fixed-base tables should be precomputed.
|
||||
func (v *affineLookupTable) FromP3(q *Point) {
|
||||
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
|
||||
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
|
||||
v.points[0].FromP3(q)
|
||||
tmpP3 := Point{}
|
||||
tmpP1xP1 := projP1xP1{}
|
||||
for i := 0; i < 7; i++ {
|
||||
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
|
||||
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
|
||||
}
|
||||
}
|
||||
|
||||
// Builds a lookup table at runtime. Fast.
|
||||
func (v *nafLookupTable5) FromP3(q *Point) {
|
||||
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
|
||||
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
|
||||
v.points[0].FromP3(q)
|
||||
q2 := Point{}
|
||||
q2.Add(q, q)
|
||||
tmpP3 := Point{}
|
||||
tmpP1xP1 := projP1xP1{}
|
||||
for i := 0; i < 7; i++ {
|
||||
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
|
||||
}
|
||||
}
|
||||
|
||||
// This is not optimised for speed; fixed-base tables should be precomputed.
|
||||
func (v *nafLookupTable8) FromP3(q *Point) {
|
||||
v.points[0].FromP3(q)
|
||||
q2 := Point{}
|
||||
q2.Add(q, q)
|
||||
tmpP3 := Point{}
|
||||
tmpP1xP1 := projP1xP1{}
|
||||
for i := 0; i < 63; i++ {
|
||||
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
|
||||
}
|
||||
}
|
||||
|
||||
// Selectors.
|
||||
|
||||
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
|
||||
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
|
||||
// Compute xabs = |x|
|
||||
xmask := x >> 7
|
||||
xabs := uint8((x + xmask) ^ xmask)
|
||||
|
||||
dest.Zero()
|
||||
for j := 1; j <= 8; j++ {
|
||||
// Set dest = j*Q if |x| = j
|
||||
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
|
||||
dest.Select(&v.points[j-1], dest, cond)
|
||||
}
|
||||
// Now dest = |x|*Q, conditionally negate to get x*Q
|
||||
dest.CondNeg(int(xmask & 1))
|
||||
}
|
||||
|
||||
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
|
||||
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
|
||||
// Compute xabs = |x|
|
||||
xmask := x >> 7
|
||||
xabs := uint8((x + xmask) ^ xmask)
|
||||
|
||||
dest.Zero()
|
||||
for j := 1; j <= 8; j++ {
|
||||
// Set dest = j*Q if |x| = j
|
||||
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
|
||||
dest.Select(&v.points[j-1], dest, cond)
|
||||
}
|
||||
// Now dest = |x|*Q, conditionally negate to get x*Q
|
||||
dest.CondNeg(int(xmask & 1))
|
||||
}
|
||||
|
||||
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
|
||||
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
|
||||
*dest = v.points[x/2]
|
||||
}
|
||||
|
||||
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
|
||||
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
|
||||
*dest = v.points[x/2]
|
||||
}
|
||||
3
vendor/github.com/aws/aws-sdk-go/NOTICE.txt
generated
vendored
Normal file
3
vendor/github.com/aws/aws-sdk-go/NOTICE.txt
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
AWS SDK for Go
|
||||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright 2014-2015 Stripe, Inc.
|
||||
93
vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go
generated
vendored
Normal file
93
vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go
generated
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
// Package arn provides a parser for interacting with Amazon Resource Names.
|
||||
package arn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
arnDelimiter = ":"
|
||||
arnSections = 6
|
||||
arnPrefix = "arn:"
|
||||
|
||||
// zero-indexed
|
||||
sectionPartition = 1
|
||||
sectionService = 2
|
||||
sectionRegion = 3
|
||||
sectionAccountID = 4
|
||||
sectionResource = 5
|
||||
|
||||
// errors
|
||||
invalidPrefix = "arn: invalid prefix"
|
||||
invalidSections = "arn: not enough sections"
|
||||
)
|
||||
|
||||
// ARN captures the individual fields of an Amazon Resource Name.
|
||||
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
|
||||
type ARN struct {
|
||||
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
|
||||
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
|
||||
// (Beijing) region is "aws-cn".
|
||||
Partition string
|
||||
|
||||
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
|
||||
// namespaces, see
|
||||
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
|
||||
Service string
|
||||
|
||||
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
|
||||
// component might be omitted.
|
||||
Region string
|
||||
|
||||
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
|
||||
// ARNs for some resources don't require an account number, so this component might be omitted.
|
||||
AccountID string
|
||||
|
||||
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
|
||||
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
|
||||
// resource name itself. Some services allows paths for resource names, as described in
|
||||
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
|
||||
Resource string
|
||||
}
|
||||
|
||||
// Parse parses an ARN into its constituent parts.
|
||||
//
|
||||
// Some example ARNs:
|
||||
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
|
||||
// arn:aws:iam::123456789012:user/David
|
||||
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
|
||||
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
|
||||
func Parse(arn string) (ARN, error) {
|
||||
if !strings.HasPrefix(arn, arnPrefix) {
|
||||
return ARN{}, errors.New(invalidPrefix)
|
||||
}
|
||||
sections := strings.SplitN(arn, arnDelimiter, arnSections)
|
||||
if len(sections) != arnSections {
|
||||
return ARN{}, errors.New(invalidSections)
|
||||
}
|
||||
return ARN{
|
||||
Partition: sections[sectionPartition],
|
||||
Service: sections[sectionService],
|
||||
Region: sections[sectionRegion],
|
||||
AccountID: sections[sectionAccountID],
|
||||
Resource: sections[sectionResource],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsARN returns whether the given string is an ARN by looking for
|
||||
// whether the string starts with "arn:" and contains the correct number
|
||||
// of sections delimited by colons(:).
|
||||
func IsARN(arn string) bool {
|
||||
return strings.HasPrefix(arn, arnPrefix) && strings.Count(arn, ":") >= arnSections-1
|
||||
}
|
||||
|
||||
// String returns the canonical representation of the ARN
|
||||
func (arn ARN) String() string {
|
||||
return arnPrefix +
|
||||
arn.Partition + arnDelimiter +
|
||||
arn.Service + arnDelimiter +
|
||||
arn.Region + arnDelimiter +
|
||||
arn.AccountID + arnDelimiter +
|
||||
arn.Resource
|
||||
}
|
||||
164
vendor/github.com/aws/aws-sdk-go/aws/awserr/error.go
generated
vendored
Normal file
164
vendor/github.com/aws/aws-sdk-go/aws/awserr/error.go
generated
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
// Package awserr represents API error interface accessors for the SDK.
|
||||
package awserr
|
||||
|
||||
// An Error wraps lower level errors with code, message and an original error.
|
||||
// The underlying concrete error type may also satisfy other interfaces which
|
||||
// can be to used to obtain more specific information about the error.
|
||||
//
|
||||
// Calling Error() or String() will always include the full information about
|
||||
// an error based on its underlying type.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// output, err := s3manage.Upload(svc, input, opts)
|
||||
// if err != nil {
|
||||
// if awsErr, ok := err.(awserr.Error); ok {
|
||||
// // Get error details
|
||||
// log.Println("Error:", awsErr.Code(), awsErr.Message())
|
||||
//
|
||||
// // Prints out full error message, including original error if there was one.
|
||||
// log.Println("Error:", awsErr.Error())
|
||||
//
|
||||
// // Get original error
|
||||
// if origErr := awsErr.OrigErr(); origErr != nil {
|
||||
// // operate on original error.
|
||||
// }
|
||||
// } else {
|
||||
// fmt.Println(err.Error())
|
||||
// }
|
||||
// }
|
||||
//
|
||||
type Error interface {
|
||||
// Satisfy the generic error interface.
|
||||
error
|
||||
|
||||
// Returns the short phrase depicting the classification of the error.
|
||||
Code() string
|
||||
|
||||
// Returns the error details message.
|
||||
Message() string
|
||||
|
||||
// Returns the original error if one was set. Nil is returned if not set.
|
||||
OrigErr() error
|
||||
}
|
||||
|
||||
// BatchError is a batch of errors which also wraps lower level errors with
|
||||
// code, message, and original errors. Calling Error() will include all errors
|
||||
// that occurred in the batch.
|
||||
//
|
||||
// Deprecated: Replaced with BatchedErrors. Only defined for backwards
|
||||
// compatibility.
|
||||
type BatchError interface {
|
||||
// Satisfy the generic error interface.
|
||||
error
|
||||
|
||||
// Returns the short phrase depicting the classification of the error.
|
||||
Code() string
|
||||
|
||||
// Returns the error details message.
|
||||
Message() string
|
||||
|
||||
// Returns the original error if one was set. Nil is returned if not set.
|
||||
OrigErrs() []error
|
||||
}
|
||||
|
||||
// BatchedErrors is a batch of errors which also wraps lower level errors with
|
||||
// code, message, and original errors. Calling Error() will include all errors
|
||||
// that occurred in the batch.
|
||||
//
|
||||
// Replaces BatchError
|
||||
type BatchedErrors interface {
|
||||
// Satisfy the base Error interface.
|
||||
Error
|
||||
|
||||
// Returns the original error if one was set. Nil is returned if not set.
|
||||
OrigErrs() []error
|
||||
}
|
||||
|
||||
// New returns an Error object described by the code, message, and origErr.
|
||||
//
|
||||
// If origErr satisfies the Error interface it will not be wrapped within a new
|
||||
// Error object and will instead be returned.
|
||||
func New(code, message string, origErr error) Error {
|
||||
var errs []error
|
||||
if origErr != nil {
|
||||
errs = append(errs, origErr)
|
||||
}
|
||||
return newBaseError(code, message, errs)
|
||||
}
|
||||
|
||||
// NewBatchError returns an BatchedErrors with a collection of errors as an
|
||||
// array of errors.
|
||||
func NewBatchError(code, message string, errs []error) BatchedErrors {
|
||||
return newBaseError(code, message, errs)
|
||||
}
|
||||
|
||||
// A RequestFailure is an interface to extract request failure information from
|
||||
// an Error such as the request ID of the failed request returned by a service.
|
||||
// RequestFailures may not always have a requestID value if the request failed
|
||||
// prior to reaching the service such as a connection error.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// output, err := s3manage.Upload(svc, input, opts)
|
||||
// if err != nil {
|
||||
// if reqerr, ok := err.(RequestFailure); ok {
|
||||
// log.Println("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
|
||||
// } else {
|
||||
// log.Println("Error:", err.Error())
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Combined with awserr.Error:
|
||||
//
|
||||
// output, err := s3manage.Upload(svc, input, opts)
|
||||
// if err != nil {
|
||||
// if awsErr, ok := err.(awserr.Error); ok {
|
||||
// // Generic AWS Error with Code, Message, and original error (if any)
|
||||
// fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
|
||||
//
|
||||
// if reqErr, ok := err.(awserr.RequestFailure); ok {
|
||||
// // A service error occurred
|
||||
// fmt.Println(reqErr.StatusCode(), reqErr.RequestID())
|
||||
// }
|
||||
// } else {
|
||||
// fmt.Println(err.Error())
|
||||
// }
|
||||
// }
|
||||
//
|
||||
type RequestFailure interface {
|
||||
Error
|
||||
|
||||
// The status code of the HTTP response.
|
||||
StatusCode() int
|
||||
|
||||
// The request ID returned by the service for a request failure. This will
|
||||
// be empty if no request ID is available such as the request failed due
|
||||
// to a connection error.
|
||||
RequestID() string
|
||||
}
|
||||
|
||||
// NewRequestFailure returns a wrapped error with additional information for
|
||||
// request status code, and service requestID.
|
||||
//
|
||||
// Should be used to wrap all request which involve service requests. Even if
|
||||
// the request failed without a service response, but had an HTTP status code
|
||||
// that may be meaningful.
|
||||
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
|
||||
return newRequestError(err, statusCode, reqID)
|
||||
}
|
||||
|
||||
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
|
||||
type UnmarshalError interface {
|
||||
awsError
|
||||
Bytes() []byte
|
||||
}
|
||||
|
||||
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
|
||||
// the bytes that fail to unmarshal to the error.
|
||||
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
|
||||
return &unmarshalError{
|
||||
awsError: New("UnmarshalError", msg, err),
|
||||
bytes: bytes,
|
||||
}
|
||||
}
|
||||
221
vendor/github.com/aws/aws-sdk-go/aws/awserr/types.go
generated
vendored
Normal file
221
vendor/github.com/aws/aws-sdk-go/aws/awserr/types.go
generated
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
package awserr
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SprintError returns a string of the formatted error code.
|
||||
//
|
||||
// Both extra and origErr are optional. If they are included their lines
|
||||
// will be added, but if they are not included their lines will be ignored.
|
||||
func SprintError(code, message, extra string, origErr error) string {
|
||||
msg := fmt.Sprintf("%s: %s", code, message)
|
||||
if extra != "" {
|
||||
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
|
||||
}
|
||||
if origErr != nil {
|
||||
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// A baseError wraps the code and message which defines an error. It also
|
||||
// can be used to wrap an original error object.
|
||||
//
|
||||
// Should be used as the root for errors satisfying the awserr.Error. Also
|
||||
// for any error which does not fit into a specific error wrapper type.
|
||||
type baseError struct {
|
||||
// Classification of error
|
||||
code string
|
||||
|
||||
// Detailed information about error
|
||||
message string
|
||||
|
||||
// Optional original error this error is based off of. Allows building
|
||||
// chained errors.
|
||||
errs []error
|
||||
}
|
||||
|
||||
// newBaseError returns an error object for the code, message, and errors.
|
||||
//
|
||||
// code is a short no whitespace phrase depicting the classification of
|
||||
// the error that is being created.
|
||||
//
|
||||
// message is the free flow string containing detailed information about the
|
||||
// error.
|
||||
//
|
||||
// origErrs is the error objects which will be nested under the new errors to
|
||||
// be returned.
|
||||
func newBaseError(code, message string, origErrs []error) *baseError {
|
||||
b := &baseError{
|
||||
code: code,
|
||||
message: message,
|
||||
errs: origErrs,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
//
|
||||
// See ErrorWithExtra for formatting.
|
||||
//
|
||||
// Satisfies the error interface.
|
||||
func (b baseError) Error() string {
|
||||
size := len(b.errs)
|
||||
if size > 0 {
|
||||
return SprintError(b.code, b.message, "", errorList(b.errs))
|
||||
}
|
||||
|
||||
return SprintError(b.code, b.message, "", nil)
|
||||
}
|
||||
|
||||
// String returns the string representation of the error.
|
||||
// Alias for Error to satisfy the stringer interface.
|
||||
func (b baseError) String() string {
|
||||
return b.Error()
|
||||
}
|
||||
|
||||
// Code returns the short phrase depicting the classification of the error.
|
||||
func (b baseError) Code() string {
|
||||
return b.code
|
||||
}
|
||||
|
||||
// Message returns the error details message.
|
||||
func (b baseError) Message() string {
|
||||
return b.message
|
||||
}
|
||||
|
||||
// OrigErr returns the original error if one was set. Nil is returned if no
|
||||
// error was set. This only returns the first element in the list. If the full
|
||||
// list is needed, use BatchedErrors.
|
||||
func (b baseError) OrigErr() error {
|
||||
switch len(b.errs) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return b.errs[0]
|
||||
default:
|
||||
if err, ok := b.errs[0].(Error); ok {
|
||||
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
|
||||
}
|
||||
return NewBatchError("BatchedErrors",
|
||||
"multiple errors occurred", b.errs)
|
||||
}
|
||||
}
|
||||
|
||||
// OrigErrs returns the original errors if one was set. An empty slice is
|
||||
// returned if no error was set.
|
||||
func (b baseError) OrigErrs() []error {
|
||||
return b.errs
|
||||
}
|
||||
|
||||
// So that the Error interface type can be included as an anonymous field
|
||||
// in the requestError struct and not conflict with the error.Error() method.
|
||||
type awsError Error
|
||||
|
||||
// A requestError wraps a request or service error.
|
||||
//
|
||||
// Composed of baseError for code, message, and original error.
|
||||
type requestError struct {
|
||||
awsError
|
||||
statusCode int
|
||||
requestID string
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// newRequestError returns a wrapped error with additional information for
|
||||
// request status code, and service requestID.
|
||||
//
|
||||
// Should be used to wrap all request which involve service requests. Even if
|
||||
// the request failed without a service response, but had an HTTP status code
|
||||
// that may be meaningful.
|
||||
//
|
||||
// Also wraps original errors via the baseError.
|
||||
func newRequestError(err Error, statusCode int, requestID string) *requestError {
|
||||
return &requestError{
|
||||
awsError: err,
|
||||
statusCode: statusCode,
|
||||
requestID: requestID,
|
||||
}
|
||||
}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
// Satisfies the error interface.
|
||||
func (r requestError) Error() string {
|
||||
extra := fmt.Sprintf("status code: %d, request id: %s",
|
||||
r.statusCode, r.requestID)
|
||||
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
|
||||
}
|
||||
|
||||
// String returns the string representation of the error.
|
||||
// Alias for Error to satisfy the stringer interface.
|
||||
func (r requestError) String() string {
|
||||
return r.Error()
|
||||
}
|
||||
|
||||
// StatusCode returns the wrapped status code for the error
|
||||
func (r requestError) StatusCode() int {
|
||||
return r.statusCode
|
||||
}
|
||||
|
||||
// RequestID returns the wrapped requestID
|
||||
func (r requestError) RequestID() string {
|
||||
return r.requestID
|
||||
}
|
||||
|
||||
// OrigErrs returns the original errors if one was set. An empty slice is
|
||||
// returned if no error was set.
|
||||
func (r requestError) OrigErrs() []error {
|
||||
if b, ok := r.awsError.(BatchedErrors); ok {
|
||||
return b.OrigErrs()
|
||||
}
|
||||
return []error{r.OrigErr()}
|
||||
}
|
||||
|
||||
type unmarshalError struct {
|
||||
awsError
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
// Satisfies the error interface.
|
||||
func (e unmarshalError) Error() string {
|
||||
extra := hex.Dump(e.bytes)
|
||||
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
|
||||
}
|
||||
|
||||
// String returns the string representation of the error.
|
||||
// Alias for Error to satisfy the stringer interface.
|
||||
func (e unmarshalError) String() string {
|
||||
return e.Error()
|
||||
}
|
||||
|
||||
// Bytes returns the bytes that failed to unmarshal.
|
||||
func (e unmarshalError) Bytes() []byte {
|
||||
return e.bytes
|
||||
}
|
||||
|
||||
// An error list that satisfies the golang interface
|
||||
type errorList []error
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
//
|
||||
// Satisfies the error interface.
|
||||
func (e errorList) Error() string {
|
||||
msg := ""
|
||||
// How do we want to handle the array size being zero
|
||||
if size := len(e); size > 0 {
|
||||
for i := 0; i < size; i++ {
|
||||
msg += e[i].Error()
|
||||
// We check the next index to see if it is within the slice.
|
||||
// If it is, then we append a newline. We do this, because unit tests
|
||||
// could be broken with the additional '\n'
|
||||
if i+1 < size {
|
||||
msg += "\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
108
vendor/github.com/aws/aws-sdk-go/aws/awsutil/copy.go
generated
vendored
Normal file
108
vendor/github.com/aws/aws-sdk-go/aws/awsutil/copy.go
generated
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
package awsutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Copy deeply copies a src structure to dst. Useful for copying request and
|
||||
// response structures.
|
||||
//
|
||||
// Can copy between structs of different type, but will only copy fields which
|
||||
// are assignable, and exist in both structs. Fields which are not assignable,
|
||||
// or do not exist in both structs are ignored.
|
||||
func Copy(dst, src interface{}) {
|
||||
dstval := reflect.ValueOf(dst)
|
||||
if !dstval.IsValid() {
|
||||
panic("Copy dst cannot be nil")
|
||||
}
|
||||
|
||||
rcopy(dstval, reflect.ValueOf(src), true)
|
||||
}
|
||||
|
||||
// CopyOf returns a copy of src while also allocating the memory for dst.
|
||||
// src must be a pointer type or this operation will fail.
|
||||
func CopyOf(src interface{}) (dst interface{}) {
|
||||
dsti := reflect.New(reflect.TypeOf(src).Elem())
|
||||
dst = dsti.Interface()
|
||||
rcopy(dsti, reflect.ValueOf(src), true)
|
||||
return
|
||||
}
|
||||
|
||||
// rcopy performs a recursive copy of values from the source to destination.
|
||||
//
|
||||
// root is used to skip certain aspects of the copy which are not valid
|
||||
// for the root node of a object.
|
||||
func rcopy(dst, src reflect.Value, root bool) {
|
||||
if !src.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
switch src.Kind() {
|
||||
case reflect.Ptr:
|
||||
if _, ok := src.Interface().(io.Reader); ok {
|
||||
if dst.Kind() == reflect.Ptr && dst.Elem().CanSet() {
|
||||
dst.Elem().Set(src)
|
||||
} else if dst.CanSet() {
|
||||
dst.Set(src)
|
||||
}
|
||||
} else {
|
||||
e := src.Type().Elem()
|
||||
if dst.CanSet() && !src.IsNil() {
|
||||
if _, ok := src.Interface().(*time.Time); !ok {
|
||||
dst.Set(reflect.New(e))
|
||||
} else {
|
||||
tempValue := reflect.New(e)
|
||||
tempValue.Elem().Set(src.Elem())
|
||||
// Sets time.Time's unexported values
|
||||
dst.Set(tempValue)
|
||||
}
|
||||
}
|
||||
if src.Elem().IsValid() {
|
||||
// Keep the current root state since the depth hasn't changed
|
||||
rcopy(dst.Elem(), src.Elem(), root)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
t := dst.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
name := t.Field(i).Name
|
||||
srcVal := src.FieldByName(name)
|
||||
dstVal := dst.FieldByName(name)
|
||||
if srcVal.IsValid() && dstVal.CanSet() {
|
||||
rcopy(dstVal, srcVal, false)
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
if src.IsNil() {
|
||||
break
|
||||
}
|
||||
|
||||
s := reflect.MakeSlice(src.Type(), src.Len(), src.Cap())
|
||||
dst.Set(s)
|
||||
for i := 0; i < src.Len(); i++ {
|
||||
rcopy(dst.Index(i), src.Index(i), false)
|
||||
}
|
||||
case reflect.Map:
|
||||
if src.IsNil() {
|
||||
break
|
||||
}
|
||||
|
||||
s := reflect.MakeMap(src.Type())
|
||||
dst.Set(s)
|
||||
for _, k := range src.MapKeys() {
|
||||
v := src.MapIndex(k)
|
||||
v2 := reflect.New(v.Type()).Elem()
|
||||
rcopy(v2, v, false)
|
||||
dst.SetMapIndex(k, v2)
|
||||
}
|
||||
default:
|
||||
// Assign the value if possible. If its not assignable, the value would
|
||||
// need to be converted and the impact of that may be unexpected, or is
|
||||
// not compatible with the dst type.
|
||||
if src.Type().AssignableTo(dst.Type()) {
|
||||
dst.Set(src)
|
||||
}
|
||||
}
|
||||
}
|
||||
27
vendor/github.com/aws/aws-sdk-go/aws/awsutil/equal.go
generated
vendored
Normal file
27
vendor/github.com/aws/aws-sdk-go/aws/awsutil/equal.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
package awsutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// DeepEqual returns if the two values are deeply equal like reflect.DeepEqual.
|
||||
// In addition to this, this method will also dereference the input values if
|
||||
// possible so the DeepEqual performed will not fail if one parameter is a
|
||||
// pointer and the other is not.
|
||||
//
|
||||
// DeepEqual will not perform indirection of nested values of the input parameters.
|
||||
func DeepEqual(a, b interface{}) bool {
|
||||
ra := reflect.Indirect(reflect.ValueOf(a))
|
||||
rb := reflect.Indirect(reflect.ValueOf(b))
|
||||
|
||||
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
|
||||
// If the elements are both nil, and of the same type they are equal
|
||||
// If they are of different types they are not equal
|
||||
return reflect.TypeOf(a) == reflect.TypeOf(b)
|
||||
} else if raValid != rbValid {
|
||||
// Both values must be valid to be equal
|
||||
return false
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(ra.Interface(), rb.Interface())
|
||||
}
|
||||
221
vendor/github.com/aws/aws-sdk-go/aws/awsutil/path_value.go
generated
vendored
Normal file
221
vendor/github.com/aws/aws-sdk-go/aws/awsutil/path_value.go
generated
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
package awsutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmespath/go-jmespath"
|
||||
)
|
||||
|
||||
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
|
||||
|
||||
// rValuesAtPath returns a slice of values found in value v. The values
|
||||
// in v are explored recursively so all nested values are collected.
|
||||
func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTerm bool) []reflect.Value {
|
||||
pathparts := strings.Split(path, "||")
|
||||
if len(pathparts) > 1 {
|
||||
for _, pathpart := range pathparts {
|
||||
vals := rValuesAtPath(v, pathpart, createPath, caseSensitive, nilTerm)
|
||||
if len(vals) > 0 {
|
||||
return vals
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
values := []reflect.Value{reflect.Indirect(reflect.ValueOf(v))}
|
||||
components := strings.Split(path, ".")
|
||||
for len(values) > 0 && len(components) > 0 {
|
||||
var index *int64
|
||||
var indexStar bool
|
||||
c := strings.TrimSpace(components[0])
|
||||
if c == "" { // no actual component, illegal syntax
|
||||
return nil
|
||||
} else if caseSensitive && c != "*" && strings.ToLower(c[0:1]) == c[0:1] {
|
||||
// TODO normalize case for user
|
||||
return nil // don't support unexported fields
|
||||
}
|
||||
|
||||
// parse this component
|
||||
if m := indexRe.FindStringSubmatch(c); m != nil {
|
||||
c = m[1]
|
||||
if m[2] == "" {
|
||||
index = nil
|
||||
indexStar = true
|
||||
} else {
|
||||
i, _ := strconv.ParseInt(m[2], 10, 32)
|
||||
index = &i
|
||||
indexStar = false
|
||||
}
|
||||
}
|
||||
|
||||
nextvals := []reflect.Value{}
|
||||
for _, value := range values {
|
||||
// pull component name out of struct member
|
||||
if value.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
if c == "*" { // pull all members
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
if f := reflect.Indirect(value.Field(i)); f.IsValid() {
|
||||
nextvals = append(nextvals, f)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
value = value.FieldByNameFunc(func(name string) bool {
|
||||
if c == name {
|
||||
return true
|
||||
} else if !caseSensitive && strings.EqualFold(name, c) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
if nilTerm && value.Kind() == reflect.Ptr && len(components[1:]) == 0 {
|
||||
if !value.IsNil() {
|
||||
value.Set(reflect.Zero(value.Type()))
|
||||
}
|
||||
return []reflect.Value{value}
|
||||
}
|
||||
|
||||
if createPath && value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
// TODO if the value is the terminus it should not be created
|
||||
// if the value to be set to its position is nil.
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
value = value.Elem()
|
||||
} else {
|
||||
value = reflect.Indirect(value)
|
||||
}
|
||||
|
||||
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
|
||||
if !createPath && value.IsNil() {
|
||||
value = reflect.ValueOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
if value.IsValid() {
|
||||
nextvals = append(nextvals, value)
|
||||
}
|
||||
}
|
||||
values = nextvals
|
||||
|
||||
if indexStar || index != nil {
|
||||
nextvals = []reflect.Value{}
|
||||
for _, valItem := range values {
|
||||
value := reflect.Indirect(valItem)
|
||||
if value.Kind() != reflect.Slice {
|
||||
continue
|
||||
}
|
||||
|
||||
if indexStar { // grab all indices
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
idx := reflect.Indirect(value.Index(i))
|
||||
if idx.IsValid() {
|
||||
nextvals = append(nextvals, idx)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// pull out index
|
||||
i := int(*index)
|
||||
if i >= value.Len() { // check out of bounds
|
||||
if createPath {
|
||||
// TODO resize slice
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
} else if i < 0 { // support negative indexing
|
||||
i = value.Len() + i
|
||||
}
|
||||
value = reflect.Indirect(value.Index(i))
|
||||
|
||||
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
|
||||
if !createPath && value.IsNil() {
|
||||
value = reflect.ValueOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
if value.IsValid() {
|
||||
nextvals = append(nextvals, value)
|
||||
}
|
||||
}
|
||||
values = nextvals
|
||||
}
|
||||
|
||||
components = components[1:]
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// ValuesAtPath returns a list of values at the case insensitive lexical
|
||||
// path inside of a structure.
|
||||
func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
|
||||
result, err := jmespath.Search(path, i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(result)
|
||||
if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) {
|
||||
return nil, nil
|
||||
}
|
||||
if s, ok := result.([]interface{}); ok {
|
||||
return s, err
|
||||
}
|
||||
if v.Kind() == reflect.Map && v.Len() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if v.Kind() == reflect.Slice {
|
||||
out := make([]interface{}, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
out[i] = v.Index(i).Interface()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
return []interface{}{result}, nil
|
||||
}
|
||||
|
||||
// SetValueAtPath sets a value at the case insensitive lexical path inside
|
||||
// of a structure.
|
||||
func SetValueAtPath(i interface{}, path string, v interface{}) {
|
||||
rvals := rValuesAtPath(i, path, true, false, v == nil)
|
||||
for _, rval := range rvals {
|
||||
if rval.Kind() == reflect.Ptr && rval.IsNil() {
|
||||
continue
|
||||
}
|
||||
setValue(rval, v)
|
||||
}
|
||||
}
|
||||
|
||||
func setValue(dstVal reflect.Value, src interface{}) {
|
||||
if dstVal.Kind() == reflect.Ptr {
|
||||
dstVal = reflect.Indirect(dstVal)
|
||||
}
|
||||
srcVal := reflect.ValueOf(src)
|
||||
|
||||
if !srcVal.IsValid() { // src is literal nil
|
||||
if dstVal.CanAddr() {
|
||||
// Convert to pointer so that pointer's value can be nil'ed
|
||||
// dstVal = dstVal.Addr()
|
||||
}
|
||||
dstVal.Set(reflect.Zero(dstVal.Type()))
|
||||
|
||||
} else if srcVal.Kind() == reflect.Ptr {
|
||||
if srcVal.IsNil() {
|
||||
srcVal = reflect.Zero(dstVal.Type())
|
||||
} else {
|
||||
srcVal = reflect.ValueOf(src).Elem()
|
||||
}
|
||||
dstVal.Set(srcVal)
|
||||
} else {
|
||||
dstVal.Set(srcVal)
|
||||
}
|
||||
|
||||
}
|
||||
123
vendor/github.com/aws/aws-sdk-go/aws/awsutil/prettify.go
generated
vendored
Normal file
123
vendor/github.com/aws/aws-sdk-go/aws/awsutil/prettify.go
generated
vendored
Normal file
@@ -0,0 +1,123 @@
|
||||
package awsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Prettify returns the string representation of a value.
|
||||
func Prettify(i interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
prettify(reflect.ValueOf(i), 0, &buf)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// prettify will recursively walk value v to build a textual
|
||||
// representation of the value.
|
||||
func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
strtype := v.Type().String()
|
||||
if strtype == "time.Time" {
|
||||
fmt.Fprintf(buf, "%s", v.Interface())
|
||||
break
|
||||
} else if strings.HasPrefix(strtype, "io.") {
|
||||
buf.WriteString("<buffer>")
|
||||
break
|
||||
}
|
||||
|
||||
buf.WriteString("{\n")
|
||||
|
||||
names := []string{}
|
||||
for i := 0; i < v.Type().NumField(); i++ {
|
||||
name := v.Type().Field(i).Name
|
||||
f := v.Field(i)
|
||||
if name[0:1] == strings.ToLower(name[0:1]) {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice || f.Kind() == reflect.Map) && f.IsNil() {
|
||||
continue // ignore unset fields
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
for i, n := range names {
|
||||
val := v.FieldByName(n)
|
||||
ft, ok := v.Type().FieldByName(n)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("expected to find field %v on type %v, but was not found", n, v.Type()))
|
||||
}
|
||||
|
||||
buf.WriteString(strings.Repeat(" ", indent+2))
|
||||
buf.WriteString(n + ": ")
|
||||
|
||||
if tag := ft.Tag.Get("sensitive"); tag == "true" {
|
||||
buf.WriteString("<sensitive>")
|
||||
} else {
|
||||
prettify(val, indent+2, buf)
|
||||
}
|
||||
|
||||
if i < len(names)-1 {
|
||||
buf.WriteString(",\n")
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
||||
case reflect.Slice:
|
||||
strtype := v.Type().String()
|
||||
if strtype == "[]uint8" {
|
||||
fmt.Fprintf(buf, "<binary> len %d", v.Len())
|
||||
break
|
||||
}
|
||||
|
||||
nl, id, id2 := "", "", ""
|
||||
if v.Len() > 3 {
|
||||
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
|
||||
}
|
||||
buf.WriteString("[" + nl)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
buf.WriteString(id2)
|
||||
prettify(v.Index(i), indent+2, buf)
|
||||
|
||||
if i < v.Len()-1 {
|
||||
buf.WriteString("," + nl)
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString(nl + id + "]")
|
||||
case reflect.Map:
|
||||
buf.WriteString("{\n")
|
||||
|
||||
for i, k := range v.MapKeys() {
|
||||
buf.WriteString(strings.Repeat(" ", indent+2))
|
||||
buf.WriteString(k.String() + ": ")
|
||||
prettify(v.MapIndex(k), indent+2, buf)
|
||||
|
||||
if i < v.Len()-1 {
|
||||
buf.WriteString(",\n")
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
||||
default:
|
||||
if !v.IsValid() {
|
||||
fmt.Fprint(buf, "<invalid value>")
|
||||
return
|
||||
}
|
||||
format := "%v"
|
||||
switch v.Interface().(type) {
|
||||
case string:
|
||||
format = "%q"
|
||||
case io.ReadSeeker, io.Reader:
|
||||
format = "buffer(%p)"
|
||||
}
|
||||
fmt.Fprintf(buf, format, v.Interface())
|
||||
}
|
||||
}
|
||||
90
vendor/github.com/aws/aws-sdk-go/aws/awsutil/string_value.go
generated
vendored
Normal file
90
vendor/github.com/aws/aws-sdk-go/aws/awsutil/string_value.go
generated
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
package awsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StringValue returns the string representation of a value.
|
||||
//
|
||||
// Deprecated: Use Prettify instead.
|
||||
func StringValue(i interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
stringValue(reflect.ValueOf(i), 0, &buf)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
buf.WriteString("{\n")
|
||||
|
||||
for i := 0; i < v.Type().NumField(); i++ {
|
||||
ft := v.Type().Field(i)
|
||||
fv := v.Field(i)
|
||||
|
||||
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
|
||||
continue // ignore unset fields
|
||||
}
|
||||
|
||||
buf.WriteString(strings.Repeat(" ", indent+2))
|
||||
buf.WriteString(ft.Name + ": ")
|
||||
|
||||
if tag := ft.Tag.Get("sensitive"); tag == "true" {
|
||||
buf.WriteString("<sensitive>")
|
||||
} else {
|
||||
stringValue(fv, indent+2, buf)
|
||||
}
|
||||
|
||||
buf.WriteString(",\n")
|
||||
}
|
||||
|
||||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
||||
case reflect.Slice:
|
||||
nl, id, id2 := "", "", ""
|
||||
if v.Len() > 3 {
|
||||
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
|
||||
}
|
||||
buf.WriteString("[" + nl)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
buf.WriteString(id2)
|
||||
stringValue(v.Index(i), indent+2, buf)
|
||||
|
||||
if i < v.Len()-1 {
|
||||
buf.WriteString("," + nl)
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString(nl + id + "]")
|
||||
case reflect.Map:
|
||||
buf.WriteString("{\n")
|
||||
|
||||
for i, k := range v.MapKeys() {
|
||||
buf.WriteString(strings.Repeat(" ", indent+2))
|
||||
buf.WriteString(k.String() + ": ")
|
||||
stringValue(v.MapIndex(k), indent+2, buf)
|
||||
|
||||
if i < v.Len()-1 {
|
||||
buf.WriteString(",\n")
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
||||
default:
|
||||
format := "%v"
|
||||
switch v.Interface().(type) {
|
||||
case string:
|
||||
format = "%q"
|
||||
}
|
||||
fmt.Fprintf(buf, format, v.Interface())
|
||||
}
|
||||
}
|
||||
94
vendor/github.com/aws/aws-sdk-go/aws/client/client.go
generated
vendored
Normal file
94
vendor/github.com/aws/aws-sdk-go/aws/client/client.go
generated
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// A Config provides configuration to a service client instance.
|
||||
type Config struct {
|
||||
Config *aws.Config
|
||||
Handlers request.Handlers
|
||||
PartitionID string
|
||||
Endpoint string
|
||||
SigningRegion string
|
||||
SigningName string
|
||||
ResolvedRegion string
|
||||
|
||||
// States that the signing name did not come from a modeled source but
|
||||
// was derived based on other data. Used by service client constructors
|
||||
// to determine if the signin name can be overridden based on metadata the
|
||||
// service has.
|
||||
SigningNameDerived bool
|
||||
}
|
||||
|
||||
// ConfigProvider provides a generic way for a service client to receive
|
||||
// the ClientConfig without circular dependencies.
|
||||
type ConfigProvider interface {
|
||||
ClientConfig(serviceName string, cfgs ...*aws.Config) Config
|
||||
}
|
||||
|
||||
// ConfigNoResolveEndpointProvider same as ConfigProvider except it will not
|
||||
// resolve the endpoint automatically. The service client's endpoint must be
|
||||
// provided via the aws.Config.Endpoint field.
|
||||
type ConfigNoResolveEndpointProvider interface {
|
||||
ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) Config
|
||||
}
|
||||
|
||||
// A Client implements the base client request and response handling
|
||||
// used by all service clients.
|
||||
type Client struct {
|
||||
request.Retryer
|
||||
metadata.ClientInfo
|
||||
|
||||
Config aws.Config
|
||||
Handlers request.Handlers
|
||||
}
|
||||
|
||||
// New will return a pointer to a new initialized service client.
|
||||
func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, options ...func(*Client)) *Client {
|
||||
svc := &Client{
|
||||
Config: cfg,
|
||||
ClientInfo: info,
|
||||
Handlers: handlers.Copy(),
|
||||
}
|
||||
|
||||
switch retryer, ok := cfg.Retryer.(request.Retryer); {
|
||||
case ok:
|
||||
svc.Retryer = retryer
|
||||
case cfg.Retryer != nil && cfg.Logger != nil:
|
||||
s := fmt.Sprintf("WARNING: %T does not implement request.Retryer; using DefaultRetryer instead", cfg.Retryer)
|
||||
cfg.Logger.Log(s)
|
||||
fallthrough
|
||||
default:
|
||||
maxRetries := aws.IntValue(cfg.MaxRetries)
|
||||
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
|
||||
maxRetries = DefaultRetryerMaxNumRetries
|
||||
}
|
||||
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
|
||||
}
|
||||
|
||||
svc.AddDebugHandlers()
|
||||
|
||||
for _, option := range options {
|
||||
option(svc)
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// NewRequest returns a new Request pointer for the service API
|
||||
// operation and parameters.
|
||||
func (c *Client) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request {
|
||||
return request.New(c.Config, c.ClientInfo, c.Handlers, c.Retryer, operation, params, data)
|
||||
}
|
||||
|
||||
// AddDebugHandlers injects debug logging handlers into the service to log request
|
||||
// debug information.
|
||||
func (c *Client) AddDebugHandlers() {
|
||||
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
|
||||
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
|
||||
}
|
||||
177
vendor/github.com/aws/aws-sdk-go/aws/client/default_retryer.go
generated
vendored
Normal file
177
vendor/github.com/aws/aws-sdk-go/aws/client/default_retryer.go
generated
vendored
Normal file
@@ -0,0 +1,177 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkrand"
|
||||
)
|
||||
|
||||
// DefaultRetryer implements basic retry logic using exponential backoff for
|
||||
// most services. If you want to implement custom retry logic, you can implement the
|
||||
// request.Retryer interface.
|
||||
//
|
||||
type DefaultRetryer struct {
|
||||
// Num max Retries is the number of max retries that will be performed.
|
||||
// By default, this is zero.
|
||||
NumMaxRetries int
|
||||
|
||||
// MinRetryDelay is the minimum retry delay after which retry will be performed.
|
||||
// If not set, the value is 0ns.
|
||||
MinRetryDelay time.Duration
|
||||
|
||||
// MinThrottleRetryDelay is the minimum retry delay when throttled.
|
||||
// If not set, the value is 0ns.
|
||||
MinThrottleDelay time.Duration
|
||||
|
||||
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
|
||||
// If not set, the value is 0ns.
|
||||
MaxRetryDelay time.Duration
|
||||
|
||||
// MaxThrottleDelay is the maximum retry delay when throttled.
|
||||
// If not set, the value is 0ns.
|
||||
MaxThrottleDelay time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultRetryerMaxNumRetries sets maximum number of retries
|
||||
DefaultRetryerMaxNumRetries = 3
|
||||
|
||||
// DefaultRetryerMinRetryDelay sets minimum retry delay
|
||||
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
|
||||
|
||||
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
|
||||
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultRetryerMaxRetryDelay sets maximum retry delay
|
||||
DefaultRetryerMaxRetryDelay = 300 * time.Second
|
||||
|
||||
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
|
||||
DefaultRetryerMaxThrottleDelay = 300 * time.Second
|
||||
)
|
||||
|
||||
// MaxRetries returns the number of maximum returns the service will use to make
|
||||
// an individual API request.
|
||||
func (d DefaultRetryer) MaxRetries() int {
|
||||
return d.NumMaxRetries
|
||||
}
|
||||
|
||||
// setRetryerDefaults sets the default values of the retryer if not set
|
||||
func (d *DefaultRetryer) setRetryerDefaults() {
|
||||
if d.MinRetryDelay == 0 {
|
||||
d.MinRetryDelay = DefaultRetryerMinRetryDelay
|
||||
}
|
||||
if d.MaxRetryDelay == 0 {
|
||||
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
|
||||
}
|
||||
if d.MinThrottleDelay == 0 {
|
||||
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
|
||||
}
|
||||
if d.MaxThrottleDelay == 0 {
|
||||
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
|
||||
}
|
||||
}
|
||||
|
||||
// RetryRules returns the delay duration before retrying this request again
|
||||
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
|
||||
|
||||
// if number of max retries is zero, no retries will be performed.
|
||||
if d.NumMaxRetries == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sets default value for retryer members
|
||||
d.setRetryerDefaults()
|
||||
|
||||
// minDelay is the minimum retryer delay
|
||||
minDelay := d.MinRetryDelay
|
||||
|
||||
var initialDelay time.Duration
|
||||
|
||||
isThrottle := r.IsErrorThrottle()
|
||||
if isThrottle {
|
||||
if delay, ok := getRetryAfterDelay(r); ok {
|
||||
initialDelay = delay
|
||||
}
|
||||
minDelay = d.MinThrottleDelay
|
||||
}
|
||||
|
||||
retryCount := r.RetryCount
|
||||
|
||||
// maxDelay the maximum retryer delay
|
||||
maxDelay := d.MaxRetryDelay
|
||||
|
||||
if isThrottle {
|
||||
maxDelay = d.MaxThrottleDelay
|
||||
}
|
||||
|
||||
var delay time.Duration
|
||||
|
||||
// Logic to cap the retry count based on the minDelay provided
|
||||
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
|
||||
if actualRetryCount < 63-retryCount {
|
||||
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
|
||||
if delay > maxDelay {
|
||||
delay = getJitterDelay(maxDelay / 2)
|
||||
}
|
||||
} else {
|
||||
delay = getJitterDelay(maxDelay / 2)
|
||||
}
|
||||
return delay + initialDelay
|
||||
}
|
||||
|
||||
// getJitterDelay returns a jittered delay for retry
|
||||
func getJitterDelay(duration time.Duration) time.Duration {
|
||||
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
|
||||
}
|
||||
|
||||
// ShouldRetry returns true if the request should be retried.
|
||||
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
|
||||
|
||||
// ShouldRetry returns false if number of max retries is 0.
|
||||
if d.NumMaxRetries == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// If one of the other handlers already set the retry state
|
||||
// we don't want to override it based on the service's state
|
||||
if r.Retryable != nil {
|
||||
return *r.Retryable
|
||||
}
|
||||
return r.IsErrorRetryable() || r.IsErrorThrottle()
|
||||
}
|
||||
|
||||
// This will look in the Retry-After header, RFC 7231, for how long
|
||||
// it will wait before attempting another request
|
||||
func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
|
||||
if !canUseRetryAfterHeader(r) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
delayStr := r.HTTPResponse.Header.Get("Retry-After")
|
||||
if len(delayStr) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
delay, err := strconv.Atoi(delayStr)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return time.Duration(delay) * time.Second, true
|
||||
}
|
||||
|
||||
// Will look at the status code to see if the retry header pertains to
|
||||
// the status code.
|
||||
func canUseRetryAfterHeader(r *request.Request) bool {
|
||||
switch r.HTTPResponse.StatusCode {
|
||||
case 429:
|
||||
case 503:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user