16 Commits

Author SHA1 Message Date
b2b77eb4da refactor: remove vendor directory to use Go proxy service
Removes local vendor directory in favor of Go module proxy for cleaner
repository and improved dependency management. Dependencies will now be
fetched automatically from proxy.golang.org during builds.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 23:10:28 +12:00
6558a09258 deps: update vendor dependencies for S3-compatible storage
Updates AWS SDK and removes Blazer B2 dependency in favor of unified
S3-compatible approach. Includes configuration examples and documentation.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 23:07:58 +12:00
f99a866e13 feat: implement unified S3-compatible storage system
Consolidates storage backends into a single S3-compatible driver that supports:
- AWS S3 (native)
- Backblaze B2 (S3-compatible API)
- Cloudflare R2 (S3-compatible API)
- MinIO and other S3-compatible services
- Local filesystem for development

This replaces the previous separate B2 driver with a unified approach,
reducing dependencies and complexity while adding support for more services.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 23:07:44 +12:00
e3152d9f40 fix: correct environment variables in justfile run-dev command
- Fix OA_DATABASE_DRIVER to OA_DATABASEDRIVER (no underscore)
- Fix OA_DATABASE_FILE to OA_DATABASEFILE (no underscore)
- Change database filename from test2.db to dev.db for better naming
- Ensure SQLite database is created in project root as expected

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:22:28 +12:00
e78098ad45 feat: update gitignore for attachment system
- Add .vscode/ to ignore IDE-specific files
- Add server to ignore build artifacts
- Add attachments/ to ignore uploaded attachment files
- Maintain clean repository without development artifacts

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:06:29 +12:00
7c43726abf fix: correct WebSocket message logging format
- Change format specifier from %s to %+v for struct logging
- Resolve compilation error in WebSocket message handling
- Maintain proper logging functionality for debugging

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:05:37 +12:00
b7ac4b0152 fix: add missing mock expectations in account tests
- Add GetSplitCountByAccountId mock expectations for CreateAccount tests
- Add GetSplitCountByAccountId mock expectations for UpdateAccount tests
- Resolve "unexpected method call" errors in account test suite
- Maintain existing test logic while fixing mock setup

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:05:21 +12:00
1b115fe0ff feat: add attachment methods to mock datastore
- Add InsertAttachment mock method for testing
- Add GetAttachment and GetAttachmentsByTransaction mock methods
- Add DeleteAttachment mock method for testing
- Maintain consistency with existing mock patterns

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:05:05 +12:00
a87df47231 feat: register attachment API routes
- Add 5 RESTful endpoints for transaction attachment management
- Include proper authentication middleware for all attachment operations
- Follow existing URL pattern: /orgs/:orgId/transactions/:transactionId/attachments
- Support nested resource access with proper authorization

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:50 +12:00
8b0a72c81f feat: implement attachment REST API endpoints
- Add POST /attachments for secure multi-file upload with validation
- Add GET /attachments for listing transaction attachments
- Add GET /attachments/:id for attachment metadata retrieval
- Add GET /attachments/:id/download for secure file download
- Add DELETE /attachments/:id for soft deletion
- Include comprehensive security validation: file type, size, content detection
- Implement proper error handling and cleanup on failures

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:32 +12:00
f64f83e66f feat: add attachment support to GORM model
- Implement AttachmentInterface methods in GormModel
- Add GetTransaction method for interface compliance
- Include placeholder implementation for future GORM repository development
- Maintain backward compatibility with existing GORM usage

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:15 +12:00
f5f0853040 feat: implement attachment business logic layer
- Add AttachmentInterface to main model interface
- Implement attachment CRUD operations with permission checking
- Add GetTransaction method for secure attachment access validation
- Add accountsContainReadAccess for permission verification
- Ensure users can only access attachments for authorized transactions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:04:01 +12:00
04653f2f02 feat: implement attachment database layer
- Add AttachmentInterface to main Datastore interface
- Implement CRUD operations for attachments following existing patterns
- Add proper SQL marshalling/unmarshalling with HEX/UNHEX for binary IDs
- Include soft deletion and proper indexing support

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:03:44 +12:00
3b89d8137e feat: add UUID utility functions for attachment system
- Add NewUUID() function for generating unique attachment identifiers
- Add IsValidUUID() function for validating UUID format in API requests
- Support 32-character hex string validation for secure file handling

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:03:30 +12:00
d10686e70f feat: add attachment model type definitions
- Add core attachment type with metadata fields for transaction files
- Add GORM model for attachment with proper relationships
- Include file information, upload timestamps, and soft deletion support

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:03:13 +12:00
c335c834ba feat: add database schema for transaction attachments
- Add attachment table with fields for file metadata and relationships
- Include indexes for optimal query performance on transactionId, orgId, userId, and uploaded fields
- Support for file storage with path tracking and soft deletion

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-01 11:02:58 +12:00
1203 changed files with 2816 additions and 701849 deletions

38
.env.storage.example Normal file
View File

@@ -0,0 +1,38 @@
# OpenAccounting Storage Configuration
# Copy this file to .env and modify as needed
# Database Configuration
OA_DATABASEDRIVER=sqlite
OA_DATABASEFILE=./openaccounting.db
OA_ADDRESS=localhost
OA_PORT=8080
OA_APIPREFIX=/api/v1
# Storage Backend Configuration
# Options: local, s3, b2
OA_STORAGE_BACKEND=local
# Local Storage Configuration
OA_STORAGE_LOCAL_ROOTDIR=./uploads
OA_STORAGE_LOCAL_BASEURL=
# Amazon S3 Storage Configuration (uncomment if using S3)
# OA_STORAGE_S3_REGION=us-east-1
# OA_STORAGE_S3_BUCKET=my-openaccounting-attachments
# OA_STORAGE_S3_PREFIX=attachments
# OA_STORAGE_S3_ACCESSKEYID=AKIA...
# OA_STORAGE_S3_SECRETACCESSKEY=...
# OA_STORAGE_S3_ENDPOINT=
# OA_STORAGE_S3_PATHSTYLE=false
# Backblaze B2 Storage Configuration (uncomment if using B2)
# OA_STORAGE_B2_ACCOUNTID=your-b2-account-id
# OA_STORAGE_B2_APPLICATIONKEY=your-b2-application-key
# OA_STORAGE_B2_BUCKET=my-openaccounting-attachments
# OA_STORAGE_B2_PREFIX=attachments
# Email Configuration (optional)
# OA_MAILGUNDOMAIN=
# OA_MAILGUNKEY=
# OA_MAILGUNEMAIL=
# OA_MAILGUNSENDER=

3
.gitignore vendored
View File

@@ -97,3 +97,6 @@ config.json
*.csr
*.sublime-project
*.sublime-workspace
.vscode/
server
attachments/

View File

@@ -6,6 +6,7 @@ Open Accounting Server is a modern financial accounting system built with Go, fe
- **GORM Integration**: Modern ORM with SQLite and MySQL support
- **Viper Configuration**: Flexible config management with environment variables
- **Modular Storage**: S3-compatible attachment storage (Local, AWS S3, Backblaze B2, Cloudflare R2, MinIO)
- **Docker Ready**: Containerized deployment with multi-stage builds
- **SQLite Support**: Easy local development and testing
- **Security**: Environment variable support for sensitive data
@@ -81,6 +82,35 @@ All configuration can be overridden with environment variables using the `OA_` p
| `OA_MAILGUN_EMAIL` | MailgunEmail | | Mailgun email |
| `OA_MAILGUN_SENDER` | MailgunSender | | Mailgun sender name |
#### Storage Configuration
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_STORAGE_BACKEND` | Storage.Backend | `local` | Storage backend: `local` or `s3` |
**Local Storage**
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_STORAGE_LOCAL_ROOTDIR` | Storage.Local.RootDir | `./uploads` | Root directory for file storage |
| `OA_STORAGE_LOCAL_BASEURL` | Storage.Local.BaseURL | | Base URL for serving files |
**S3-Compatible Storage** (AWS S3, Backblaze B2, Cloudflare R2, MinIO)
| Environment Variable | Config Field | Default | Description |
|---------------------|--------------|---------|-------------|
| `OA_STORAGE_S3_REGION` | Storage.S3.Region | | Region (use "auto" for Cloudflare R2) |
| `OA_STORAGE_S3_BUCKET` | Storage.S3.Bucket | | Bucket name |
| `OA_STORAGE_S3_PREFIX` | Storage.S3.Prefix | | Optional prefix for all objects |
| `OA_STORAGE_S3_ACCESSKEYID` | Storage.S3.AccessKeyID | | Access Key ID ⚠️ |
| `OA_STORAGE_S3_SECRETACCESSKEY` | Storage.S3.SecretAccessKey | | Secret Access Key ⚠️ |
| `OA_STORAGE_S3_ENDPOINT` | Storage.S3.Endpoint | | Custom endpoint (see examples below) |
| `OA_STORAGE_S3_PATHSTYLE` | Storage.S3.PathStyle | `false` | Use path-style addressing |
**S3-Compatible Service Endpoints:**
- **AWS S3**: Leave endpoint empty, set appropriate region
- **Backblaze B2**: `https://s3.us-west-004.backblazeb2.com` (replace region as needed)
- **Cloudflare R2**: `https://<account-id>.r2.cloudflarestorage.com`
- **MinIO**: `http://localhost:9000` (or your MinIO server URL)
⚠️ **Security**: Always use environment variables for sensitive data like passwords and API keys.
### Configuration Examples
@@ -109,6 +139,51 @@ export OA_MAILGUN_KEY=key-abc123
OA_DATABASE_DRIVER=mysql OA_PASSWORD=secret OA_MAILGUN_KEY=key-123 ./server
```
#### Storage Configuration Examples
```bash
# Local storage (default)
export OA_STORAGE_BACKEND=local
export OA_STORAGE_LOCAL_ROOTDIR=./uploads
./server
# AWS S3
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=us-west-2
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=your-access-key
export OA_STORAGE_S3_SECRETACCESSKEY=your-secret-key
./server
# Backblaze B2 (S3-compatible)
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=us-west-004
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=your-b2-key-id
export OA_STORAGE_S3_SECRETACCESSKEY=your-b2-application-key
export OA_STORAGE_S3_ENDPOINT=https://s3.us-west-004.backblazeb2.com
export OA_STORAGE_S3_PATHSTYLE=true
./server
# Cloudflare R2
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=auto
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=your-r2-access-key
export OA_STORAGE_S3_SECRETACCESSKEY=your-r2-secret-key
export OA_STORAGE_S3_ENDPOINT=https://your-account-id.r2.cloudflarestorage.com
./server
# MinIO (self-hosted)
export OA_STORAGE_BACKEND=s3
export OA_STORAGE_S3_REGION=us-east-1
export OA_STORAGE_S3_BUCKET=my-app-attachments
export OA_STORAGE_S3_ACCESSKEYID=minioadmin
export OA_STORAGE_S3_SECRETACCESSKEY=minioadmin
export OA_STORAGE_S3_ENDPOINT=http://localhost:9000
export OA_STORAGE_S3_PATHSTYLE=true
./server
```
#### Docker
```bash
# SQLite with volume mount
@@ -123,6 +198,25 @@ docker run -p 8080:8080 \
-e OA_DATABASE_ADDRESS=mysql:3306 \
-e OA_PASSWORD=secret \
openaccounting-server:latest
# With AWS S3 storage
docker run -p 8080:8080 \
-e OA_STORAGE_BACKEND=s3 \
-e OA_STORAGE_S3_REGION=us-west-2 \
-e OA_STORAGE_S3_BUCKET=my-attachments \
-e OA_STORAGE_S3_ACCESSKEYID=your-key \
-e OA_STORAGE_S3_SECRETACCESSKEY=your-secret \
openaccounting-server:latest
# With Cloudflare R2 storage
docker run -p 8080:8080 \
-e OA_STORAGE_BACKEND=s3 \
-e OA_STORAGE_S3_REGION=auto \
-e OA_STORAGE_S3_BUCKET=my-attachments \
-e OA_STORAGE_S3_ACCESSKEYID=your-r2-key \
-e OA_STORAGE_S3_SECRETACCESSKEY=your-r2-secret \
-e OA_STORAGE_S3_ENDPOINT=https://account-id.r2.cloudflarestorage.com \
openaccounting-server:latest
```
## Database Setup

239
STORAGE.md Normal file
View File

@@ -0,0 +1,239 @@
# Modular Storage System
The OpenAccounting server now supports multiple storage backends for file attachments. This allows you to choose between local filesystem storage for simple deployments or cloud storage for production/multi-user environments.
## Supported Storage Backends
### 1. Local Filesystem Storage
Perfect for self-hosted deployments or development environments.
**Configuration:**
```json
{
"storage": {
"backend": "local",
"local": {
"root_dir": "./uploads",
"base_url": "https://yourapp.com/files"
}
}
}
```
**Environment Variables:**
```bash
OA_STORAGE_BACKEND=local
OA_STORAGE_LOCAL_ROOT_DIR=./uploads
OA_STORAGE_LOCAL_BASE_URL=https://yourapp.com/files
```
### 2. Amazon S3 Storage
Reliable cloud storage for production deployments.
**Configuration:**
```json
{
"storage": {
"backend": "s3",
"s3": {
"region": "us-east-1",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments",
"access_key_id": "AKIA...",
"secret_access_key": "...",
"endpoint": "",
"path_style": false
}
}
}
```
**Environment Variables:**
```bash
OA_STORAGE_BACKEND=s3
OA_STORAGE_S3_REGION=us-east-1
OA_STORAGE_S3_BUCKET=my-openaccounting-attachments
OA_STORAGE_S3_PREFIX=attachments
OA_STORAGE_S3_ACCESS_KEY_ID=AKIA...
OA_STORAGE_S3_SECRET_ACCESS_KEY=...
```
**Features:**
- Automatic presigned URL generation
- Configurable expiry times
- Support for S3-compatible services (MinIO, DigitalOcean Spaces)
- IAM role support (leave credentials empty to use IAM)
### 3. Backblaze B2 Storage
Cost-effective cloud storage alternative to S3.
**Configuration:**
```json
{
"storage": {
"backend": "b2",
"b2": {
"account_id": "your-b2-account-id",
"application_key": "your-b2-application-key",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments"
}
}
}
```
**Environment Variables:**
```bash
OA_STORAGE_BACKEND=b2
OA_STORAGE_B2_ACCOUNT_ID=your-b2-account-id
OA_STORAGE_B2_APPLICATION_KEY=your-b2-application-key
OA_STORAGE_B2_BUCKET=my-openaccounting-attachments
OA_STORAGE_B2_PREFIX=attachments
```
## API Endpoints
The storage system provides both legacy and new endpoints:
### New Storage-Agnostic Endpoints
**Upload Attachment:**
```
POST /api/v1/attachments
Content-Type: multipart/form-data
transactionId: uuid
description: string (optional)
file: binary data
```
**Get Attachment Metadata:**
```
GET /api/v1/attachments/{id}
```
**Get Download URL:**
```
GET /api/v1/attachments/{id}/url
```
**Download File:**
```
GET /api/v1/attachments/{id}?download=true
```
**Delete Attachment:**
```
DELETE /api/v1/attachments/{id}
```
### Legacy Endpoints (Still Supported)
The original transaction-scoped endpoints remain available for backward compatibility:
- `GET/POST /api/v1/orgs/{orgId}/transactions/{transactionId}/attachments`
## Security Features
- **File type validation** - Only allowed MIME types are accepted
- **File size limits** - Configurable maximum file size (default 10MB)
- **Path traversal protection** - Prevents directory traversal attacks
- **Access control** - Files are linked to users and organizations
- **Presigned URLs** - Time-limited access for cloud storage
## File Organization
Files are automatically organized by date:
```
uploads/
├── 2025/
│ ├── 01/
│ │ ├── 15/
│ │ │ ├── uuid1.pdf
│ │ │ └── uuid2.png
│ │ └── 16/
│ │ └── uuid3.jpg
```
## Configuration Examples
### Development (Local Storage)
```json
{
"storage": {
"backend": "local",
"local": {
"root_dir": "./dev-uploads"
}
}
}
```
### Production (S3 with IAM)
```json
{
"storage": {
"backend": "s3",
"s3": {
"region": "us-west-2",
"bucket": "prod-openaccounting-files",
"prefix": "attachments"
}
}
}
```
### Cost-Optimized (Backblaze B2)
```json
{
"storage": {
"backend": "b2",
"b2": {
"account_id": "${B2_ACCOUNT_ID}",
"application_key": "${B2_APP_KEY}",
"bucket": "openaccounting-prod"
}
}
}
```
## Migration Between Storage Backends
When changing storage backends, existing attachments will remain in the old storage location. The database records contain the storage path, so files can be accessed until migrated.
To migrate:
1. Update configuration to new backend
2. Restart server
3. New uploads will use the new backend
4. Optional: Run migration script to move existing files
## Environment-Specific Considerations
### Self-Hosted
- Use local storage for simplicity
- Ensure backup strategy includes upload directory
- Consider disk space management
### Cloud Deployment
- Use S3 or B2 for reliability and scalability
- Configure proper IAM policies
- Enable versioning and lifecycle policies
### Multi-Region
- Use cloud storage with appropriate region selection
- Consider CDN integration for better performance
## Troubleshooting
**Storage backend not initialized:**
- Check configuration syntax
- Verify credentials for cloud backends
- Ensure storage directories/buckets exist
**Permission denied:**
- Check file system permissions for local storage
- Verify IAM policies for S3
- Confirm B2 application key permissions
**Large file uploads failing:**
- Check `MaxFileSize` configuration
- Verify network timeouts
- Consider multipart upload for large files

17
config.b2.json.sample Normal file
View File

@@ -0,0 +1,17 @@
{
"weburl": "https://yourapp.com",
"address": "localhost",
"port": 8080,
"apiprefix": "/api/v1",
"databasedriver": "sqlite",
"databasefile": "./openaccounting.db",
"storage": {
"backend": "b2",
"b2": {
"account_id": "your-b2-account-id",
"application_key": "your-b2-application-key",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments"
}
}
}

20
config.s3.json.sample Normal file
View File

@@ -0,0 +1,20 @@
{
"weburl": "https://yourapp.com",
"address": "localhost",
"port": 8080,
"apiprefix": "/api/v1",
"databasedriver": "sqlite",
"databasefile": "./openaccounting.db",
"storage": {
"backend": "s3",
"s3": {
"region": "us-east-1",
"bucket": "my-openaccounting-attachments",
"prefix": "attachments",
"access_key_id": "",
"secret_access_key": "",
"endpoint": "",
"path_style": false
}
}
}

View File

@@ -0,0 +1,15 @@
{
"weburl": "https://yourapp.com",
"address": "localhost",
"port": 8080,
"apiprefix": "/api/v1",
"databasedriver": "sqlite",
"databasefile": "./openaccounting.db",
"storage": {
"backend": "local",
"local": {
"root_dir": "./uploads",
"base_url": "https://yourapp.com/files"
}
}
}

313
core/api/attachment.go Normal file
View File

@@ -0,0 +1,313 @@
package api
import (
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/ant0ine/go-json-rest/rest"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
)
const (
MaxFileSize = 10 * 1024 * 1024 // 10MB
MaxFilesPerTx = 10
AttachmentDir = "attachments"
)
var AllowedMimeTypes = map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"application/pdf": true,
"text/plain": true,
"text/csv": true,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true, // .xlsx
"application/vnd.ms-excel": true, // .xls
}
func PostAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Parse multipart form
err := r.ParseMultipartForm(MaxFileSize)
if err != nil {
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
rest.Error(w, "No files provided", http.StatusBadRequest)
return
}
if len(files) > MaxFilesPerTx {
rest.Error(w, fmt.Sprintf("Too many files. Maximum %d files allowed", MaxFilesPerTx), http.StatusBadRequest)
return
}
// Verify transaction exists and user has permission
tx, err := model.Instance.GetTransaction(transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Transaction not found or access denied", http.StatusNotFound)
return
}
if tx == nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
var attachments []*types.Attachment
var description string
if desc := r.FormValue("description"); desc != "" {
description = desc
}
for _, fileHeader := range files {
attachment, err := processFileUpload(fileHeader, transactionId, orgId, user.Id, description)
if err != nil {
// Clean up any successfully uploaded files
for _, att := range attachments {
os.Remove(att.FilePath)
}
rest.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Save attachment to database
createdAttachment, err := model.Instance.CreateAttachment(attachment)
if err != nil {
// Clean up file and any previously uploaded files
os.Remove(attachment.FilePath)
for _, att := range attachments {
os.Remove(att.FilePath)
}
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
return
}
attachments = append(attachments, createdAttachment)
}
w.WriteJson(map[string]interface{}{
"attachments": attachments,
"count": len(attachments),
})
}
func GetAttachments(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Failed to retrieve attachments", http.StatusInternalServerError)
return
}
w.WriteJson(attachments)
}
func GetAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
attachmentId := r.PathParam("attachmentId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Attachment not found or access denied", http.StatusNotFound)
return
}
w.WriteJson(attachment)
}
func DownloadAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
attachmentId := r.PathParam("attachmentId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
attachment, err := model.Instance.GetAttachment(attachmentId, transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Attachment not found or access denied", http.StatusNotFound)
return
}
// Check if file exists
if _, err := os.Stat(attachment.FilePath); os.IsNotExist(err) {
rest.Error(w, "File not found on disk", http.StatusNotFound)
return
}
// Set headers for file download
w.Header().Set("Content-Type", attachment.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
// Open and serve file
file, err := os.Open(attachment.FilePath)
if err != nil {
rest.Error(w, "Failed to open file", http.StatusInternalServerError)
return
}
defer file.Close()
io.Copy(w.(http.ResponseWriter), file)
}
func DeleteAttachment(w rest.ResponseWriter, r *rest.Request) {
orgId := r.PathParam("orgId")
transactionId := r.PathParam("transactionId")
attachmentId := r.PathParam("attachmentId")
if !util.IsValidUUID(orgId) || !util.IsValidUUID(transactionId) || !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid UUID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
err := model.Instance.DeleteAttachment(attachmentId, transactionId, orgId, user.Id)
if err != nil {
rest.Error(w, "Failed to delete attachment or access denied", http.StatusInternalServerError)
return
}
w.WriteJson(map[string]string{"status": "deleted"})
}
func processFileUpload(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
// Validate file size
if fileHeader.Size > MaxFileSize {
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
}
// Open uploaded file
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %v", err)
}
defer file.Close()
// Validate file type from header
contentType := fileHeader.Header.Get("Content-Type")
if !AllowedMimeTypes[contentType] {
return nil, fmt.Errorf("file type %s not allowed", contentType)
}
// Validate file type by detecting content (more secure)
buffer := make([]byte, 512)
n, err := file.Read(buffer)
if err != nil {
return nil, fmt.Errorf("failed to read file for content detection: %v", err)
}
// Reset file pointer to beginning
if _, err := file.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to reset file pointer: %v", err)
}
detectedType := http.DetectContentType(buffer[:n])
if !AllowedMimeTypes[detectedType] {
return nil, fmt.Errorf("detected file type %s not allowed (header claimed %s)", detectedType, contentType)
}
// Generate unique filename
attachmentId := util.NewUUID()
ext := filepath.Ext(fileHeader.Filename)
fileName := attachmentId + ext
// Create attachments directory if it doesn't exist
uploadDir := filepath.Join(AttachmentDir, orgId, transactionId)
if err := os.MkdirAll(uploadDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create upload directory: %v", err)
}
// Create file path
filePath := filepath.Join(uploadDir, fileName)
// Create destination file
dst, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to create destination file: %v", err)
}
defer dst.Close()
// Copy file contents
if _, err := io.Copy(dst, file); err != nil {
return nil, fmt.Errorf("failed to save file: %v", err)
}
// Create attachment object
attachment := &types.Attachment{
Id: attachmentId,
TransactionId: transactionId,
OrgId: orgId,
UserId: userId,
FileName: fileName,
OriginalName: fileHeader.Filename,
ContentType: contentType,
FileSize: fileHeader.Size,
FilePath: filePath,
Description: description,
Uploaded: time.Now(),
Deleted: false,
}
return attachment, nil
}
func sanitizeFilename(filename string) string {
// Remove potentially dangerous characters
filename = strings.ReplaceAll(filename, "..", "")
filename = strings.ReplaceAll(filename, "/", "")
filename = strings.ReplaceAll(filename, "\\", "")
filename = strings.ReplaceAll(filename, "\x00", "") // null bytes
filename = strings.ReplaceAll(filename, "\r", "") // carriage return
filename = strings.ReplaceAll(filename, "\n", "") // newline
// Limit filename length
if len(filename) > 255 {
ext := filepath.Ext(filename)
base := filename[:255-len(ext)]
filename = base + ext
}
return filename
}

View File

@@ -0,0 +1,306 @@
package api
import (
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/core/util/id"
"github.com/openaccounting/oa-server/database"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupTestDatabase(t *testing.T) (*gorm.DB, func()) {
// Create temporary database file
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
// Set global DB for database package
database.DB = db
// Run migrations
err = database.AutoMigrate()
if err != nil {
t.Fatalf("Failed to run auto migrations: %v", err)
}
err = database.Migrate()
if err != nil {
t.Fatalf("Failed to run custom migrations: %v", err)
}
// Cleanup function
cleanup := func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
os.RemoveAll(tmpDir)
}
return db, cleanup
}
func setupTestData(t *testing.T, db *gorm.DB) (orgID, userID, transactionID string) {
// Use hardcoded UUIDs without dashes for hex format
orgID = "550e8400e29b41d4a716446655440000"
userID = "550e8400e29b41d4a716446655440001"
transactionID = "550e8400e29b41d4a716446655440002"
accountID := "550e8400e29b41d4a716446655440003"
// Insert test data using raw SQL for reliability
now := time.Now()
// Insert org
err := db.Exec("INSERT INTO orgs (id, inserted, updated, name, currency, `precision`, timezone) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?)",
orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Org", "USD", 2, "UTC").Error
if err != nil {
t.Fatalf("Failed to insert org: %v", err)
}
// Insert user
err = db.Exec("INSERT INTO users (id, inserted, updated, firstName, lastName, email, passwordHash, agreeToTerms, passwordReset, emailVerified, emailVerifyCode, signupSource) VALUES (UNHEX(?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test", "User", "test@example.com", "hashedpassword", true, "", true, "", "test").Error
if err != nil {
t.Fatalf("Failed to insert user: %v", err)
}
// Insert user-org relationship
err = db.Exec("INSERT INTO user_orgs (userId, orgId, admin) VALUES (UNHEX(?), UNHEX(?), ?)",
userID, orgID, false).Error
if err != nil {
t.Fatalf("Failed to insert user-org: %v", err)
}
// Insert account
err = db.Exec("INSERT INTO accounts (id, orgId, inserted, updated, name, parent, currency, `precision`, debitBalance) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?, ?)",
accountID, orgID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Account", []byte{}, "USD", 2, true).Error
if err != nil {
t.Fatalf("Failed to insert account: %v", err)
}
// Insert transaction
err = db.Exec("INSERT INTO transactions (id, orgId, userId, inserted, updated, date, description, data, deleted) VALUES (UNHEX(?), UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
transactionID, orgID, userID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), "Test Transaction", "", false).Error
if err != nil {
t.Fatalf("Failed to insert transaction: %v", err)
}
// Insert split
err = db.Exec("INSERT INTO splits (transactionId, accountId, date, inserted, updated, amount, nativeAmount, deleted) VALUES (UNHEX(?), UNHEX(?), ?, ?, ?, ?, ?, ?)",
transactionID, accountID, now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), now.UnixNano()/int64(time.Millisecond), 100, 100, false).Error
if err != nil {
t.Fatalf("Failed to insert split: %v", err)
}
return orgID, userID, transactionID
}
func createTestFile(t *testing.T) (string, []byte) {
content := []byte("This is a test file content for attachment testing")
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.txt")
err := os.WriteFile(filePath, content, 0644)
if err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
return filePath, content
}
func TestAttachmentIntegration(t *testing.T) {
db, cleanup := setupTestDatabase(t)
defer cleanup()
orgID, userID, transactionID := setupTestData(t, db)
// Set up the model instance for the API handlers
bc := &util.StandardBcrypt{}
// Use the existing datastore model which has the attachment implementation
// We need to create it with the database connection
datastoreModel := model.NewModel(nil, bc, types.Config{})
model.Instance = datastoreModel
t.Run("Database Integration Test", func(t *testing.T) {
// Test direct database operations first
filePath, originalContent := createTestFile(t)
defer os.Remove(filePath)
// Create attachment record directly
attachmentID := id.String(id.New())
uploadTime := time.Now()
attachment := types.Attachment{
Id: attachmentID,
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
FileName: "stored_test.txt",
OriginalName: "test.txt",
ContentType: "text/plain",
FileSize: int64(len(originalContent)),
FilePath: "uploads/test/" + attachmentID + ".txt",
Description: "Test attachment description",
Uploaded: uploadTime,
Deleted: false,
}
// Insert using the existing model
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.NoError(t, err)
assert.NotNil(t, createdAttachment)
assert.Equal(t, attachmentID, createdAttachment.Id)
// Verify database persistence
var dbAttachment types.Attachment
err = db.Raw("SELECT HEX(id) as id, HEX(transactionId) as transactionId, HEX(orgId) as orgId, HEX(userId) as userId, fileName, originalName, contentType, fileSize, filePath, description, uploaded, deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&dbAttachment).Error
assert.NoError(t, err)
assert.Equal(t, attachmentID, dbAttachment.Id)
assert.Equal(t, transactionID, dbAttachment.TransactionId)
assert.Equal(t, "Test attachment description", dbAttachment.Description)
// Test retrieval
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
assert.NotNil(t, retrievedAttachment)
assert.Equal(t, attachmentID, retrievedAttachment.Id)
// Test listing by transaction
attachments, err := model.Instance.GetAttachmentsByTransaction(transactionID, orgID, userID)
assert.NoError(t, err)
assert.Len(t, attachments, 1)
assert.Equal(t, attachmentID, attachments[0].Id)
// Test soft deletion
err = model.Instance.DeleteAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
// Verify soft deletion in database
var deletedAttachment types.Attachment
err = db.Raw("SELECT deleted FROM attachment WHERE HEX(id) = ?", attachmentID).Scan(&deletedAttachment).Error
assert.NoError(t, err)
assert.True(t, deletedAttachment.Deleted)
// Verify attachment is no longer accessible
retrievedAttachment, err = model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.Error(t, err)
assert.Nil(t, retrievedAttachment)
})
t.Run("File Upload Integration Test", func(t *testing.T) {
// Test file upload functionality
filePath, originalContent := createTestFile(t)
defer os.Remove(filePath)
// Create upload directory
uploadDir := "uploads/test"
os.MkdirAll(uploadDir, 0755)
defer os.RemoveAll("uploads")
// Simulate file upload process
attachmentID := id.String(id.New())
storedFilePath := filepath.Join(uploadDir, attachmentID+".txt")
// Copy file to upload location
err := copyFile(filePath, storedFilePath)
assert.NoError(t, err)
// Create attachment record
attachment := types.Attachment{
Id: attachmentID,
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
FileName: filepath.Base(storedFilePath),
OriginalName: "test.txt",
ContentType: "text/plain",
FileSize: int64(len(originalContent)),
FilePath: storedFilePath,
Description: "Uploaded test file",
Uploaded: time.Now(),
Deleted: false,
}
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.NoError(t, err)
assert.NotNil(t, createdAttachment)
// Verify file exists
_, err = os.Stat(storedFilePath)
assert.NoError(t, err)
// Verify database record
retrievedAttachment, err := model.Instance.GetAttachment(attachmentID, transactionID, orgID, userID)
assert.NoError(t, err)
assert.Equal(t, storedFilePath, retrievedAttachment.FilePath)
assert.Equal(t, int64(len(originalContent)), retrievedAttachment.FileSize)
})
}
// Helper function to copy files
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
func TestAttachmentValidation(t *testing.T) {
db, cleanup := setupTestDatabase(t)
defer cleanup()
orgID, userID, transactionID := setupTestData(t, db)
// Set up the model instance
bc := &util.StandardBcrypt{}
gormModel := model.NewGormModel(db, bc, types.Config{})
model.Instance = gormModel
t.Run("Invalid attachment data", func(t *testing.T) {
// Test with missing required fields
attachment := types.Attachment{
// Missing ID
TransactionId: transactionID,
OrgId: orgID,
UserId: userID,
}
createdAttachment, err := model.Instance.CreateAttachment(&attachment)
assert.Error(t, err)
assert.Nil(t, createdAttachment)
})
t.Run("Non-existent attachment retrieval", func(t *testing.T) {
nonExistentID := id.String(id.New())
attachment, err := model.Instance.GetAttachment(nonExistentID, transactionID, orgID, userID)
assert.Error(t, err)
assert.Nil(t, attachment)
})
}

View File

@@ -0,0 +1,289 @@
package api
import (
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
"github.com/ant0ine/go-json-rest/rest"
"github.com/openaccounting/oa-server/core/model"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/storage"
"github.com/openaccounting/oa-server/core/util"
"github.com/openaccounting/oa-server/core/util/id"
)
// AttachmentHandler handles attachment operations with configurable storage
type AttachmentHandler struct {
storage storage.Storage
}
// Global attachment handler instance (will be initialized during server startup)
var attachmentHandler *AttachmentHandler
// InitializeAttachmentHandler initializes the global attachment handler with storage backend
func InitializeAttachmentHandler(storageConfig storage.Config) error {
storageBackend, err := storage.NewStorage(storageConfig)
if err != nil {
return fmt.Errorf("failed to initialize storage backend: %w", err)
}
attachmentHandler = &AttachmentHandler{
storage: storageBackend,
}
return nil
}
// PostAttachmentWithStorage handles file upload using the configured storage backend
func PostAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
transactionId := r.FormValue("transactionId")
if transactionId == "" {
rest.Error(w, "Transaction ID is required", http.StatusBadRequest)
return
}
if !util.IsValidUUID(transactionId) {
rest.Error(w, "Invalid transaction ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Parse multipart form
err := r.ParseMultipartForm(MaxFileSize)
if err != nil {
rest.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return
}
files := r.MultipartForm.File["file"]
if len(files) == 0 {
rest.Error(w, "No file provided", http.StatusBadRequest)
return
}
fileHeader := files[0] // Take the first file
// Verify transaction exists and user has permission
tx, err := model.Instance.GetTransaction(transactionId, "", user.Id)
if err != nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
if tx == nil {
rest.Error(w, "Transaction not found", http.StatusNotFound)
return
}
// Process the file upload
attachment, err := attachmentHandler.processFileUploadWithStorage(fileHeader, transactionId, tx.OrgId, user.Id, r.FormValue("description"))
if err != nil {
rest.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Save attachment to database
createdAttachment, err := model.Instance.CreateAttachment(attachment)
if err != nil {
// Clean up the stored file on database error
attachmentHandler.storage.Delete(attachment.FilePath)
rest.Error(w, "Failed to save attachment", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
w.WriteJson(createdAttachment)
}
// GetAttachmentWithStorage retrieves an attachment using the configured storage backend
func GetAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Check if this is a download request
if r.URL.Query().Get("download") == "true" {
// Stream the file directly to the client
err := attachmentHandler.streamFile(w, attachment)
if err != nil {
rest.Error(w, "Failed to retrieve file", http.StatusInternalServerError)
return
}
return
}
// Return attachment metadata
w.WriteJson(attachment)
}
// GetAttachmentDownloadURL returns a download URL for an attachment
func GetAttachmentDownloadURL(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Generate download URL (valid for 1 hour)
url, err := attachmentHandler.storage.GetURL(attachment.FilePath, time.Hour)
if err != nil {
rest.Error(w, "Failed to generate download URL", http.StatusInternalServerError)
return
}
response := map[string]string{
"url": url,
"expiresIn": "3600", // 1 hour in seconds
}
w.WriteJson(response)
}
// DeleteAttachmentWithStorage deletes an attachment using the configured storage backend
func DeleteAttachmentWithStorage(w rest.ResponseWriter, r *rest.Request) {
if attachmentHandler == nil {
rest.Error(w, "Storage backend not initialized", http.StatusInternalServerError)
return
}
attachmentId := r.PathParam("id")
if !util.IsValidUUID(attachmentId) {
rest.Error(w, "Invalid attachment ID format", http.StatusBadRequest)
return
}
user := r.Env["USER"].(*types.User)
// Get attachment from database first
attachment, err := model.Instance.GetAttachment(attachmentId, "", "", user.Id)
if err != nil {
rest.Error(w, "Attachment not found", http.StatusNotFound)
return
}
// Delete from database (soft delete)
err = model.Instance.DeleteAttachment(attachmentId, attachment.TransactionId, attachment.OrgId, user.Id)
if err != nil {
rest.Error(w, "Failed to delete attachment", http.StatusInternalServerError)
return
}
// Delete from storage backend
// Note: For production, you might want to delay physical deletion
// and run a cleanup job later to handle any issues
err = attachmentHandler.storage.Delete(attachment.FilePath)
if err != nil {
// Log the error but don't fail the request since database deletion succeeded
// The file can be cleaned up later by a maintenance job
fmt.Printf("Warning: Failed to delete file from storage: %v\n", err)
}
w.WriteHeader(http.StatusOK)
w.WriteJson(map[string]string{"status": "deleted"})
}
// processFileUploadWithStorage processes a file upload using the storage backend
func (h *AttachmentHandler) processFileUploadWithStorage(fileHeader *multipart.FileHeader, transactionId, orgId, userId, description string) (*types.Attachment, error) {
// Validate file size
if fileHeader.Size > MaxFileSize {
return nil, fmt.Errorf("file too large. Maximum size is %d bytes", MaxFileSize)
}
// Validate content type
contentType := fileHeader.Header.Get("Content-Type")
if !AllowedMimeTypes[contentType] {
return nil, fmt.Errorf("unsupported file type: %s", contentType)
}
// Open the file
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
}
defer file.Close()
// Store the file using the storage backend
storagePath, err := h.storage.Store(fileHeader.Filename, file, contentType)
if err != nil {
return nil, fmt.Errorf("failed to store file: %w", err)
}
// Create attachment record
attachment := &types.Attachment{
Id: id.String(id.New()),
TransactionId: transactionId,
OrgId: orgId,
UserId: userId,
FileName: storagePath, // Store the storage path/key
OriginalName: fileHeader.Filename,
ContentType: contentType,
FileSize: fileHeader.Size,
FilePath: storagePath, // For backward compatibility
Description: description,
Uploaded: time.Now(),
Deleted: false,
}
return attachment, nil
}
// streamFile streams a file from storage to the HTTP response
func (h *AttachmentHandler) streamFile(w rest.ResponseWriter, attachment *types.Attachment) error {
// Get file from storage
reader, err := h.storage.Retrieve(attachment.FilePath)
if err != nil {
return fmt.Errorf("failed to retrieve file: %w", err)
}
defer reader.Close()
// Set appropriate headers
w.Header().Set("Content-Type", attachment.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", attachment.OriginalName))
// If we know the file size, set Content-Length
if attachment.FileSize > 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", attachment.FileSize))
}
// Stream the file to the client
_, err = io.Copy(w.(http.ResponseWriter), reader)
return err
}

View File

@@ -31,6 +31,17 @@ func GetRouter(auth *AuthMiddleware, prefix string) (rest.App, error) {
rest.Post(prefix+"/orgs/:orgId/transactions", auth.RequireAuth(PostTransaction)),
rest.Put(prefix+"/orgs/:orgId/transactions/:transactionId", auth.RequireAuth(PutTransaction)),
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId", auth.RequireAuth(DeleteTransaction)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments", auth.RequireAuth(GetAttachments)),
rest.Post(prefix+"/orgs/:orgId/transactions/:transactionId/attachments", auth.RequireAuth(PostAttachment)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(GetAttachment)),
rest.Get(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId/download", auth.RequireAuth(DownloadAttachment)),
rest.Delete(prefix+"/orgs/:orgId/transactions/:transactionId/attachments/:attachmentId", auth.RequireAuth(DeleteAttachment)),
// New storage-based attachment endpoints
rest.Post(prefix+"/attachments", auth.RequireAuth(PostAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id", auth.RequireAuth(GetAttachmentWithStorage)),
rest.Get(prefix+"/attachments/:id/url", auth.RequireAuth(GetAttachmentDownloadURL)),
rest.Delete(prefix+"/attachments/:id", auth.RequireAuth(DeleteAttachmentWithStorage)),
rest.Get(prefix+"/orgs/:orgId/prices", auth.RequireAuth(GetPrices)),
rest.Post(prefix+"/orgs/:orgId/prices", auth.RequireAuth(PostPrice)),
rest.Delete(prefix+"/orgs/:orgId/prices/:priceId", auth.RequireAuth(DeletePrice)),

View File

@@ -993,3 +993,74 @@ func (_m *Datastore) VerifyUser(_a0 string) error {
return r0
}
// Attachment interface mock methods
func (_m *Datastore) InsertAttachment(_a0 *types.Attachment) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(*types.Attachment) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
func (_m *Datastore) GetAttachment(_a0 string, _a1 string, _a2 string) (*types.Attachment, error) {
ret := _m.Called(_a0, _a1, _a2)
var r0 *types.Attachment
if rf, ok := ret.Get(0).(func(string, string, string) *types.Attachment); ok {
r0 = rf(_a0, _a1, _a2)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.Attachment)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string, string) error); ok {
r1 = rf(_a0, _a1, _a2)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (_m *Datastore) GetAttachmentsByTransaction(_a0 string, _a1 string) ([]*types.Attachment, error) {
ret := _m.Called(_a0, _a1)
var r0 []*types.Attachment
if rf, ok := ret.Get(0).(func(string, string) []*types.Attachment); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*types.Attachment)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (_m *Datastore) DeleteAttachment(_a0 string, _a1 string, _a2 string) error {
ret := _m.Called(_a0, _a1, _a2)
var r0 error
if rf, ok := ret.Get(0).(func(string, string, string) error); ok {
r0 = rf(_a0, _a1, _a2)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -408,6 +408,15 @@ func (model *Model) accountsContainWriteAccess(accounts []*types.Account, accoun
return false
}
func (model *Model) accountsContainReadAccess(accounts []*types.Account, accountId string) bool {
for _, account := range accounts {
if account.Id == accountId {
return true
}
}
return false
}
func (model *Model) getAccountFromList(accounts []*types.Account, accountId string) *types.Account {
for _, account := range accounts {
if account.Id == accountId {

View File

@@ -162,6 +162,10 @@ func TestCreateAccount(t *testing.T) {
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
// Mock GetSplitCountByAccountId for parent account check
if test.account.Parent != "" {
td.On("GetSplitCountByAccountId", test.account.Parent).Return(int64(0), nil)
}
model := NewModel(td, nil, types.Config{})
@@ -206,6 +210,10 @@ func TestUpdateAccount(t *testing.T) {
td := &TdAccount{}
td.On("GetAccountsByOrgId", "1").Return(getTestAccounts(), nil)
// Mock GetSplitCountByAccountId for parent account check
if test.account.Parent != "" {
td.On("GetSplitCountByAccountId", test.account.Parent).Return(int64(0), nil)
}
model := NewModel(td, nil, types.Config{})

163
core/model/attachment.go Normal file
View File

@@ -0,0 +1,163 @@
package model
import (
"errors"
"time"
"github.com/openaccounting/oa-server/core/model/types"
)
type AttachmentInterface interface {
CreateAttachment(*types.Attachment) (*types.Attachment, error)
GetAttachmentsByTransaction(string, string, string) ([]*types.Attachment, error)
GetAttachment(string, string, string, string) (*types.Attachment, error)
DeleteAttachment(string, string, string, string) error
}
func (model *Model) CreateAttachment(attachment *types.Attachment) (*types.Attachment, error) {
if attachment.Id == "" {
return nil, errors.New("attachment ID required")
}
if attachment.TransactionId == "" {
return nil, errors.New("transaction ID required")
}
if attachment.OrgId == "" {
return nil, errors.New("organization ID required")
}
if attachment.UserId == "" {
return nil, errors.New("user ID required")
}
if attachment.FileName == "" {
return nil, errors.New("file name required")
}
if attachment.FilePath == "" {
return nil, errors.New("file path required")
}
// Set upload timestamp
attachment.Uploaded = time.Now()
attachment.Deleted = false
// Save to database
err := model.db.InsertAttachment(attachment)
if err != nil {
return nil, err
}
return attachment, nil
}
func (model *Model) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
if transactionId == "" {
return nil, errors.New("transaction ID required")
}
if orgId == "" {
return nil, errors.New("organization ID required")
}
if userId == "" {
return nil, errors.New("user ID required")
}
// First verify the user has access to the transaction
tx, err := model.GetTransaction(transactionId, orgId, userId)
if err != nil {
return nil, err
}
if tx == nil {
return nil, errors.New("transaction not found or access denied")
}
// Get attachments for the transaction
attachments, err := model.db.GetAttachmentsByTransaction(transactionId, orgId)
if err != nil {
return nil, err
}
return attachments, nil
}
func (model *Model) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
if attachmentId == "" {
return nil, errors.New("attachment ID required")
}
if transactionId == "" {
return nil, errors.New("transaction ID required")
}
if orgId == "" {
return nil, errors.New("organization ID required")
}
if userId == "" {
return nil, errors.New("user ID required")
}
// First verify the user has access to the transaction
tx, err := model.GetTransaction(transactionId, orgId, userId)
if err != nil {
return nil, err
}
if tx == nil {
return nil, errors.New("transaction not found or access denied")
}
// Get the attachment
attachment, err := model.db.GetAttachment(attachmentId, transactionId, orgId)
if err != nil {
return nil, err
}
return attachment, nil
}
func (model *Model) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
if attachmentId == "" {
return errors.New("attachment ID required")
}
if transactionId == "" {
return errors.New("transaction ID required")
}
if orgId == "" {
return errors.New("organization ID required")
}
if userId == "" {
return errors.New("user ID required")
}
// First verify the user has access to the transaction
tx, err := model.GetTransaction(transactionId, orgId, userId)
if err != nil {
return err
}
if tx == nil {
return errors.New("transaction not found or access denied")
}
// Verify the attachment exists and belongs to the transaction
attachment, err := model.db.GetAttachment(attachmentId, transactionId, orgId)
if err != nil {
return err
}
if attachment == nil {
return errors.New("attachment not found")
}
// Soft delete the attachment
err = model.db.DeleteAttachment(attachmentId, transactionId, orgId)
if err != nil {
return err
}
return nil
}

126
core/model/db/attachment.go Normal file
View File

@@ -0,0 +1,126 @@
package db
import (
"database/sql"
"github.com/openaccounting/oa-server/core/model/types"
"github.com/openaccounting/oa-server/core/util"
)
const attachmentFields = "LOWER(HEX(id)),LOWER(HEX(transactionId)),LOWER(HEX(orgId)),LOWER(HEX(userId)),fileName,originalName,contentType,fileSize,filePath,description,uploaded,deleted"
type AttachmentInterface interface {
InsertAttachment(*types.Attachment) error
GetAttachment(string, string, string) (*types.Attachment, error)
GetAttachmentsByTransaction(string, string) ([]*types.Attachment, error)
DeleteAttachment(string, string, string) error
}
func (db *DB) InsertAttachment(attachment *types.Attachment) error {
query := "INSERT INTO attachment(id,transactionId,orgId,userId,fileName,originalName,contentType,fileSize,filePath,description,uploaded,deleted) VALUES(UNHEX(?),UNHEX(?),UNHEX(?),UNHEX(?),?,?,?,?,?,?,?,?)"
_, err := db.Exec(
query,
attachment.Id,
attachment.TransactionId,
attachment.OrgId,
attachment.UserId,
attachment.FileName,
attachment.OriginalName,
attachment.ContentType,
attachment.FileSize,
attachment.FilePath,
attachment.Description,
util.TimeToMs(attachment.Uploaded),
attachment.Deleted,
)
return err
}
func (db *DB) GetAttachment(attachmentId, transactionId, orgId string) (*types.Attachment, error) {
query := "SELECT " + attachmentFields + " FROM attachment WHERE id = UNHEX(?) AND transactionId = UNHEX(?) AND orgId = UNHEX(?) AND deleted = false"
row := db.QueryRow(query, attachmentId, transactionId, orgId)
return db.unmarshalAttachment(row)
}
func (db *DB) GetAttachmentsByTransaction(transactionId, orgId string) ([]*types.Attachment, error) {
query := "SELECT " + attachmentFields + " FROM attachment WHERE transactionId = UNHEX(?) AND orgId = UNHEX(?) AND deleted = false ORDER BY uploaded DESC"
rows, err := db.Query(query, transactionId, orgId)
if err != nil {
return nil, err
}
return db.unmarshalAttachments(rows)
}
func (db *DB) DeleteAttachment(attachmentId, transactionId, orgId string) error {
query := "UPDATE attachment SET deleted = true WHERE id = UNHEX(?) AND transactionId = UNHEX(?) AND orgId = UNHEX(?)"
_, err := db.Exec(query, attachmentId, transactionId, orgId)
return err
}
func (db *DB) unmarshalAttachment(row *sql.Row) (*types.Attachment, error) {
attachment := &types.Attachment{}
var uploaded int64
err := row.Scan(
&attachment.Id,
&attachment.TransactionId,
&attachment.OrgId,
&attachment.UserId,
&attachment.FileName,
&attachment.OriginalName,
&attachment.ContentType,
&attachment.FileSize,
&attachment.FilePath,
&attachment.Description,
&uploaded,
&attachment.Deleted,
)
if err != nil {
return nil, err
}
attachment.Uploaded = util.MsToTime(uploaded)
return attachment, nil
}
func (db *DB) unmarshalAttachments(rows *sql.Rows) ([]*types.Attachment, error) {
defer rows.Close()
attachments := []*types.Attachment{}
for rows.Next() {
attachment := &types.Attachment{}
var uploaded int64
err := rows.Scan(
&attachment.Id,
&attachment.TransactionId,
&attachment.OrgId,
&attachment.UserId,
&attachment.FileName,
&attachment.OriginalName,
&attachment.ContentType,
&attachment.FileSize,
&attachment.FilePath,
&attachment.Description,
&uploaded,
&attachment.Deleted,
)
if err != nil {
return nil, err
}
attachment.Uploaded = util.MsToTime(uploaded)
attachments = append(attachments, attachment)
}
return attachments, nil
}

View File

@@ -15,6 +15,7 @@ type Datastore interface {
OrgInterface
AccountInterface
TransactionInterface
AttachmentInterface
PriceInterface
SessionInterface
ApiKeyInterface

View File

@@ -247,6 +247,38 @@ func (m *GormModel) InsertTransaction(transaction *types.Transaction) error {
return m.repository.InsertTransaction(transaction)
}
func (m *GormModel) GetTransaction(transactionId, orgId, userId string) (*types.Transaction, error) {
// For now, delegate to repository - in a full implementation, this would include permission checking
return m.repository.GetTransactionById(transactionId)
}
// AttachmentInterface implementation
func (m *GormModel) CreateAttachment(attachment *types.Attachment) (*types.Attachment, error) {
if attachment.Id == "" {
return nil, errors.New("attachment ID required")
}
// Set upload timestamp
attachment.Uploaded = time.Now()
attachment.Deleted = false
// For GORM implementation, we'd need to implement repository methods
// For now, return an error indicating not implemented
return nil, errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) GetAttachmentsByTransaction(transactionId, orgId, userId string) ([]*types.Attachment, error) {
return nil, errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) GetAttachment(attachmentId, transactionId, orgId, userId string) (*types.Attachment, error) {
return nil, errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) DeleteAttachment(attachmentId, transactionId, orgId, userId string) error {
return errors.New("attachment operations not yet implemented for GORM model")
}
func (m *GormModel) GetTransactionById(id string) (*types.Transaction, error) {
return m.repository.GetTransactionById(id)
}

View File

@@ -20,11 +20,13 @@ type Interface interface {
OrgInterface
AccountInterface
TransactionInterface
AttachmentInterface
PriceInterface
SessionInterface
ApiKeyInterface
SystemHealthInteface
BudgetInterface
GetTransaction(string, string, string) (*types.Transaction, error)
}
func NewModel(db db.Datastore, bcrypt util.Bcrypt, config types.Config) *Model {

View File

@@ -169,6 +169,31 @@ func (model *Model) getTransactionById(id string) (*types.Transaction, error) {
return model.db.GetTransactionById(id)
}
func (model *Model) GetTransaction(transactionId, orgId, userId string) (*types.Transaction, error) {
transaction, err := model.getTransactionById(transactionId)
if err != nil {
return nil, err
}
if transaction == nil || transaction.OrgId != orgId {
return nil, nil
}
// Check if user has access to all accounts in the transaction
userAccounts, err := model.GetAccounts(orgId, userId, "")
if err != nil {
return nil, err
}
for _, split := range transaction.Splits {
if !model.accountsContainReadAccess(userAccounts, split.AccountId) {
return nil, fmt.Errorf("user does not have permission to access account %s", split.AccountId)
}
}
return transaction, nil
}
func (model *Model) checkSplits(transaction *types.Transaction) (err error) {
if len(transaction.Splits) < 2 {
return errors.New("at least 2 splits are required")

View File

@@ -0,0 +1,20 @@
package types
import (
"time"
)
type Attachment struct {
Id string `json:"id"`
TransactionId string `json:"transactionId"`
OrgId string `json:"orgId"`
UserId string `json:"userId"`
FileName string `json:"fileName"`
OriginalName string `json:"originalName"`
ContentType string `json:"contentType"`
FileSize int64 `json:"fileSize"`
FilePath string `json:"filePath"`
Description string `json:"description"`
Uploaded time.Time `json:"uploaded"`
Deleted bool `json:"deleted"`
}

View File

@@ -1,5 +1,7 @@
package types
import "github.com/openaccounting/oa-server/core/storage"
type Config struct {
WebUrl string `mapstructure:"weburl"`
Address string `mapstructure:"address"`
@@ -15,8 +17,11 @@ type Config struct {
Password string `mapstructure:"password"` // Sensitive: use OA_PASSWORD env var
// SQLite specific
DatabaseFile string `mapstructure:"databasefile"`
// Email configuration
MailgunDomain string `mapstructure:"mailgundomain"`
MailgunKey string `mapstructure:"mailgunkey"` // Sensitive: use OA_MAILGUN_KEY env var
MailgunEmail string `mapstructure:"mailgunemail"`
MailgunSender string `mapstructure:"mailgunsender"`
// Storage configuration
Storage storage.Config `mapstructure:"storage"`
}

View File

@@ -5,6 +5,7 @@ import (
"log"
"net/http"
"strconv"
"strings"
"github.com/openaccounting/oa-server/core/api"
"github.com/openaccounting/oa-server/core/auth"
@@ -31,6 +32,22 @@ func main() {
viper.AutomaticEnv()
viper.SetEnvPrefix("OA") // will look for OA_DATABASE_PASSWORD, etc.
// Configure Viper to handle nested config with environment variables
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Bind specific storage environment variables for better support
// Using mapstructure field names (snake_case)
viper.BindEnv("Storage.backend", "OA_STORAGE_BACKEND")
viper.BindEnv("Storage.local.root_dir", "OA_STORAGE_LOCAL_ROOTDIR")
viper.BindEnv("Storage.local.base_url", "OA_STORAGE_LOCAL_BASEURL")
viper.BindEnv("Storage.s3.region", "OA_STORAGE_S3_REGION")
viper.BindEnv("Storage.s3.bucket", "OA_STORAGE_S3_BUCKET")
viper.BindEnv("Storage.s3.prefix", "OA_STORAGE_S3_PREFIX")
viper.BindEnv("Storage.s3.access_key_id", "OA_STORAGE_S3_ACCESSKEYID")
viper.BindEnv("Storage.s3.secret_access_key", "OA_STORAGE_S3_SECRETACCESSKEY")
viper.BindEnv("Storage.s3.endpoint", "OA_STORAGE_S3_ENDPOINT")
viper.BindEnv("Storage.s3.path_style", "OA_STORAGE_S3_PATHSTYLE")
// Set default values
viper.SetDefault("Address", "localhost")
viper.SetDefault("Port", 8080)
@@ -38,6 +55,11 @@ func main() {
viper.SetDefault("DatabaseFile", "./openaccounting.db")
viper.SetDefault("ApiPrefix", "/api/v1")
// Set storage defaults (using mapstructure field names)
viper.SetDefault("Storage.backend", "local")
viper.SetDefault("Storage.local.root_dir", "./uploads")
viper.SetDefault("Storage.local.base_url", "")
// Read configuration
err := viper.ReadInConfig()
if err != nil {
@@ -50,6 +72,14 @@ func main() {
if err != nil {
log.Fatal(fmt.Errorf("failed to unmarshal config: %s", err.Error()))
}
// Set storage defaults if not configured (Viper doesn't handle nested defaults well)
if config.Storage.Backend == "" {
config.Storage.Backend = "local"
}
if config.Storage.Local.RootDir == "" {
config.Storage.Local.RootDir = "./uploads"
}
// Parse database address (assuming format host:port for MySQL)
host := config.DatabaseAddress
@@ -105,6 +135,12 @@ func main() {
// Set the global model instance
model.Instance = gormModel
// Initialize storage backend for attachments
err = api.InitializeAttachmentHandler(config.Storage)
if err != nil {
log.Fatal(fmt.Errorf("failed to initialize storage backend: %s", err.Error()))
}
app, err := api.Init(config.ApiPrefix)
if err != nil {
log.Fatal(fmt.Errorf("failed to create api instance with: %s", err.Error()))

106
core/storage/interface.go Normal file
View File

@@ -0,0 +1,106 @@
package storage
import (
"io"
"time"
)
// Storage defines the interface for file storage backends
type Storage interface {
// Store saves a file and returns the storage path/key
Store(filename string, content io.Reader, contentType string) (string, error)
// Retrieve gets a file by its storage path/key
Retrieve(path string) (io.ReadCloser, error)
// Delete removes a file by its storage path/key
Delete(path string) error
// GetURL returns a URL for accessing the file (may be signed/temporary)
GetURL(path string, expiry time.Duration) (string, error)
// Exists checks if a file exists at the given path
Exists(path string) (bool, error)
// GetMetadata returns file metadata (size, last modified, etc.)
GetMetadata(path string) (*FileMetadata, error)
}
// FileMetadata contains information about a stored file
type FileMetadata struct {
Size int64
LastModified time.Time
ContentType string
ETag string
}
// Config holds configuration for storage backends
type Config struct {
// Storage backend type: "local", "s3"
Backend string `mapstructure:"backend"`
// Local filesystem configuration
Local LocalConfig `mapstructure:"local"`
// S3-compatible storage configuration (S3, B2, R2, etc.)
S3 S3Config `mapstructure:"s3"`
}
// LocalConfig configures local filesystem storage
type LocalConfig struct {
// Root directory for file storage
RootDir string `mapstructure:"root_dir"`
// Base URL for serving files (optional)
BaseURL string `mapstructure:"base_url"`
}
// S3Config configures S3-compatible storage (AWS S3, Backblaze B2, Cloudflare R2, etc.)
type S3Config struct {
// AWS Region (use "auto" for Cloudflare R2)
Region string `mapstructure:"region"`
// S3 Bucket name
Bucket string `mapstructure:"bucket"`
// Optional prefix for all objects
Prefix string `mapstructure:"prefix"`
// Access Key ID
AccessKeyID string `mapstructure:"access_key_id"`
// Secret Access Key
SecretAccessKey string `mapstructure:"secret_access_key"`
// Custom endpoint URL for S3-compatible services:
// - Backblaze B2: https://s3.us-west-004.backblazeb2.com
// - Cloudflare R2: https://<account-id>.r2.cloudflarestorage.com
// - MinIO: http://localhost:9000
// Leave empty for AWS S3
Endpoint string `mapstructure:"endpoint"`
// Use path-style addressing (required for some S3-compatible services)
PathStyle bool `mapstructure:"path_style"`
}
// NewStorage creates a new storage backend based on configuration
func NewStorage(config Config) (Storage, error) {
switch config.Backend {
case "local", "":
return NewLocalStorage(config.Local)
case "s3":
return NewS3Storage(config.S3)
default:
return nil, &UnsupportedBackendError{Backend: config.Backend}
}
}
// UnsupportedBackendError is returned when an unknown storage backend is requested
type UnsupportedBackendError struct {
Backend string
}
func (e *UnsupportedBackendError) Error() string {
return "unsupported storage backend: " + e.Backend
}

View File

@@ -0,0 +1,101 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewStorage(t *testing.T) {
t.Run("Local Storage", func(t *testing.T) {
config := Config{
Backend: "local",
Local: LocalConfig{
RootDir: t.TempDir(),
},
}
storage, err := NewStorage(config)
assert.NoError(t, err)
assert.IsType(t, &LocalStorage{}, storage)
})
t.Run("Default to Local Storage", func(t *testing.T) {
config := Config{
// No backend specified
Local: LocalConfig{
RootDir: t.TempDir(),
},
}
storage, err := NewStorage(config)
assert.NoError(t, err)
assert.IsType(t, &LocalStorage{}, storage)
})
t.Run("S3 Storage", func(t *testing.T) {
config := Config{
Backend: "s3",
S3: S3Config{
Region: "us-east-1",
Bucket: "test-bucket",
},
}
// This might succeed if AWS credentials are available via IAM roles or env vars
// Let's just check that we get an S3Storage instance or an error
storage, err := NewStorage(config)
if err != nil {
// If it fails, that's expected in test environments without AWS access
assert.Nil(t, storage)
} else {
// If it succeeds, we should get an S3Storage instance
assert.IsType(t, &S3Storage{}, storage)
}
})
t.Run("B2 Storage", func(t *testing.T) {
config := Config{
Backend: "b2",
B2: B2Config{
AccountID: "test-account",
ApplicationKey: "test-key",
Bucket: "test-bucket",
},
}
// This will fail because we don't have real B2 credentials
storage, err := NewStorage(config)
assert.Error(t, err) // Expected to fail without credentials
assert.Nil(t, storage)
})
t.Run("Unsupported Backend", func(t *testing.T) {
config := Config{
Backend: "unsupported",
}
storage, err := NewStorage(config)
assert.Error(t, err)
assert.IsType(t, &UnsupportedBackendError{}, err)
assert.Nil(t, storage)
assert.Contains(t, err.Error(), "unsupported")
})
}
func TestStorageErrors(t *testing.T) {
t.Run("UnsupportedBackendError", func(t *testing.T) {
err := &UnsupportedBackendError{Backend: "ftp"}
assert.Equal(t, "unsupported storage backend: ftp", err.Error())
})
t.Run("FileNotFoundError", func(t *testing.T) {
err := &FileNotFoundError{Path: "missing.txt"}
assert.Equal(t, "file not found: missing.txt", err.Error())
})
t.Run("InvalidPathError", func(t *testing.T) {
err := &InvalidPathError{Path: "../../../etc/passwd"}
assert.Equal(t, "invalid path: ../../../etc/passwd", err.Error())
})
}

243
core/storage/local.go Normal file
View File

@@ -0,0 +1,243 @@
package storage
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/openaccounting/oa-server/core/util/id"
)
// LocalStorage implements the Storage interface for local filesystem
type LocalStorage struct {
rootDir string
baseURL string
}
// NewLocalStorage creates a new local filesystem storage backend
func NewLocalStorage(config LocalConfig) (*LocalStorage, error) {
rootDir := config.RootDir
if rootDir == "" {
rootDir = "./uploads"
}
// Ensure the root directory exists
if err := os.MkdirAll(rootDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create storage directory: %w", err)
}
return &LocalStorage{
rootDir: rootDir,
baseURL: config.BaseURL,
}, nil
}
// Store saves a file to the local filesystem
func (l *LocalStorage) Store(filename string, content io.Reader, contentType string) (string, error) {
// Generate a unique storage path
storagePath := l.generateStoragePath(filename)
fullPath := filepath.Join(l.rootDir, storagePath)
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
// Create and write the file
file, err := os.Create(fullPath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
_, err = io.Copy(file, content)
if err != nil {
// Clean up the file if write failed
os.Remove(fullPath)
return "", fmt.Errorf("failed to write file: %w", err)
}
return storagePath, nil
}
// Retrieve gets a file from the local filesystem
func (l *LocalStorage) Retrieve(path string) (io.ReadCloser, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return nil, err
}
fullPath := filepath.Join(l.rootDir, path)
file, err := os.Open(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, &FileNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to open file: %w", err)
}
return file, nil
}
// Delete removes a file from the local filesystem
func (l *LocalStorage) Delete(path string) error {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return err
}
fullPath := filepath.Join(l.rootDir, path)
err := os.Remove(fullPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete file: %w", err)
}
// Try to remove empty parent directories
l.cleanupEmptyDirs(filepath.Dir(fullPath))
return nil
}
// GetURL returns a URL for accessing the file
func (l *LocalStorage) GetURL(path string, expiry time.Duration) (string, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return "", err
}
// Check if file exists
exists, err := l.Exists(path)
if err != nil {
return "", err
}
if !exists {
return "", &FileNotFoundError{Path: path}
}
if l.baseURL != "" {
// Return a public URL if base URL is configured
return l.baseURL + "/" + path, nil
}
// For local storage without a base URL, return the file path
// In a real application, you might serve these through an endpoint
return "/files/" + path, nil
}
// Exists checks if a file exists at the given path
func (l *LocalStorage) Exists(path string) (bool, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return false, err
}
fullPath := filepath.Join(l.rootDir, path)
_, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check file existence: %w", err)
}
return true, nil
}
// GetMetadata returns file metadata
func (l *LocalStorage) GetMetadata(path string) (*FileMetadata, error) {
// Validate path to prevent directory traversal
if err := l.validatePath(path); err != nil {
return nil, err
}
fullPath := filepath.Join(l.rootDir, path)
info, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, &FileNotFoundError{Path: path}
}
return nil, fmt.Errorf("failed to get file metadata: %w", err)
}
return &FileMetadata{
Size: info.Size(),
LastModified: info.ModTime(),
ContentType: "", // Local storage doesn't store content type
ETag: "", // Local storage doesn't have ETags
}, nil
}
// generateStoragePath creates a unique storage path for a file
func (l *LocalStorage) generateStoragePath(filename string) string {
// Generate a unique ID for the file
fileID := id.String(id.New())
// Extract file extension
ext := filepath.Ext(filename)
// Create a path structure: YYYY/MM/DD/uuid.ext
now := time.Now()
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
return filepath.Join(datePath, fileID+ext)
}
// validatePath ensures the path doesn't contain directory traversal attempts
func (l *LocalStorage) validatePath(path string) error {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Reject paths that try to go up directories
if strings.Contains(cleanPath, "..") {
return &InvalidPathError{Path: path}
}
// Reject absolute paths
if filepath.IsAbs(cleanPath) {
return &InvalidPathError{Path: path}
}
return nil
}
// cleanupEmptyDirs removes empty parent directories up to the root
func (l *LocalStorage) cleanupEmptyDirs(dir string) {
// Don't remove the root directory
if dir == l.rootDir {
return
}
// Check if directory is empty
entries, err := os.ReadDir(dir)
if err != nil || len(entries) > 0 {
return
}
// Remove empty directory
if err := os.Remove(dir); err == nil {
// Recursively clean parent directories
l.cleanupEmptyDirs(filepath.Dir(dir))
}
}
// FileNotFoundError is returned when a file doesn't exist
type FileNotFoundError struct {
Path string
}
func (e *FileNotFoundError) Error() string {
return "file not found: " + e.Path
}
// InvalidPathError is returned when a path is invalid or contains traversal attempts
type InvalidPathError struct {
Path string
}
func (e *InvalidPathError) Error() string {
return "invalid path: " + e.Path
}

202
core/storage/local_test.go Normal file
View File

@@ -0,0 +1,202 @@
package storage
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLocalStorage(t *testing.T) {
// Create temporary directory for testing
tmpDir := t.TempDir()
config := LocalConfig{
RootDir: tmpDir,
BaseURL: "http://localhost:8080/files",
}
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.NotNil(t, storage)
t.Run("Store and Retrieve File", func(t *testing.T) {
content := []byte("test file content")
reader := bytes.NewReader(content)
// Store file
path, err := storage.Store("test.txt", reader, "text/plain")
assert.NoError(t, err)
assert.NotEmpty(t, path)
// Verify file exists
exists, err := storage.Exists(path)
assert.NoError(t, err)
assert.True(t, exists)
// Retrieve file
retrievedReader, err := storage.Retrieve(path)
assert.NoError(t, err)
defer retrievedReader.Close()
retrievedContent, err := io.ReadAll(retrievedReader)
assert.NoError(t, err)
assert.Equal(t, content, retrievedContent)
})
t.Run("Get File Metadata", func(t *testing.T) {
content := []byte("metadata test content")
reader := bytes.NewReader(content)
path, err := storage.Store("metadata.txt", reader, "text/plain")
assert.NoError(t, err)
metadata, err := storage.GetMetadata(path)
assert.NoError(t, err)
assert.Equal(t, int64(len(content)), metadata.Size)
assert.False(t, metadata.LastModified.IsZero())
})
t.Run("Get File URL", func(t *testing.T) {
content := []byte("url test content")
reader := bytes.NewReader(content)
path, err := storage.Store("url.txt", reader, "text/plain")
assert.NoError(t, err)
url, err := storage.GetURL(path, time.Hour)
assert.NoError(t, err)
assert.Contains(t, url, path)
assert.Contains(t, url, config.BaseURL)
})
t.Run("Delete File", func(t *testing.T) {
content := []byte("delete test content")
reader := bytes.NewReader(content)
path, err := storage.Store("delete.txt", reader, "text/plain")
assert.NoError(t, err)
// Verify file exists
exists, err := storage.Exists(path)
assert.NoError(t, err)
assert.True(t, exists)
// Delete file
err = storage.Delete(path)
assert.NoError(t, err)
// Verify file no longer exists
exists, err = storage.Exists(path)
assert.NoError(t, err)
assert.False(t, exists)
})
t.Run("Path Validation", func(t *testing.T) {
// Test directory traversal prevention
_, err := storage.Retrieve("../../../etc/passwd")
assert.Error(t, err)
assert.IsType(t, &InvalidPathError{}, err)
// Test absolute path rejection
_, err = storage.Retrieve("/etc/passwd")
assert.Error(t, err)
assert.IsType(t, &InvalidPathError{}, err)
})
t.Run("File Not Found", func(t *testing.T) {
_, err := storage.Retrieve("nonexistent.txt")
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
_, err = storage.GetMetadata("nonexistent.txt")
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
_, err = storage.GetURL("nonexistent.txt", time.Hour)
assert.Error(t, err)
assert.IsType(t, &FileNotFoundError{}, err)
})
t.Run("Storage Path Generation", func(t *testing.T) {
content := []byte("path test content")
reader1 := bytes.NewReader(content)
reader2 := bytes.NewReader(content)
// Store two files with same name
path1, err := storage.Store("same.txt", reader1, "text/plain")
assert.NoError(t, err)
path2, err := storage.Store("same.txt", reader2, "text/plain")
assert.NoError(t, err)
// Paths should be different (unique)
assert.NotEqual(t, path1, path2)
// Both should exist
exists1, err := storage.Exists(path1)
assert.NoError(t, err)
assert.True(t, exists1)
exists2, err := storage.Exists(path2)
assert.NoError(t, err)
assert.True(t, exists2)
// Both should have correct extension
assert.True(t, strings.HasSuffix(path1, ".txt"))
assert.True(t, strings.HasSuffix(path2, ".txt"))
// Should be organized by date
now := time.Now()
expectedPrefix := filepath.Join(
fmt.Sprintf("%04d", now.Year()),
fmt.Sprintf("%02d", now.Month()),
fmt.Sprintf("%02d", now.Day()),
)
assert.True(t, strings.HasPrefix(path1, expectedPrefix))
assert.True(t, strings.HasPrefix(path2, expectedPrefix))
})
}
func TestLocalStorageConfig(t *testing.T) {
t.Run("Default Root Directory", func(t *testing.T) {
config := LocalConfig{} // Empty config
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.NotNil(t, storage)
// Should create default uploads directory
assert.Equal(t, "./uploads", storage.rootDir)
// Verify directory was created
_, err = os.Stat("./uploads")
assert.NoError(t, err)
// Clean up
os.RemoveAll("./uploads")
})
t.Run("Custom Root Directory", func(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom", "storage")
config := LocalConfig{
RootDir: customDir,
}
storage, err := NewLocalStorage(config)
assert.NoError(t, err)
assert.Equal(t, customDir, storage.rootDir)
// Verify custom directory was created
_, err = os.Stat(customDir)
assert.NoError(t, err)
})
}

236
core/storage/s3.go Normal file
View File

@@ -0,0 +1,236 @@
package storage
import (
"fmt"
"io"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/openaccounting/oa-server/core/util/id"
)
// S3Storage implements the Storage interface for Amazon S3
type S3Storage struct {
client *s3.S3
uploader *s3manager.Uploader
bucket string
prefix string
}
// NewS3Storage creates a new S3 storage backend
func NewS3Storage(config S3Config) (*S3Storage, error) {
if config.Bucket == "" {
return nil, fmt.Errorf("S3 bucket name is required")
}
// Create AWS config
awsConfig := &aws.Config{
Region: aws.String(config.Region),
}
// Set custom endpoint if provided (for S3-compatible services)
if config.Endpoint != "" {
awsConfig.Endpoint = aws.String(config.Endpoint)
awsConfig.S3ForcePathStyle = aws.Bool(config.PathStyle)
}
// Set credentials if provided
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
awsConfig.Credentials = credentials.NewStaticCredentials(
config.AccessKeyID,
config.SecretAccessKey,
"",
)
}
// Create session
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create AWS session: %w", err)
}
// Create S3 client
client := s3.New(sess)
uploader := s3manager.NewUploader(sess)
return &S3Storage{
client: client,
uploader: uploader,
bucket: config.Bucket,
prefix: config.Prefix,
}, nil
}
// Store saves a file to S3
func (s *S3Storage) Store(filename string, content io.Reader, contentType string) (string, error) {
// Generate a unique storage key
storageKey := s.generateStorageKey(filename)
// Prepare upload input
input := &s3manager.UploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(storageKey),
Body: content,
}
// Set content type if provided
if contentType != "" {
input.ContentType = aws.String(contentType)
}
// Upload the file
_, err := s.uploader.Upload(input)
if err != nil {
return "", fmt.Errorf("failed to upload file to S3: %w", err)
}
return storageKey, nil
}
// Retrieve gets a file from S3
func (s *S3Storage) Retrieve(path string) (io.ReadCloser, error) {
input := &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
result, err := s.client.GetObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey:
return nil, &FileNotFoundError{Path: path}
}
}
return nil, fmt.Errorf("failed to retrieve file from S3: %w", err)
}
return result.Body, nil
}
// Delete removes a file from S3
func (s *S3Storage) Delete(path string) error {
input := &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
_, err := s.client.DeleteObject(input)
if err != nil {
return fmt.Errorf("failed to delete file from S3: %w", err)
}
return nil
}
// GetURL returns a presigned URL for accessing the file
func (s *S3Storage) GetURL(path string, expiry time.Duration) (string, error) {
// Check if file exists first
exists, err := s.Exists(path)
if err != nil {
return "", err
}
if !exists {
return "", &FileNotFoundError{Path: path}
}
// Generate presigned URL
req, _ := s.client.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
})
url, err := req.Presign(expiry)
if err != nil {
return "", fmt.Errorf("failed to generate presigned URL: %w", err)
}
return url, nil
}
// Exists checks if a file exists in S3
func (s *S3Storage) Exists(path string) (bool, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
_, err := s.client.HeadObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey, "NotFound":
return false, nil
}
}
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
}
return true, nil
}
// GetMetadata returns file metadata from S3
func (s *S3Storage) GetMetadata(path string) (*FileMetadata, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(path),
}
result, err := s.client.HeadObject(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case s3.ErrCodeNoSuchKey, "NotFound":
return nil, &FileNotFoundError{Path: path}
}
}
return nil, fmt.Errorf("failed to get file metadata from S3: %w", err)
}
metadata := &FileMetadata{
Size: aws.Int64Value(result.ContentLength),
}
if result.LastModified != nil {
metadata.LastModified = *result.LastModified
}
if result.ContentType != nil {
metadata.ContentType = *result.ContentType
}
if result.ETag != nil {
metadata.ETag = strings.Trim(*result.ETag, "\"")
}
return metadata, nil
}
// generateStorageKey creates a unique storage key for a file
func (s *S3Storage) generateStorageKey(filename string) string {
// Generate a unique ID for the file
fileID := id.String(id.New())
// Extract file extension
ext := path.Ext(filename)
// Create a key structure: prefix/YYYY/MM/DD/uuid.ext
now := time.Now()
datePath := fmt.Sprintf("%04d/%02d/%02d", now.Year(), now.Month(), now.Day())
key := path.Join(datePath, fileID+ext)
// Add prefix if configured
if s.prefix != "" {
key = path.Join(s.prefix, key)
}
return key
}

View File

@@ -3,6 +3,7 @@ package util
import (
"crypto/rand"
"encoding/hex"
"regexp"
"time"
)
@@ -44,3 +45,23 @@ func NewInviteId() (string, error) {
return hex.EncodeToString(byteArray), nil
}
func NewUUID() string {
guid, err := NewGuid()
if err != nil {
// Fallback to timestamp-based UUID if random generation fails
return hex.EncodeToString([]byte(time.Now().Format("20060102150405")))
}
return guid
}
func IsValidUUID(uuid string) bool {
// Check if the string is a valid 32-character hex string (16 bytes * 2 hex chars)
if len(uuid) != 32 {
return false
}
// Check if all characters are valid hex characters
matched, _ := regexp.MatchString("^[0-9a-f]{32}$", uuid)
return matched
}

View File

@@ -67,7 +67,7 @@ func Handler(w rest.ResponseWriter, r *rest.Request) {
continue
}
log.Printf("recv: %s", message)
log.Printf("recv: %+v", message)
// check version
err = checkVersion(message.Version)

View File

@@ -102,6 +102,7 @@ func AutoMigrate() error {
&models.APIKey{},
&models.Invite{},
&models.BudgetItem{},
&models.Attachment{},
)
}
@@ -131,6 +132,10 @@ func createIndexes() error {
"CREATE INDEX IF NOT EXISTS split_date_index ON splits(date)",
"CREATE INDEX IF NOT EXISTS split_updated_index ON splits(updated)",
"CREATE INDEX IF NOT EXISTS budgetitem_orgId_index ON budget_items(orgId)",
"CREATE INDEX IF NOT EXISTS attachment_transactionId_index ON attachment(transactionId)",
"CREATE INDEX IF NOT EXISTS attachment_orgId_index ON attachment(orgId)",
"CREATE INDEX IF NOT EXISTS attachment_userId_index ON attachment(userId)",
"CREATE INDEX IF NOT EXISTS attachment_uploaded_index ON attachment(uploaded)",
// Additional useful indexes for performance
"CREATE INDEX IF NOT EXISTS idx_transaction_date ON transactions(date)",

BIN
dev.db Normal file

Binary file not shown.

4
go.mod
View File

@@ -5,6 +5,7 @@ go 1.24.2
require (
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01
github.com/aws/aws-sdk-go v1.44.0
github.com/go-sql-driver/mysql v1.8.1
github.com/gorilla/websocket v0.0.0-20180605202552-5ed622c449da
github.com/mailgun/mailgun-go/v4 v4.3.0
@@ -18,6 +19,7 @@ require (
require (
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
@@ -42,7 +44,7 @@ require (
github.com/json-iterator/go v1.1.10 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect
github.com/pkg/errors v0.8.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/text v0.21.0 // indirect

19
go.sum
View File

@@ -4,6 +4,8 @@ github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf h1:BMUJnVJI5J50
github.com/Masterminds/semver v0.0.0-20180807142431-c84ddcca87bf/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01 h1:oYAjCHMjyRaNBo3nUEepDce4LC+Kuh+6jU6y+AllvnU=
github.com/ant0ine/go-json-rest v0.0.0-20170913041208-ebb33769ae01/go.mod h1:q6aCt0GfU6LhpBsnZ/2U+mwe+0XB5WStbmwyoPfc+sk=
github.com/aws/aws-sdk-go v1.44.0 h1:jwtHuNqfnJxL4DKHBUVUmQlfueQqBW7oXP6yebZR/R0=
github.com/aws/aws-sdk-go v1.44.0/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -34,6 +36,10 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -52,8 +58,9 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLD
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
@@ -84,13 +91,23 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=

View File

@@ -3,4 +3,8 @@ CREATE INDEX split_accountId_index ON split (accountId);
CREATE INDEX split_transactionId_index ON split (transactionId);
CREATE INDEX split_date_index ON split (date);
CREATE INDEX split_updated_index ON split (updated);
CREATE INDEX budgetitem_orgId_index ON budgetitem (orgId);
CREATE INDEX budgetitem_orgId_index ON budgetitem (orgId);
CREATE INDEX attachment_transactionId_index ON attachment (transactionId);
CREATE INDEX attachment_orgId_index ON attachment (orgId);
CREATE INDEX attachment_userId_index ON attachment (userId);
CREATE INDEX attachment_uploaded_index ON attachment (uploaded);

View File

@@ -22,7 +22,7 @@ run: build
# Run with custom environment
run-dev: build
@echo "Starting server in development mode..."
OA_DATABASE_DRIVER=sqlite OA_DATABASE_FILE=./dev.db OA_PORT=8080 ./server
OA_DATABASEDRIVER=sqlite OA_DATABASEFILE=./dev.db OA_PORT=8080 ./server
# Run tests
test:

28
models/attachment.go Normal file
View File

@@ -0,0 +1,28 @@
package models
import (
"time"
)
type Attachment struct {
ID []byte `gorm:"type:BINARY(16);primaryKey"`
TransactionID []byte `gorm:"column:transactionId;type:BINARY(16);not null"`
OrgID []byte `gorm:"column:orgId;type:BINARY(16);not null"`
UserID []byte `gorm:"column:userId;type:BINARY(16);not null"`
FileName string `gorm:"column:fileName;size:255;not null"`
OriginalName string `gorm:"column:originalName;size:255;not null"`
ContentType string `gorm:"column:contentType;size:100;not null"`
FileSize int64 `gorm:"column:fileSize;not null"`
FilePath string `gorm:"column:filePath;size:500;not null"`
Description string `gorm:"column:description;size:500"`
Uploaded time.Time `gorm:"column:uploaded;not null"`
Deleted bool `gorm:"column:deleted;default:false"`
Transaction Transaction `gorm:"foreignKey:TransactionID"`
Org Org `gorm:"foreignKey:OrgID"`
User User `gorm:"foreignKey:UserID"`
}
func (Attachment) TableName() string {
return "attachment"
}

View File

@@ -30,4 +30,6 @@ CREATE TABLE apikey (id BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL,
CREATE TABLE invite (id VARCHAR(32) NOT NULL, orgId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, updated BIGINT UNSIGNED NOT NULL, email VARCHAR(100) NOT NULL, accepted BOOLEAN NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
CREATE TABLE budgetitem (id INT UNSIGNED NOT NULL AUTO_INCREMENT, orgId BINARY(16) NOT NULL, accountId BINARY(16) NOT NULL, inserted BIGINT UNSIGNED NOT NULL, amount BIGINT NOT NULL, PRIMARY KEY(id)) ENGINE=InnoDB;
CREATE TABLE attachment (id BINARY(16) NOT NULL, transactionId BINARY(16) NOT NULL, orgId BINARY(16) NOT NULL, userId BINARY(16) NOT NULL, fileName VARCHAR(255) NOT NULL, originalName VARCHAR(255) NOT NULL, contentType VARCHAR(100) NOT NULL, fileSize BIGINT NOT NULL, filePath VARCHAR(500) NOT NULL, description VARCHAR(500), uploaded BIGINT UNSIGNED NOT NULL, deleted BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY(id)) ENGINE=InnoDB;

View File

@@ -1,27 +0,0 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,14 +0,0 @@
# filippo.io/edwards25519
```
import "filippo.io/edwards25519"
```
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.

View File

@@ -1,20 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/internal/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519

View File

@@ -1,427 +0,0 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"errors"
"filippo.io/edwards25519/field"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is NOT valid, and it may be used only as a receiver.
type Point struct {
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t field.Element
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d field.Element
}
// Constructors.
func (v *projP2) Zero() *projP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *affineCached) Zero() *affineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
// Set sets v = u, and returns v.
func (v *Point) Set(u *Point) *Point {
*v = *u
return v
}
// Encoding.
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytes(&buf)
}
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
// Note that SetBytes accepts all non-canonical encodings of valid points.
// That is, it follows decoding rules that match most implementations in
// the ecosystem rather than RFC 8032.
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
// specifically the "Canonical A, R" section.
y, err := new(field.Element).SetBytes(x)
if err != nil {
return nil, errors.New("edwards25519: invalid point encoding length")
}
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
v.z.One()
v.t.Multiply(xx, y) // xy = T / Z
return v, nil
}
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
// Conversions.
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
v.X.Multiply(&p.X, &p.T)
v.Y.Multiply(&p.Y, &p.Z)
v.Z.Multiply(&p.Z, &p.T)
return v
}
func (v *projP2) FromP3(p *Point) *projP2 {
v.X.Set(&p.x)
v.Y.Set(&p.y)
v.Z.Set(&p.z)
return v
}
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
v.x.Multiply(&p.X, &p.T)
v.y.Multiply(&p.Y, &p.Z)
v.z.Multiply(&p.Z, &p.T)
v.t.Multiply(&p.X, &p.Y)
return v
}
func (v *Point) fromP2(p *projP2) *Point {
v.x.Multiply(&p.X, &p.Z)
v.y.Multiply(&p.Y, &p.Z)
v.z.Square(&p.Z)
v.t.Multiply(&p.X, &p.Y)
return v
}
// d is a constant in the curve equation.
var d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.Z.Set(&p.z)
v.T2d.Multiply(&p.t, d2)
return v
}
func (v *affineCached) FromP3(p *Point) *affineCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
v.T2d.Multiply(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Subtract(&ZZ2, &TT2d)
return v
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Subtract(&Z2, &TT2d)
return v
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Subtract(&YY, &XX)
v.X.Subtract(&XplusYsq, &v.Y)
v.T.Subtract(&ZZ2, &v.Z)
return v
}
// Negation.
// Negate sets v = -p, and returns v.
func (v *Point) Negate(p *Point) *Point {
checkInitialized(p)
v.x.Negate(&p.x)
v.y.Set(&p.y)
v.z.Set(&p.z)
v.t.Negate(&p.t)
return v
}
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
t4.Multiply(&u.y, &v.z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}

View File

@@ -1,349 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// This file contains additional functionality that is not included in the
// upstream crypto/internal/edwards25519 package.
import (
"errors"
"filippo.io/edwards25519/field"
)
// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap. Don't change the style without making
// sure it doesn't increase the inliner cost.
var e [4]field.Element
X, Y, Z, T = v.extendedCoordinates(&e)
return
}
func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
checkInitialized(v)
X = e[0].Set(&v.x)
Y = e[1].Set(&v.y)
Z = e[2].Set(&v.z)
T = e[3].Set(&v.t)
return
}
// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
//
// If the coordinates are invalid or don't represent a valid point on the curve,
// SetExtendedCoordinates returns nil and an error and the receiver is
// unchanged. Otherwise, SetExtendedCoordinates returns v.
func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
if !isOnCurve(X, Y, Z, T) {
return nil, errors.New("edwards25519: invalid point coordinates")
}
v.x.Set(X)
v.y.Set(Y)
v.z.Set(Z)
v.t.Set(T)
return v, nil
}
func isOnCurve(X, Y, Z, T *field.Element) bool {
var lhs, rhs field.Element
XX := new(field.Element).Square(X)
YY := new(field.Element).Square(Y)
ZZ := new(field.Element).Square(Z)
TT := new(field.Element).Square(T)
// -x² + y² = 1 + dx²y²
// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
// -X² + Y² = Z² + dT²
lhs.Subtract(YY, XX)
rhs.Multiply(d, TT).Add(&rhs, ZZ)
if lhs.Equal(&rhs) != 1 {
return false
}
// xy = T/Z
// XY/Z² = T/Z
// XY = TZ
lhs.Multiply(X, Y)
rhs.Multiply(T, Z)
return lhs.Equal(&rhs) == 1
}
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
//
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
// while every valid edwards25519 point has a unique u-coordinate Montgomery
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
// to any edwards25519 point, and every other X25519 input corresponds to two
// edwards25519 points.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert returns zero.
func (s *Scalar) Invert(t *Scalar) *Scalar {
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

View File

@@ -1,420 +0,0 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"errors"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
// not of the right length, SetBytes returns nil and an error, and the
// receiver is unchanged.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032, but
// consistent with most Ed25519 implementations.
func (v *Element) SetBytes(x []byte) (*Element, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v, nil
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
t0 := new(Element)
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := new(Element).Square(v)
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
uNeg := new(Element).Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(rr) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}

View File

@@ -1,16 +0,0 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape
func feSquare(out *Element, a *Element)

View File

@@ -1,379 +0,0 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET

View File

@@ -1,12 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

View File

@@ -1,16 +0,0 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

View File

@@ -1,42 +0,0 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET

View File

@@ -1,12 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

View File

@@ -1,50 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "errors"
// This file contains additional functionality that is not included in the
// upstream crypto/ed25519/edwards25519/field package.
// SetWideBytes sets v to x, where x is a 64-byte little-endian encoding, which
// is reduced modulo the field order. If x is not of the right length,
// SetWideBytes returns nil and an error, and the receiver is unchanged.
//
// SetWideBytes is not necessary to select a uniformly distributed value, and is
// only provided for compatibility: SetBytes can be used instead as the chance
// of bias is less than 2⁻²⁵⁰.
func (v *Element) SetWideBytes(x []byte) (*Element, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetWideBytes input size")
}
// Split the 64 bytes into two elements, and extract the most significant
// bit of each, which is ignored by SetBytes.
lo, _ := new(Element).SetBytes(x[:32])
loMSB := uint64(x[31] >> 7)
hi, _ := new(Element).SetBytes(x[32:])
hiMSB := uint64(x[63] >> 7)
// The output we want is
//
// v = lo + loMSB * 2²⁵⁵ + hi * 2²⁵⁶ + hiMSB * 2⁵¹¹
//
// which applying the reduction identity comes out to
//
// v = lo + loMSB * 19 + hi * 2 * 19 + hiMSB * 2 * 19²
//
// l0 will be the sum of a 52 bits value (lo.l0), plus a 5 bits value
// (loMSB * 19), a 6 bits value (hi.l0 * 2 * 19), and a 10 bits value
// (hiMSB * 2 * 19²), so it fits in a uint64.
v.l0 = lo.l0 + loMSB*19 + hi.l0*2*19 + hiMSB*2*19*19
v.l1 = lo.l1 + hi.l1*2*19
v.l2 = lo.l2 + hi.l2*2*19
v.l3 = lo.l3 + hi.l3*2*19
v.l4 = lo.l4 + hi.l4*2*19
return v.carryPropagate(), nil
}

View File

@@ -1,266 +0,0 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}

View File

@@ -1,343 +0,0 @@
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to a uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,214 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

View File

@@ -1,129 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}

View File

@@ -1,27 +0,0 @@
language: go
go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- tip
# Setting sudo access to false will let Travis CI use containers rather than
# VMs to run the tests. For more details see:
# - http://docs.travis-ci.com/user/workers/container-based-infrastructure/
# - http://docs.travis-ci.com/user/workers/standard-infrastructure/
sudo: false
script:
- make setup
- make test
notifications:
webhooks:
urls:
- https://webhooks.gitter.im/e/06e3328629952dabe3e0
on_success: change # options: [always|never|change] default: always
on_failure: always # options: [always|never|change] default: always
on_start: never # options: [always|never|change] default: always

View File

@@ -1,86 +0,0 @@
# 1.4.2 (2018-04-10)
## Changed
- #72: Updated the docs to point to vert for a console appliaction
- #71: Update the docs on pre-release comparator handling
## Fixed
- #70: Fix the handling of pre-releases and the 0.0.0 release edge case
# 1.4.1 (2018-04-02)
## Fixed
- Fixed #64: Fix pre-release precedence issue (thanks @uudashr)
# 1.4.0 (2017-10-04)
## Changed
- #61: Update NewVersion to parse ints with a 64bit int size (thanks @zknill)
# 1.3.1 (2017-07-10)
## Fixed
- Fixed #57: number comparisons in prerelease sometimes inaccurate
# 1.3.0 (2017-05-02)
## Added
- #45: Added json (un)marshaling support (thanks @mh-cbon)
- Stability marker. See https://masterminds.github.io/stability/
## Fixed
- #51: Fix handling of single digit tilde constraint (thanks @dgodd)
## Changed
- #55: The godoc icon moved from png to svg
# 1.2.3 (2017-04-03)
## Fixed
- #46: Fixed 0.x.x and 0.0.x in constraints being treated as *
# Release 1.2.2 (2016-12-13)
## Fixed
- #34: Fixed issue where hyphen range was not working with pre-release parsing.
# Release 1.2.1 (2016-11-28)
## Fixed
- #24: Fixed edge case issue where constraint "> 0" does not handle "0.0.1-alpha"
properly.
# Release 1.2.0 (2016-11-04)
## Added
- #20: Added MustParse function for versions (thanks @adamreese)
- #15: Added increment methods on versions (thanks @mh-cbon)
## Fixed
- Issue #21: Per the SemVer spec (section 9) a pre-release is unstable and
might not satisfy the intended compatibility. The change here ignores pre-releases
on constraint checks (e.g., ~ or ^) when a pre-release is not part of the
constraint. For example, `^1.2.3` will ignore pre-releases while
`^1.2.3-alpha` will include them.
# Release 1.1.1 (2016-06-30)
## Changed
- Issue #9: Speed up version comparison performance (thanks @sdboyer)
- Issue #8: Added benchmarks (thanks @sdboyer)
- Updated Go Report Card URL to new location
- Updated Readme to add code snippet formatting (thanks @mh-cbon)
- Updating tagging to v[SemVer] structure for compatibility with other tools.
# Release 1.1.0 (2016-03-11)
- Issue #2: Implemented validation to provide reasons a versions failed a
constraint.
# Release 1.0.1 (2015-12-31)
- Fixed #1: * constraint failing on valid versions.
# Release 1.0.0 (2015-10-20)
- Initial release

View File

@@ -1,20 +0,0 @@
The Masterminds
Copyright (C) 2014-2015, Matt Butcher and Matt Farina
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -1,36 +0,0 @@
.PHONY: setup
setup:
go get -u gopkg.in/alecthomas/gometalinter.v1
gometalinter.v1 --install
.PHONY: test
test: validate lint
@echo "==> Running tests"
go test -v
.PHONY: validate
validate:
@echo "==> Running static validations"
@gometalinter.v1 \
--disable-all \
--enable deadcode \
--severity deadcode:error \
--enable gofmt \
--enable gosimple \
--enable ineffassign \
--enable misspell \
--enable vet \
--tests \
--vendor \
--deadline 60s \
./... || exit_code=1
.PHONY: lint
lint:
@echo "==> Running linters"
@gometalinter.v1 \
--disable-all \
--enable golint \
--vendor \
--deadline 60s \
./... || :

View File

@@ -1,186 +0,0 @@
# SemVer
The `semver` package provides the ability to work with [Semantic Versions](http://semver.org) in Go. Specifically it provides the ability to:
* Parse semantic versions
* Sort semantic versions
* Check if a semantic version fits within a set of constraints
* Optionally work with a `v` prefix
[![Stability:
Active](https://masterminds.github.io/stability/active.svg)](https://masterminds.github.io/stability/active.html)
[![Build Status](https://travis-ci.org/Masterminds/semver.svg)](https://travis-ci.org/Masterminds/semver) [![Build status](https://ci.appveyor.com/api/projects/status/jfk66lib7hb985k8/branch/master?svg=true&passingText=windows%20build%20passing&failingText=windows%20build%20failing)](https://ci.appveyor.com/project/mattfarina/semver/branch/master) [![GoDoc](https://godoc.org/github.com/Masterminds/semver?status.svg)](https://godoc.org/github.com/Masterminds/semver) [![Go Report Card](https://goreportcard.com/badge/github.com/Masterminds/semver)](https://goreportcard.com/report/github.com/Masterminds/semver)
If you are looking for a command line tool for version comparisons please see
[vert](https://github.com/Masterminds/vert) which uses this library.
## Parsing Semantic Versions
To parse a semantic version use the `NewVersion` function. For example,
```go
v, err := semver.NewVersion("1.2.3-beta.1+build345")
```
If there is an error the version wasn't parseable. The version object has methods
to get the parts of the version, compare it to other versions, convert the
version back into a string, and get the original string. For more details
please see the [documentation](https://godoc.org/github.com/Masterminds/semver).
## Sorting Semantic Versions
A set of versions can be sorted using the [`sort`](https://golang.org/pkg/sort/)
package from the standard library. For example,
```go
raw := []string{"1.2.3", "1.0", "1.3", "2", "0.4.2",}
vs := make([]*semver.Version, len(raw))
for i, r := range raw {
v, err := semver.NewVersion(r)
if err != nil {
t.Errorf("Error parsing version: %s", err)
}
vs[i] = v
}
sort.Sort(semver.Collection(vs))
```
## Checking Version Constraints
Checking a version against version constraints is one of the most featureful
parts of the package.
```go
c, err := semver.NewConstraint(">= 1.2.3")
if err != nil {
// Handle constraint not being parseable.
}
v, _ := semver.NewVersion("1.3")
if err != nil {
// Handle version not being parseable.
}
// Check if the version meets the constraints. The a variable will be true.
a := c.Check(v)
```
## Basic Comparisons
There are two elements to the comparisons. First, a comparison string is a list
of comma separated and comparisons. These are then separated by || separated or
comparisons. For example, `">= 1.2, < 3.0.0 || >= 4.2.3"` is looking for a
comparison that's greater than or equal to 1.2 and less than 3.0.0 or is
greater than or equal to 4.2.3.
The basic comparisons are:
* `=`: equal (aliased to no operator)
* `!=`: not equal
* `>`: greater than
* `<`: less than
* `>=`: greater than or equal to
* `<=`: less than or equal to
## Working With Pre-release Versions
Pre-releases, for those not familiar with them, are used for software releases
prior to stable or generally available releases. Examples of pre-releases include
development, alpha, beta, and release candidate releases. A pre-release may be
a version such as `1.2.3-beta.1` while the stable release would be `1.2.3`. In the
order of precidence, pre-releases come before their associated releases. In this
example `1.2.3-beta.1 < 1.2.3`.
According to the Semantic Version specification pre-releases may not be
API compliant with their release counterpart. It says,
> A pre-release version indicates that the version is unstable and might not satisfy the intended compatibility requirements as denoted by its associated normal version.
SemVer comparisons without a pre-release comparator will skip pre-release versions.
For example, `>=1.2.3` will skip pre-releases when looking at a list of releases
while `>=1.2.3-0` will evaluate and find pre-releases.
The reason for the `0` as a pre-release version in the example comparison is
because pre-releases can only contain ASCII alphanumerics and hyphens (along with
`.` separators), per the spec. Sorting happens in ASCII sort order, again per the spec. The lowest character is a `0` in ASCII sort order (see an [ASCII Table](http://www.asciitable.com/))
Understanding ASCII sort ordering is important because A-Z comes before a-z. That
means `>=1.2.3-BETA` will return `1.2.3-alpha`. What you might expect from case
sensitivity doesn't apply here. This is due to ASCII sort ordering which is what
the spec specifies.
## Hyphen Range Comparisons
There are multiple methods to handle ranges and the first is hyphens ranges.
These look like:
* `1.2 - 1.4.5` which is equivalent to `>= 1.2, <= 1.4.5`
* `2.3.4 - 4.5` which is equivalent to `>= 2.3.4, <= 4.5`
## Wildcards In Comparisons
The `x`, `X`, and `*` characters can be used as a wildcard character. This works
for all comparison operators. When used on the `=` operator it falls
back to the pack level comparison (see tilde below). For example,
* `1.2.x` is equivalent to `>= 1.2.0, < 1.3.0`
* `>= 1.2.x` is equivalent to `>= 1.2.0`
* `<= 2.x` is equivalent to `< 3`
* `*` is equivalent to `>= 0.0.0`
## Tilde Range Comparisons (Patch)
The tilde (`~`) comparison operator is for patch level ranges when a minor
version is specified and major level changes when the minor number is missing.
For example,
* `~1.2.3` is equivalent to `>= 1.2.3, < 1.3.0`
* `~1` is equivalent to `>= 1, < 2`
* `~2.3` is equivalent to `>= 2.3, < 2.4`
* `~1.2.x` is equivalent to `>= 1.2.0, < 1.3.0`
* `~1.x` is equivalent to `>= 1, < 2`
## Caret Range Comparisons (Major)
The caret (`^`) comparison operator is for major level changes. This is useful
when comparisons of API versions as a major change is API breaking. For example,
* `^1.2.3` is equivalent to `>= 1.2.3, < 2.0.0`
* `^1.2.x` is equivalent to `>= 1.2.0, < 2.0.0`
* `^2.3` is equivalent to `>= 2.3, < 3`
* `^2.x` is equivalent to `>= 2.0.0, < 3`
# Validation
In addition to testing a version against a constraint, a version can be validated
against a constraint. When validation fails a slice of errors containing why a
version didn't meet the constraint is returned. For example,
```go
c, err := semver.NewConstraint("<= 1.2.3, >= 1.4")
if err != nil {
// Handle constraint not being parseable.
}
v, _ := semver.NewVersion("1.3")
if err != nil {
// Handle version not being parseable.
}
// Validate a version against a constraint.
a, msgs := c.Validate(v)
// a is false
for _, m := range msgs {
fmt.Println(m)
// Loops over the errors which would read
// "1.3 is greater than 1.2.3"
// "1.3 is less than 1.4"
}
```
# Contribute
If you find an issue or want to contribute please file an [issue](https://github.com/Masterminds/semver/issues)
or [create a pull request](https://github.com/Masterminds/semver/pulls).

View File

@@ -1,44 +0,0 @@
version: build-{build}.{branch}
clone_folder: C:\gopath\src\github.com\Masterminds\semver
shallow_clone: true
environment:
GOPATH: C:\gopath
platform:
- x64
install:
- go version
- go env
- go get -u gopkg.in/alecthomas/gometalinter.v1
- set PATH=%PATH%;%GOPATH%\bin
- gometalinter.v1.exe --install
build_script:
- go install -v ./...
test_script:
- "gometalinter.v1 \
--disable-all \
--enable deadcode \
--severity deadcode:error \
--enable gofmt \
--enable gosimple \
--enable ineffassign \
--enable misspell \
--enable vet \
--tests \
--vendor \
--deadline 60s \
./... || exit_code=1"
- "gometalinter.v1 \
--disable-all \
--enable golint \
--vendor \
--deadline 60s \
./... || :"
- go test -v
deploy: off

View File

@@ -1,24 +0,0 @@
package semver
// Collection is a collection of Version instances and implements the sort
// interface. See the sort package for more details.
// https://golang.org/pkg/sort/
type Collection []*Version
// Len returns the length of a collection. The number of Version instances
// on the slice.
func (c Collection) Len() int {
return len(c)
}
// Less is needed for the sort interface to compare two Version objects on the
// slice. If checks if one is less than the other.
func (c Collection) Less(i, j int) bool {
return c[i].LessThan(c[j])
}
// Swap is needed for the sort interface to replace the Version objects
// at two different positions in the slice.
func (c Collection) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}

View File

@@ -1,406 +0,0 @@
package semver
import (
"errors"
"fmt"
"regexp"
"strings"
)
// Constraints is one or more constraint that a semantic version can be
// checked against.
type Constraints struct {
constraints [][]*constraint
}
// NewConstraint returns a Constraints instance that a Version instance can
// be checked against. If there is a parse error it will be returned.
func NewConstraint(c string) (*Constraints, error) {
// Rewrite - ranges into a comparison operation.
c = rewriteRange(c)
ors := strings.Split(c, "||")
or := make([][]*constraint, len(ors))
for k, v := range ors {
cs := strings.Split(v, ",")
result := make([]*constraint, len(cs))
for i, s := range cs {
pc, err := parseConstraint(s)
if err != nil {
return nil, err
}
result[i] = pc
}
or[k] = result
}
o := &Constraints{constraints: or}
return o, nil
}
// Check tests if a version satisfies the constraints.
func (cs Constraints) Check(v *Version) bool {
// loop over the ORs and check the inner ANDs
for _, o := range cs.constraints {
joy := true
for _, c := range o {
if !c.check(v) {
joy = false
break
}
}
if joy {
return true
}
}
return false
}
// Validate checks if a version satisfies a constraint. If not a slice of
// reasons for the failure are returned in addition to a bool.
func (cs Constraints) Validate(v *Version) (bool, []error) {
// loop over the ORs and check the inner ANDs
var e []error
for _, o := range cs.constraints {
joy := true
for _, c := range o {
if !c.check(v) {
em := fmt.Errorf(c.msg, v, c.orig)
e = append(e, em)
joy = false
}
}
if joy {
return true, []error{}
}
}
return false, e
}
var constraintOps map[string]cfunc
var constraintMsg map[string]string
var constraintRegex *regexp.Regexp
func init() {
constraintOps = map[string]cfunc{
"": constraintTildeOrEqual,
"=": constraintTildeOrEqual,
"!=": constraintNotEqual,
">": constraintGreaterThan,
"<": constraintLessThan,
">=": constraintGreaterThanEqual,
"=>": constraintGreaterThanEqual,
"<=": constraintLessThanEqual,
"=<": constraintLessThanEqual,
"~": constraintTilde,
"~>": constraintTilde,
"^": constraintCaret,
}
constraintMsg = map[string]string{
"": "%s is not equal to %s",
"=": "%s is not equal to %s",
"!=": "%s is equal to %s",
">": "%s is less than or equal to %s",
"<": "%s is greater than or equal to %s",
">=": "%s is less than %s",
"=>": "%s is less than %s",
"<=": "%s is greater than %s",
"=<": "%s is greater than %s",
"~": "%s does not have same major and minor version as %s",
"~>": "%s does not have same major and minor version as %s",
"^": "%s does not have same major version as %s",
}
ops := make([]string, 0, len(constraintOps))
for k := range constraintOps {
ops = append(ops, regexp.QuoteMeta(k))
}
constraintRegex = regexp.MustCompile(fmt.Sprintf(
`^\s*(%s)\s*(%s)\s*$`,
strings.Join(ops, "|"),
cvRegex))
constraintRangeRegex = regexp.MustCompile(fmt.Sprintf(
`\s*(%s)\s+-\s+(%s)\s*`,
cvRegex, cvRegex))
}
// An individual constraint
type constraint struct {
// The callback function for the restraint. It performs the logic for
// the constraint.
function cfunc
msg string
// The version used in the constraint check. For example, if a constraint
// is '<= 2.0.0' the con a version instance representing 2.0.0.
con *Version
// The original parsed version (e.g., 4.x from != 4.x)
orig string
// When an x is used as part of the version (e.g., 1.x)
minorDirty bool
dirty bool
patchDirty bool
}
// Check if a version meets the constraint
func (c *constraint) check(v *Version) bool {
return c.function(v, c)
}
type cfunc func(v *Version, c *constraint) bool
func parseConstraint(c string) (*constraint, error) {
m := constraintRegex.FindStringSubmatch(c)
if m == nil {
return nil, fmt.Errorf("improper constraint: %s", c)
}
ver := m[2]
orig := ver
minorDirty := false
patchDirty := false
dirty := false
if isX(m[3]) {
ver = "0.0.0"
dirty = true
} else if isX(strings.TrimPrefix(m[4], ".")) || m[4] == "" {
minorDirty = true
dirty = true
ver = fmt.Sprintf("%s.0.0%s", m[3], m[6])
} else if isX(strings.TrimPrefix(m[5], ".")) {
dirty = true
patchDirty = true
ver = fmt.Sprintf("%s%s.0%s", m[3], m[4], m[6])
}
con, err := NewVersion(ver)
if err != nil {
// The constraintRegex should catch any regex parsing errors. So,
// we should never get here.
return nil, errors.New("constraint Parser Error")
}
cs := &constraint{
function: constraintOps[m[1]],
msg: constraintMsg[m[1]],
con: con,
orig: orig,
minorDirty: minorDirty,
patchDirty: patchDirty,
dirty: dirty,
}
return cs, nil
}
// Constraint functions
func constraintNotEqual(v *Version, c *constraint) bool {
if c.dirty {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
if c.con.Major() != v.Major() {
return true
}
if c.con.Minor() != v.Minor() && !c.minorDirty {
return true
} else if c.minorDirty {
return false
}
return false
}
return !v.Equal(c.con)
}
func constraintGreaterThan(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
return v.Compare(c.con) == 1
}
func constraintLessThan(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
if !c.dirty {
return v.Compare(c.con) < 0
}
if v.Major() > c.con.Major() {
return false
} else if v.Minor() > c.con.Minor() && !c.minorDirty {
return false
}
return true
}
func constraintGreaterThanEqual(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
return v.Compare(c.con) >= 0
}
func constraintLessThanEqual(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
if !c.dirty {
return v.Compare(c.con) <= 0
}
if v.Major() > c.con.Major() {
return false
} else if v.Minor() > c.con.Minor() && !c.minorDirty {
return false
}
return true
}
// ~*, ~>* --> >= 0.0.0 (any)
// ~2, ~2.x, ~2.x.x, ~>2, ~>2.x ~>2.x.x --> >=2.0.0, <3.0.0
// ~2.0, ~2.0.x, ~>2.0, ~>2.0.x --> >=2.0.0, <2.1.0
// ~1.2, ~1.2.x, ~>1.2, ~>1.2.x --> >=1.2.0, <1.3.0
// ~1.2.3, ~>1.2.3 --> >=1.2.3, <1.3.0
// ~1.2.0, ~>1.2.0 --> >=1.2.0, <1.3.0
func constraintTilde(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
if v.LessThan(c.con) {
return false
}
// ~0.0.0 is a special case where all constraints are accepted. It's
// equivalent to >= 0.0.0.
if c.con.Major() == 0 && c.con.Minor() == 0 && c.con.Patch() == 0 &&
!c.minorDirty && !c.patchDirty {
return true
}
if v.Major() != c.con.Major() {
return false
}
if v.Minor() != c.con.Minor() && !c.minorDirty {
return false
}
return true
}
// When there is a .x (dirty) status it automatically opts in to ~. Otherwise
// it's a straight =
func constraintTildeOrEqual(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
if c.dirty {
c.msg = constraintMsg["~"]
return constraintTilde(v, c)
}
return v.Equal(c.con)
}
// ^* --> (any)
// ^2, ^2.x, ^2.x.x --> >=2.0.0, <3.0.0
// ^2.0, ^2.0.x --> >=2.0.0, <3.0.0
// ^1.2, ^1.2.x --> >=1.2.0, <2.0.0
// ^1.2.3 --> >=1.2.3, <2.0.0
// ^1.2.0 --> >=1.2.0, <2.0.0
func constraintCaret(v *Version, c *constraint) bool {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
}
if v.LessThan(c.con) {
return false
}
if v.Major() != c.con.Major() {
return false
}
return true
}
var constraintRangeRegex *regexp.Regexp
const cvRegex string = `v?([0-9|x|X|\*]+)(\.[0-9|x|X|\*]+)?(\.[0-9|x|X|\*]+)?` +
`(-([0-9A-Za-z\-]+(\.[0-9A-Za-z\-]+)*))?` +
`(\+([0-9A-Za-z\-]+(\.[0-9A-Za-z\-]+)*))?`
func isX(x string) bool {
switch x {
case "x", "*", "X":
return true
default:
return false
}
}
func rewriteRange(i string) string {
m := constraintRangeRegex.FindAllStringSubmatch(i, -1)
if m == nil {
return i
}
o := i
for _, v := range m {
t := fmt.Sprintf(">= %s, <= %s", v[1], v[11])
o = strings.Replace(o, v[0], t, 1)
}
return o
}

View File

@@ -1,115 +0,0 @@
/*
Package semver provides the ability to work with Semantic Versions (http://semver.org) in Go.
Specifically it provides the ability to:
* Parse semantic versions
* Sort semantic versions
* Check if a semantic version fits within a set of constraints
* Optionally work with a `v` prefix
Parsing Semantic Versions
To parse a semantic version use the `NewVersion` function. For example,
v, err := semver.NewVersion("1.2.3-beta.1+build345")
If there is an error the version wasn't parseable. The version object has methods
to get the parts of the version, compare it to other versions, convert the
version back into a string, and get the original string. For more details
please see the documentation at https://godoc.org/github.com/Masterminds/semver.
Sorting Semantic Versions
A set of versions can be sorted using the `sort` package from the standard library.
For example,
raw := []string{"1.2.3", "1.0", "1.3", "2", "0.4.2",}
vs := make([]*semver.Version, len(raw))
for i, r := range raw {
v, err := semver.NewVersion(r)
if err != nil {
t.Errorf("Error parsing version: %s", err)
}
vs[i] = v
}
sort.Sort(semver.Collection(vs))
Checking Version Constraints
Checking a version against version constraints is one of the most featureful
parts of the package.
c, err := semver.NewConstraint(">= 1.2.3")
if err != nil {
// Handle constraint not being parseable.
}
v, err := semver.NewVersion("1.3")
if err != nil {
// Handle version not being parseable.
}
// Check if the version meets the constraints. The a variable will be true.
a := c.Check(v)
Basic Comparisons
There are two elements to the comparisons. First, a comparison string is a list
of comma separated and comparisons. These are then separated by || separated or
comparisons. For example, `">= 1.2, < 3.0.0 || >= 4.2.3"` is looking for a
comparison that's greater than or equal to 1.2 and less than 3.0.0 or is
greater than or equal to 4.2.3.
The basic comparisons are:
* `=`: equal (aliased to no operator)
* `!=`: not equal
* `>`: greater than
* `<`: less than
* `>=`: greater than or equal to
* `<=`: less than or equal to
Hyphen Range Comparisons
There are multiple methods to handle ranges and the first is hyphens ranges.
These look like:
* `1.2 - 1.4.5` which is equivalent to `>= 1.2, <= 1.4.5`
* `2.3.4 - 4.5` which is equivalent to `>= 2.3.4, <= 4.5`
Wildcards In Comparisons
The `x`, `X`, and `*` characters can be used as a wildcard character. This works
for all comparison operators. When used on the `=` operator it falls
back to the pack level comparison (see tilde below). For example,
* `1.2.x` is equivalent to `>= 1.2.0, < 1.3.0`
* `>= 1.2.x` is equivalent to `>= 1.2.0`
* `<= 2.x` is equivalent to `<= 3`
* `*` is equivalent to `>= 0.0.0`
Tilde Range Comparisons (Patch)
The tilde (`~`) comparison operator is for patch level ranges when a minor
version is specified and major level changes when the minor number is missing.
For example,
* `~1.2.3` is equivalent to `>= 1.2.3, < 1.3.0`
* `~1` is equivalent to `>= 1, < 2`
* `~2.3` is equivalent to `>= 2.3, < 2.4`
* `~1.2.x` is equivalent to `>= 1.2.0, < 1.3.0`
* `~1.x` is equivalent to `>= 1, < 2`
Caret Range Comparisons (Major)
The caret (`^`) comparison operator is for major level changes. This is useful
when comparisons of API versions as a major change is API breaking. For example,
* `^1.2.3` is equivalent to `>= 1.2.3, < 2.0.0`
* `^1.2.x` is equivalent to `>= 1.2.0, < 2.0.0`
* `^2.3` is equivalent to `>= 2.3, < 3`
* `^2.x` is equivalent to `>= 2.0.0, < 3`
*/
package semver

View File

@@ -1,421 +0,0 @@
package semver
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
)
// The compiled version of the regex created at init() is cached here so it
// only needs to be created once.
var versionRegex *regexp.Regexp
var validPrereleaseRegex *regexp.Regexp
var (
// ErrInvalidSemVer is returned a version is found to be invalid when
// being parsed.
ErrInvalidSemVer = errors.New("Invalid Semantic Version")
// ErrInvalidMetadata is returned when the metadata is an invalid format
ErrInvalidMetadata = errors.New("Invalid Metadata string")
// ErrInvalidPrerelease is returned when the pre-release is an invalid format
ErrInvalidPrerelease = errors.New("Invalid Prerelease string")
)
// SemVerRegex is the regular expression used to parse a semantic version.
const SemVerRegex string = `v?([0-9]+)(\.[0-9]+)?(\.[0-9]+)?` +
`(-([0-9A-Za-z\-]+(\.[0-9A-Za-z\-]+)*))?` +
`(\+([0-9A-Za-z\-]+(\.[0-9A-Za-z\-]+)*))?`
// ValidPrerelease is the regular expression which validates
// both prerelease and metadata values.
const ValidPrerelease string = `^([0-9A-Za-z\-]+(\.[0-9A-Za-z\-]+)*)`
// Version represents a single semantic version.
type Version struct {
major, minor, patch int64
pre string
metadata string
original string
}
func init() {
versionRegex = regexp.MustCompile("^" + SemVerRegex + "$")
validPrereleaseRegex = regexp.MustCompile(ValidPrerelease)
}
// NewVersion parses a given version and returns an instance of Version or
// an error if unable to parse the version.
func NewVersion(v string) (*Version, error) {
m := versionRegex.FindStringSubmatch(v)
if m == nil {
return nil, ErrInvalidSemVer
}
sv := &Version{
metadata: m[8],
pre: m[5],
original: v,
}
var temp int64
temp, err := strconv.ParseInt(m[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("Error parsing version segment: %s", err)
}
sv.major = temp
if m[2] != "" {
temp, err = strconv.ParseInt(strings.TrimPrefix(m[2], "."), 10, 64)
if err != nil {
return nil, fmt.Errorf("Error parsing version segment: %s", err)
}
sv.minor = temp
} else {
sv.minor = 0
}
if m[3] != "" {
temp, err = strconv.ParseInt(strings.TrimPrefix(m[3], "."), 10, 64)
if err != nil {
return nil, fmt.Errorf("Error parsing version segment: %s", err)
}
sv.patch = temp
} else {
sv.patch = 0
}
return sv, nil
}
// MustParse parses a given version and panics on error.
func MustParse(v string) *Version {
sv, err := NewVersion(v)
if err != nil {
panic(err)
}
return sv
}
// String converts a Version object to a string.
// Note, if the original version contained a leading v this version will not.
// See the Original() method to retrieve the original value. Semantic Versions
// don't contain a leading v per the spec. Instead it's optional on
// implementation.
func (v *Version) String() string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%d.%d.%d", v.major, v.minor, v.patch)
if v.pre != "" {
fmt.Fprintf(&buf, "-%s", v.pre)
}
if v.metadata != "" {
fmt.Fprintf(&buf, "+%s", v.metadata)
}
return buf.String()
}
// Original returns the original value passed in to be parsed.
func (v *Version) Original() string {
return v.original
}
// Major returns the major version.
func (v *Version) Major() int64 {
return v.major
}
// Minor returns the minor version.
func (v *Version) Minor() int64 {
return v.minor
}
// Patch returns the patch version.
func (v *Version) Patch() int64 {
return v.patch
}
// Prerelease returns the pre-release version.
func (v *Version) Prerelease() string {
return v.pre
}
// Metadata returns the metadata on the version.
func (v *Version) Metadata() string {
return v.metadata
}
// originalVPrefix returns the original 'v' prefix if any.
func (v *Version) originalVPrefix() string {
// Note, only lowercase v is supported as a prefix by the parser.
if v.original != "" && v.original[:1] == "v" {
return v.original[:1]
}
return ""
}
// IncPatch produces the next patch version.
// If the current version does not have prerelease/metadata information,
// it unsets metadata and prerelease values, increments patch number.
// If the current version has any of prerelease or metadata information,
// it unsets both values and keeps curent patch value
func (v Version) IncPatch() Version {
vNext := v
// according to http://semver.org/#spec-item-9
// Pre-release versions have a lower precedence than the associated normal version.
// according to http://semver.org/#spec-item-10
// Build metadata SHOULD be ignored when determining version precedence.
if v.pre != "" {
vNext.metadata = ""
vNext.pre = ""
} else {
vNext.metadata = ""
vNext.pre = ""
vNext.patch = v.patch + 1
}
vNext.original = v.originalVPrefix() + "" + vNext.String()
return vNext
}
// IncMinor produces the next minor version.
// Sets patch to 0.
// Increments minor number.
// Unsets metadata.
// Unsets prerelease status.
func (v Version) IncMinor() Version {
vNext := v
vNext.metadata = ""
vNext.pre = ""
vNext.patch = 0
vNext.minor = v.minor + 1
vNext.original = v.originalVPrefix() + "" + vNext.String()
return vNext
}
// IncMajor produces the next major version.
// Sets patch to 0.
// Sets minor to 0.
// Increments major number.
// Unsets metadata.
// Unsets prerelease status.
func (v Version) IncMajor() Version {
vNext := v
vNext.metadata = ""
vNext.pre = ""
vNext.patch = 0
vNext.minor = 0
vNext.major = v.major + 1
vNext.original = v.originalVPrefix() + "" + vNext.String()
return vNext
}
// SetPrerelease defines the prerelease value.
// Value must not include the required 'hypen' prefix.
func (v Version) SetPrerelease(prerelease string) (Version, error) {
vNext := v
if len(prerelease) > 0 && !validPrereleaseRegex.MatchString(prerelease) {
return vNext, ErrInvalidPrerelease
}
vNext.pre = prerelease
vNext.original = v.originalVPrefix() + "" + vNext.String()
return vNext, nil
}
// SetMetadata defines metadata value.
// Value must not include the required 'plus' prefix.
func (v Version) SetMetadata(metadata string) (Version, error) {
vNext := v
if len(metadata) > 0 && !validPrereleaseRegex.MatchString(metadata) {
return vNext, ErrInvalidMetadata
}
vNext.metadata = metadata
vNext.original = v.originalVPrefix() + "" + vNext.String()
return vNext, nil
}
// LessThan tests if one version is less than another one.
func (v *Version) LessThan(o *Version) bool {
return v.Compare(o) < 0
}
// GreaterThan tests if one version is greater than another one.
func (v *Version) GreaterThan(o *Version) bool {
return v.Compare(o) > 0
}
// Equal tests if two versions are equal to each other.
// Note, versions can be equal with different metadata since metadata
// is not considered part of the comparable version.
func (v *Version) Equal(o *Version) bool {
return v.Compare(o) == 0
}
// Compare compares this version to another one. It returns -1, 0, or 1 if
// the version smaller, equal, or larger than the other version.
//
// Versions are compared by X.Y.Z. Build metadata is ignored. Prerelease is
// lower than the version without a prerelease.
func (v *Version) Compare(o *Version) int {
// Compare the major, minor, and patch version for differences. If a
// difference is found return the comparison.
if d := compareSegment(v.Major(), o.Major()); d != 0 {
return d
}
if d := compareSegment(v.Minor(), o.Minor()); d != 0 {
return d
}
if d := compareSegment(v.Patch(), o.Patch()); d != 0 {
return d
}
// At this point the major, minor, and patch versions are the same.
ps := v.pre
po := o.Prerelease()
if ps == "" && po == "" {
return 0
}
if ps == "" {
return 1
}
if po == "" {
return -1
}
return comparePrerelease(ps, po)
}
// UnmarshalJSON implements JSON.Unmarshaler interface.
func (v *Version) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
temp, err := NewVersion(s)
if err != nil {
return err
}
v.major = temp.major
v.minor = temp.minor
v.patch = temp.patch
v.pre = temp.pre
v.metadata = temp.metadata
v.original = temp.original
temp = nil
return nil
}
// MarshalJSON implements JSON.Marshaler interface.
func (v *Version) MarshalJSON() ([]byte, error) {
return json.Marshal(v.String())
}
func compareSegment(v, o int64) int {
if v < o {
return -1
}
if v > o {
return 1
}
return 0
}
func comparePrerelease(v, o string) int {
// split the prelease versions by their part. The separator, per the spec,
// is a .
sparts := strings.Split(v, ".")
oparts := strings.Split(o, ".")
// Find the longer length of the parts to know how many loop iterations to
// go through.
slen := len(sparts)
olen := len(oparts)
l := slen
if olen > slen {
l = olen
}
// Iterate over each part of the prereleases to compare the differences.
for i := 0; i < l; i++ {
// Since the lentgh of the parts can be different we need to create
// a placeholder. This is to avoid out of bounds issues.
stemp := ""
if i < slen {
stemp = sparts[i]
}
otemp := ""
if i < olen {
otemp = oparts[i]
}
d := comparePrePart(stemp, otemp)
if d != 0 {
return d
}
}
// Reaching here means two versions are of equal value but have different
// metadata (the part following a +). They are not identical in string form
// but the version comparison finds them to be equal.
return 0
}
func comparePrePart(s, o string) int {
// Fastpath if they are equal
if s == o {
return 0
}
// When s or o are empty we can use the other in an attempt to determine
// the response.
if s == "" {
if o != "" {
return -1
}
return 1
}
if o == "" {
if s != "" {
return 1
}
return -1
}
// When comparing strings "99" is greater than "103". To handle
// cases like this we need to detect numbers and compare them.
oi, n1 := strconv.ParseInt(o, 10, 64)
si, n2 := strconv.ParseInt(s, 10, 64)
// The case where both are strings compare the strings
if n1 != nil && n2 != nil {
if s > o {
return 1
}
return -1
} else if n1 != nil {
// o is a string and s is a number
return -1
} else if n2 != nil {
// s is a string and o is a number
return 1
}
// Both are numbers
if si > oi {
return 1
}
return -1
}

View File

@@ -1,9 +0,0 @@
Copyright (c) 2013-2016 Antoine Imbert
The MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -1,236 +0,0 @@
package rest
import (
"bytes"
"fmt"
"log"
"net"
"os"
"strings"
"text/template"
"time"
)
// TODO Future improvements:
// * support %{strftime}t ?
// * support %{<header>}o to print headers
// AccessLogFormat defines the format of the access log record.
// This implementation is a subset of Apache mod_log_config.
// (See http://httpd.apache.org/docs/2.0/mod/mod_log_config.html)
//
// %b content length in bytes, - if 0
// %B content length in bytes
// %D response elapsed time in microseconds
// %h remote address
// %H server protocol
// %l identd logname, not supported, -
// %m http method
// %P process id
// %q query string
// %r first line of the request
// %s status code
// %S status code preceeded by a terminal color
// %t time of the request
// %T response elapsed time in seconds, 3 decimals
// %u remote user, - if missing
// %{User-Agent}i user agent, - if missing
// %{Referer}i referer, - is missing
//
// Some predefined formats are provided as contants.
type AccessLogFormat string
const (
// CommonLogFormat is the Common Log Format (CLF).
CommonLogFormat = "%h %l %u %t \"%r\" %s %b"
// CombinedLogFormat is the NCSA extended/combined log format.
CombinedLogFormat = "%h %l %u %t \"%r\" %s %b \"%{Referer}i\" \"%{User-Agent}i\""
// DefaultLogFormat is the default format, colored output and response time, convenient for development.
DefaultLogFormat = "%t %S\033[0m \033[36;1m%Dμs\033[0m \"%r\" \033[1;30m%u \"%{User-Agent}i\"\033[0m"
)
// AccessLogApacheMiddleware produces the access log following a format inspired by Apache
// mod_log_config. It depends on TimerMiddleware and RecorderMiddleware that should be in the wrapped
// middlewares. It also uses request.Env["REMOTE_USER"].(string) set by the auth middlewares.
type AccessLogApacheMiddleware struct {
// Logger points to the logger object used by this middleware, it defaults to
// log.New(os.Stderr, "", 0).
Logger *log.Logger
// Format defines the format of the access log record. See AccessLogFormat for the details.
// It defaults to DefaultLogFormat.
Format AccessLogFormat
textTemplate *template.Template
}
// MiddlewareFunc makes AccessLogApacheMiddleware implement the Middleware interface.
func (mw *AccessLogApacheMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
// set the default Logger
if mw.Logger == nil {
mw.Logger = log.New(os.Stderr, "", 0)
}
// set default format
if mw.Format == "" {
mw.Format = DefaultLogFormat
}
mw.convertFormat()
return func(w ResponseWriter, r *Request) {
// call the handler
h(w, r)
util := &accessLogUtil{w, r}
mw.Logger.Print(mw.executeTextTemplate(util))
}
}
var apacheAdapter = strings.NewReplacer(
"%b", "{{.BytesWritten | dashIf0}}",
"%B", "{{.BytesWritten}}",
"%D", "{{.ResponseTime | microseconds}}",
"%h", "{{.ApacheRemoteAddr}}",
"%H", "{{.R.Proto}}",
"%l", "-",
"%m", "{{.R.Method}}",
"%P", "{{.Pid}}",
"%q", "{{.ApacheQueryString}}",
"%r", "{{.R.Method}} {{.R.URL.RequestURI}} {{.R.Proto}}",
"%s", "{{.StatusCode}}",
"%S", "\033[{{.StatusCode | statusCodeColor}}m{{.StatusCode}}",
"%t", "{{if .StartTime}}{{.StartTime.Format \"02/Jan/2006:15:04:05 -0700\"}}{{end}}",
"%T", "{{if .ResponseTime}}{{.ResponseTime.Seconds | printf \"%.3f\"}}{{end}}",
"%u", "{{.RemoteUser | dashIfEmptyStr}}",
"%{User-Agent}i", "{{.R.UserAgent | dashIfEmptyStr}}",
"%{Referer}i", "{{.R.Referer | dashIfEmptyStr}}",
)
// Convert the Apache access log format into a text/template
func (mw *AccessLogApacheMiddleware) convertFormat() {
tmplText := apacheAdapter.Replace(string(mw.Format))
funcMap := template.FuncMap{
"dashIfEmptyStr": func(value string) string {
if value == "" {
return "-"
}
return value
},
"dashIf0": func(value int64) string {
if value == 0 {
return "-"
}
return fmt.Sprintf("%d", value)
},
"microseconds": func(dur *time.Duration) string {
if dur != nil {
return fmt.Sprintf("%d", dur.Nanoseconds()/1000)
}
return ""
},
"statusCodeColor": func(statusCode int) string {
if statusCode >= 400 && statusCode < 500 {
return "1;33"
} else if statusCode >= 500 {
return "0;31"
}
return "0;32"
},
}
var err error
mw.textTemplate, err = template.New("accessLog").Funcs(funcMap).Parse(tmplText)
if err != nil {
panic(err)
}
}
// Execute the text template with the data derived from the request, and return a string.
func (mw *AccessLogApacheMiddleware) executeTextTemplate(util *accessLogUtil) string {
buf := bytes.NewBufferString("")
err := mw.textTemplate.Execute(buf, util)
if err != nil {
panic(err)
}
return buf.String()
}
// accessLogUtil provides a collection of utility functions that devrive data from the Request object.
// This object is used to provide data to the Apache Style template and the the JSON log record.
type accessLogUtil struct {
W ResponseWriter
R *Request
}
// As stored by the auth middlewares.
func (u *accessLogUtil) RemoteUser() string {
if u.R.Env["REMOTE_USER"] != nil {
return u.R.Env["REMOTE_USER"].(string)
}
return ""
}
// If qs exists then return it with a leadin "?", apache log style.
func (u *accessLogUtil) ApacheQueryString() string {
if u.R.URL.RawQuery != "" {
return "?" + u.R.URL.RawQuery
}
return ""
}
// When the request entered the timer middleware.
func (u *accessLogUtil) StartTime() *time.Time {
if u.R.Env["START_TIME"] != nil {
return u.R.Env["START_TIME"].(*time.Time)
}
return nil
}
// If remoteAddr is set then return is without the port number, apache log style.
func (u *accessLogUtil) ApacheRemoteAddr() string {
remoteAddr := u.R.RemoteAddr
if remoteAddr != "" {
if ip, _, err := net.SplitHostPort(remoteAddr); err == nil {
return ip
}
}
return ""
}
// As recorded by the recorder middleware.
func (u *accessLogUtil) StatusCode() int {
if u.R.Env["STATUS_CODE"] != nil {
return u.R.Env["STATUS_CODE"].(int)
}
return 0
}
// As mesured by the timer middleware.
func (u *accessLogUtil) ResponseTime() *time.Duration {
if u.R.Env["ELAPSED_TIME"] != nil {
return u.R.Env["ELAPSED_TIME"].(*time.Duration)
}
return nil
}
// Process id.
func (u *accessLogUtil) Pid() int {
return os.Getpid()
}
// As recorded by the recorder middleware.
func (u *accessLogUtil) BytesWritten() int64 {
if u.R.Env["BYTES_WRITTEN"] != nil {
return u.R.Env["BYTES_WRITTEN"].(int64)
}
return 0
}

View File

@@ -1,88 +0,0 @@
package rest
import (
"encoding/json"
"log"
"os"
"time"
)
// AccessLogJsonMiddleware produces the access log with records written as JSON. This middleware
// depends on TimerMiddleware and RecorderMiddleware that must be in the wrapped middlewares. It
// also uses request.Env["REMOTE_USER"].(string) set by the auth middlewares.
type AccessLogJsonMiddleware struct {
// Logger points to the logger object used by this middleware, it defaults to
// log.New(os.Stderr, "", 0).
Logger *log.Logger
}
// MiddlewareFunc makes AccessLogJsonMiddleware implement the Middleware interface.
func (mw *AccessLogJsonMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
// set the default Logger
if mw.Logger == nil {
mw.Logger = log.New(os.Stderr, "", 0)
}
return func(w ResponseWriter, r *Request) {
// call the handler
h(w, r)
mw.Logger.Printf("%s", makeAccessLogJsonRecord(r).asJson())
}
}
// AccessLogJsonRecord is the data structure used by AccessLogJsonMiddleware to create the JSON
// records. (Public for documentation only, no public method uses it)
type AccessLogJsonRecord struct {
Timestamp *time.Time
StatusCode int
ResponseTime *time.Duration
HttpMethod string
RequestURI string
RemoteUser string
UserAgent string
}
func makeAccessLogJsonRecord(r *Request) *AccessLogJsonRecord {
var timestamp *time.Time
if r.Env["START_TIME"] != nil {
timestamp = r.Env["START_TIME"].(*time.Time)
}
var statusCode int
if r.Env["STATUS_CODE"] != nil {
statusCode = r.Env["STATUS_CODE"].(int)
}
var responseTime *time.Duration
if r.Env["ELAPSED_TIME"] != nil {
responseTime = r.Env["ELAPSED_TIME"].(*time.Duration)
}
var remoteUser string
if r.Env["REMOTE_USER"] != nil {
remoteUser = r.Env["REMOTE_USER"].(string)
}
return &AccessLogJsonRecord{
Timestamp: timestamp,
StatusCode: statusCode,
ResponseTime: responseTime,
HttpMethod: r.Method,
RequestURI: r.URL.RequestURI(),
RemoteUser: remoteUser,
UserAgent: r.UserAgent(),
}
}
func (r *AccessLogJsonRecord) asJson() []byte {
b, err := json.Marshal(r)
if err != nil {
panic(err)
}
return b
}

View File

@@ -1,83 +0,0 @@
package rest
import (
"net/http"
)
// Api defines a stack of Middlewares and an App.
type Api struct {
stack []Middleware
app App
}
// NewApi makes a new Api object. The Middleware stack is empty, and the App is nil.
func NewApi() *Api {
return &Api{
stack: []Middleware{},
app: nil,
}
}
// Use pushes one or multiple middlewares to the stack for middlewares
// maintained in the Api object.
func (api *Api) Use(middlewares ...Middleware) {
api.stack = append(api.stack, middlewares...)
}
// SetApp sets the App in the Api object.
func (api *Api) SetApp(app App) {
api.app = app
}
// MakeHandler wraps all the Middlewares of the stack and the App together, and returns an
// http.Handler ready to be used. If the Middleware stack is empty the App is used directly. If the
// App is nil, a HandlerFunc that does nothing is used instead.
func (api *Api) MakeHandler() http.Handler {
var appFunc HandlerFunc
if api.app != nil {
appFunc = api.app.AppFunc()
} else {
appFunc = func(w ResponseWriter, r *Request) {}
}
return http.HandlerFunc(
adapterFunc(
WrapMiddlewares(api.stack, appFunc),
),
)
}
// Defines a stack of middlewares convenient for development. Among other things:
// console friendly logging, JSON indentation, error stack strace in the response.
var DefaultDevStack = []Middleware{
&AccessLogApacheMiddleware{},
&TimerMiddleware{},
&RecorderMiddleware{},
&PoweredByMiddleware{},
&RecoverMiddleware{
EnableResponseStackTrace: true,
},
&JsonIndentMiddleware{},
&ContentTypeCheckerMiddleware{},
}
// Defines a stack of middlewares convenient for production. Among other things:
// Apache CombinedLogFormat logging, gzip compression.
var DefaultProdStack = []Middleware{
&AccessLogApacheMiddleware{
Format: CombinedLogFormat,
},
&TimerMiddleware{},
&RecorderMiddleware{},
&PoweredByMiddleware{},
&RecoverMiddleware{},
&GzipMiddleware{},
&ContentTypeCheckerMiddleware{},
}
// Defines a stack of middlewares that should be common to most of the middleware stacks.
var DefaultCommonStack = []Middleware{
&TimerMiddleware{},
&RecorderMiddleware{},
&PoweredByMiddleware{},
&RecoverMiddleware{},
}

View File

@@ -1,100 +0,0 @@
package rest
import (
"encoding/base64"
"errors"
"log"
"net/http"
"strings"
)
// AuthBasicMiddleware provides a simple AuthBasic implementation. On failure, a 401 HTTP response
//is returned. On success, the wrapped middleware is called, and the userId is made available as
// request.Env["REMOTE_USER"].(string)
type AuthBasicMiddleware struct {
// Realm name to display to the user. Required.
Realm string
// Callback function that should perform the authentication of the user based on userId and
// password. Must return true on success, false on failure. Required.
Authenticator func(userId string, password string) bool
// Callback function that should perform the authorization of the authenticated user. Called
// only after an authentication success. Must return true on success, false on failure.
// Optional, default to success.
Authorizator func(userId string, request *Request) bool
}
// MiddlewareFunc makes AuthBasicMiddleware implement the Middleware interface.
func (mw *AuthBasicMiddleware) MiddlewareFunc(handler HandlerFunc) HandlerFunc {
if mw.Realm == "" {
log.Fatal("Realm is required")
}
if mw.Authenticator == nil {
log.Fatal("Authenticator is required")
}
if mw.Authorizator == nil {
mw.Authorizator = func(userId string, request *Request) bool {
return true
}
}
return func(writer ResponseWriter, request *Request) {
authHeader := request.Header.Get("Authorization")
if authHeader == "" {
mw.unauthorized(writer)
return
}
providedUserId, providedPassword, err := mw.decodeBasicAuthHeader(authHeader)
if err != nil {
Error(writer, "Invalid authentication", http.StatusBadRequest)
return
}
if !mw.Authenticator(providedUserId, providedPassword) {
mw.unauthorized(writer)
return
}
if !mw.Authorizator(providedUserId, request) {
mw.unauthorized(writer)
return
}
request.Env["REMOTE_USER"] = providedUserId
handler(writer, request)
}
}
func (mw *AuthBasicMiddleware) unauthorized(writer ResponseWriter) {
writer.Header().Set("WWW-Authenticate", "Basic realm="+mw.Realm)
Error(writer, "Not Authorized", http.StatusUnauthorized)
}
func (mw *AuthBasicMiddleware) decodeBasicAuthHeader(header string) (user string, password string, err error) {
parts := strings.SplitN(header, " ", 2)
if !(len(parts) == 2 && parts[0] == "Basic") {
return "", "", errors.New("Invalid authentication")
}
decoded, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", "", errors.New("Invalid base64")
}
creds := strings.SplitN(string(decoded), ":", 2)
if len(creds) != 2 {
return "", "", errors.New("Invalid authentication")
}
return creds[0], creds[1], nil
}

View File

@@ -1,40 +0,0 @@
package rest
import (
"mime"
"net/http"
"strings"
)
// ContentTypeCheckerMiddleware verifies the request Content-Type header and returns a
// StatusUnsupportedMediaType (415) HTTP error response if it's incorrect. The expected
// Content-Type is 'application/json' if the content is non-null. Note: If a charset parameter
// exists, it MUST be UTF-8.
type ContentTypeCheckerMiddleware struct{}
// MiddlewareFunc makes ContentTypeCheckerMiddleware implement the Middleware interface.
func (mw *ContentTypeCheckerMiddleware) MiddlewareFunc(handler HandlerFunc) HandlerFunc {
return func(w ResponseWriter, r *Request) {
mediatype, params, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
charset, ok := params["charset"]
if !ok {
charset = "UTF-8"
}
// per net/http doc, means that the length is known and non-null
if r.ContentLength > 0 &&
!(mediatype == "application/json" && strings.ToUpper(charset) == "UTF-8") {
Error(w,
"Bad Content-Type or charset, expected 'application/json'",
http.StatusUnsupportedMediaType,
)
return
}
// call the wrapped handler
handler(w, r)
}
}

View File

@@ -1,135 +0,0 @@
package rest
import (
"net/http"
"strconv"
"strings"
)
// Possible improvements:
// If AllowedMethods["*"] then Access-Control-Allow-Methods is set to the requested methods
// If AllowedHeaderss["*"] then Access-Control-Allow-Headers is set to the requested headers
// Put some presets in AllowedHeaders
// Put some presets in AccessControlExposeHeaders
// CorsMiddleware provides a configurable CORS implementation.
type CorsMiddleware struct {
allowedMethods map[string]bool
allowedMethodsCsv string
allowedHeaders map[string]bool
allowedHeadersCsv string
// Reject non CORS requests if true. See CorsInfo.IsCors.
RejectNonCorsRequests bool
// Function excecuted for every CORS requests to validate the Origin. (Required)
// Must return true if valid, false if invalid.
// For instance: simple equality, regexp, DB lookup, ...
OriginValidator func(origin string, request *Request) bool
// List of allowed HTTP methods. Note that the comparison will be made in
// uppercase to avoid common mistakes. And that the
// Access-Control-Allow-Methods response header also uses uppercase.
// (see CorsInfo.AccessControlRequestMethod)
AllowedMethods []string
// List of allowed HTTP Headers. Note that the comparison will be made with
// noarmalized names (http.CanonicalHeaderKey). And that the response header
// also uses normalized names.
// (see CorsInfo.AccessControlRequestHeaders)
AllowedHeaders []string
// List of headers used to set the Access-Control-Expose-Headers header.
AccessControlExposeHeaders []string
// User to se the Access-Control-Allow-Credentials response header.
AccessControlAllowCredentials bool
// Used to set the Access-Control-Max-Age response header, in seconds.
AccessControlMaxAge int
}
// MiddlewareFunc makes CorsMiddleware implement the Middleware interface.
func (mw *CorsMiddleware) MiddlewareFunc(handler HandlerFunc) HandlerFunc {
// precompute as much as possible at init time
mw.allowedMethods = map[string]bool{}
normedMethods := []string{}
for _, allowedMethod := range mw.AllowedMethods {
normed := strings.ToUpper(allowedMethod)
mw.allowedMethods[normed] = true
normedMethods = append(normedMethods, normed)
}
mw.allowedMethodsCsv = strings.Join(normedMethods, ",")
mw.allowedHeaders = map[string]bool{}
normedHeaders := []string{}
for _, allowedHeader := range mw.AllowedHeaders {
normed := http.CanonicalHeaderKey(allowedHeader)
mw.allowedHeaders[normed] = true
normedHeaders = append(normedHeaders, normed)
}
mw.allowedHeadersCsv = strings.Join(normedHeaders, ",")
return func(writer ResponseWriter, request *Request) {
corsInfo := request.GetCorsInfo()
// non CORS requests
if !corsInfo.IsCors {
if mw.RejectNonCorsRequests {
Error(writer, "Non CORS request", http.StatusForbidden)
return
}
// continue, execute the wrapped middleware
handler(writer, request)
return
}
// Validate the Origin
if mw.OriginValidator(corsInfo.Origin, request) == false {
Error(writer, "Invalid Origin", http.StatusForbidden)
return
}
if corsInfo.IsPreflight {
// check the request methods
if mw.allowedMethods[corsInfo.AccessControlRequestMethod] == false {
Error(writer, "Invalid Preflight Request", http.StatusForbidden)
return
}
// check the request headers
for _, requestedHeader := range corsInfo.AccessControlRequestHeaders {
if mw.allowedHeaders[requestedHeader] == false {
Error(writer, "Invalid Preflight Request", http.StatusForbidden)
return
}
}
writer.Header().Set("Access-Control-Allow-Methods", mw.allowedMethodsCsv)
writer.Header().Set("Access-Control-Allow-Headers", mw.allowedHeadersCsv)
writer.Header().Set("Access-Control-Allow-Origin", corsInfo.Origin)
if mw.AccessControlAllowCredentials == true {
writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
writer.Header().Set("Access-Control-Max-Age", strconv.Itoa(mw.AccessControlMaxAge))
writer.WriteHeader(http.StatusOK)
return
}
// Non-preflight requests
for _, exposed := range mw.AccessControlExposeHeaders {
writer.Header().Add("Access-Control-Expose-Headers", exposed)
}
writer.Header().Set("Access-Control-Allow-Origin", corsInfo.Origin)
if mw.AccessControlAllowCredentials == true {
writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
// continure, execute the wrapped middleware
handler(writer, request)
return
}
}

View File

@@ -1,47 +0,0 @@
// A quick and easy way to setup a RESTful JSON API
//
// http://ant0ine.github.io/go-json-rest/
//
// Go-Json-Rest is a thin layer on top of net/http that helps building RESTful JSON APIs easily.
// It provides fast and scalable request routing using a Trie based implementation, helpers to deal
// with JSON requests and responses, and middlewares for functionalities like CORS, Auth, Gzip,
// Status, ...
//
// Example:
//
// package main
//
// import (
// "github.com/ant0ine/go-json-rest/rest"
// "log"
// "net/http"
// )
//
// type User struct {
// Id string
// Name string
// }
//
// func GetUser(w rest.ResponseWriter, req *rest.Request) {
// user := User{
// Id: req.PathParam("id"),
// Name: "Antoine",
// }
// w.WriteJson(&user)
// }
//
// func main() {
// api := rest.NewApi()
// api.Use(rest.DefaultDevStack...)
// router, err := rest.MakeRouter(
// rest.Get("/users/:id", GetUser),
// )
// if err != nil {
// log.Fatal(err)
// }
// api.SetApp(router)
// log.Fatal(http.ListenAndServe(":8080", api.MakeHandler()))
// }
//
//
package rest

View File

@@ -1,132 +0,0 @@
package rest
import (
"bufio"
"compress/gzip"
"net"
"net/http"
"strings"
)
// GzipMiddleware is responsible for compressing the payload with gzip and setting the proper
// headers when supported by the client. It must be wrapped by TimerMiddleware for the
// compression time to be captured. And It must be wrapped by RecorderMiddleware for the
// compressed BYTES_WRITTEN to be captured.
type GzipMiddleware struct{}
// MiddlewareFunc makes GzipMiddleware implement the Middleware interface.
func (mw *GzipMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
return func(w ResponseWriter, r *Request) {
// gzip support enabled
canGzip := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip")
// client accepts gzip ?
writer := &gzipResponseWriter{w, false, canGzip, nil}
defer func() {
// need to close gzip writer
if writer.gzipWriter != nil {
writer.gzipWriter.Close()
}
}()
// call the handler with the wrapped writer
h(writer, r)
}
}
// Private responseWriter intantiated by the gzip middleware.
// It encodes the payload with gzip and set the proper headers.
// It implements the following interfaces:
// ResponseWriter
// http.ResponseWriter
// http.Flusher
// http.CloseNotifier
// http.Hijacker
type gzipResponseWriter struct {
ResponseWriter
wroteHeader bool
canGzip bool
gzipWriter *gzip.Writer
}
// Set the right headers for gzip encoded responses.
func (w *gzipResponseWriter) WriteHeader(code int) {
// Always set the Vary header, even if this particular request
// is not gzipped.
w.Header().Add("Vary", "Accept-Encoding")
if w.canGzip {
w.Header().Set("Content-Encoding", "gzip")
}
w.ResponseWriter.WriteHeader(code)
w.wroteHeader = true
}
// Make sure the local Write is called.
func (w *gzipResponseWriter) WriteJson(v interface{}) error {
b, err := w.EncodeJson(v)
if err != nil {
return err
}
_, err = w.Write(b)
if err != nil {
return err
}
return nil
}
// Make sure the local WriteHeader is called, and call the parent Flush.
// Provided in order to implement the http.Flusher interface.
func (w *gzipResponseWriter) Flush() {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
flusher := w.ResponseWriter.(http.Flusher)
flusher.Flush()
}
// Call the parent CloseNotify.
// Provided in order to implement the http.CloseNotifier interface.
func (w *gzipResponseWriter) CloseNotify() <-chan bool {
notifier := w.ResponseWriter.(http.CloseNotifier)
return notifier.CloseNotify()
}
// Provided in order to implement the http.Hijacker interface.
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker := w.ResponseWriter.(http.Hijacker)
return hijacker.Hijack()
}
// Make sure the local WriteHeader is called, and encode the payload if necessary.
// Provided in order to implement the http.ResponseWriter interface.
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
writer := w.ResponseWriter.(http.ResponseWriter)
if w.canGzip {
// Write can be called multiple times for a given response.
// (see the streaming example:
// https://github.com/ant0ine/go-json-rest-examples/tree/master/streaming)
// The gzipWriter is instantiated only once, and flushed after
// each write.
if w.gzipWriter == nil {
w.gzipWriter = gzip.NewWriter(writer)
}
count, errW := w.gzipWriter.Write(b)
errF := w.gzipWriter.Flush()
if errW != nil {
return count, errW
}
if errF != nil {
return count, errF
}
return count, nil
}
return writer.Write(b)
}

View File

@@ -1,53 +0,0 @@
package rest
import (
"log"
)
// IfMiddleware evaluates at runtime a condition based on the current request, and decides to
// execute one of the other Middleware based on this boolean.
type IfMiddleware struct {
// Runtime condition that decides of the execution of IfTrue of IfFalse.
Condition func(r *Request) bool
// Middleware to run when the condition is true. Note that the middleware is initialized
// weather if will be used or not. (Optional, pass-through if not set)
IfTrue Middleware
// Middleware to run when the condition is false. Note that the middleware is initialized
// weather if will be used or not. (Optional, pass-through if not set)
IfFalse Middleware
}
// MiddlewareFunc makes TimerMiddleware implement the Middleware interface.
func (mw *IfMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
if mw.Condition == nil {
log.Fatal("IfMiddleware Condition is required")
}
var ifTrueHandler HandlerFunc
if mw.IfTrue != nil {
ifTrueHandler = mw.IfTrue.MiddlewareFunc(h)
} else {
ifTrueHandler = h
}
var ifFalseHandler HandlerFunc
if mw.IfFalse != nil {
ifFalseHandler = mw.IfFalse.MiddlewareFunc(h)
} else {
ifFalseHandler = h
}
return func(w ResponseWriter, r *Request) {
if mw.Condition(r) {
ifTrueHandler(w, r)
} else {
ifFalseHandler(w, r)
}
}
}

View File

@@ -1,113 +0,0 @@
package rest
import (
"bufio"
"encoding/json"
"net"
"net/http"
)
// JsonIndentMiddleware provides JSON encoding with indentation.
// It could be convenient to use it during development.
// It works by "subclassing" the responseWriter provided by the wrapping middleware,
// replacing the writer.EncodeJson and writer.WriteJson implementations,
// and making the parent implementations ignored.
type JsonIndentMiddleware struct {
// prefix string, as in json.MarshalIndent
Prefix string
// indentation string, as in json.MarshalIndent
Indent string
}
// MiddlewareFunc makes JsonIndentMiddleware implement the Middleware interface.
func (mw *JsonIndentMiddleware) MiddlewareFunc(handler HandlerFunc) HandlerFunc {
if mw.Indent == "" {
mw.Indent = " "
}
return func(w ResponseWriter, r *Request) {
writer := &jsonIndentResponseWriter{w, false, mw.Prefix, mw.Indent}
// call the wrapped handler
handler(writer, r)
}
}
// Private responseWriter intantiated by the middleware.
// It implements the following interfaces:
// ResponseWriter
// http.ResponseWriter
// http.Flusher
// http.CloseNotifier
// http.Hijacker
type jsonIndentResponseWriter struct {
ResponseWriter
wroteHeader bool
prefix string
indent string
}
// Replace the parent EncodeJson to provide indentation.
func (w *jsonIndentResponseWriter) EncodeJson(v interface{}) ([]byte, error) {
b, err := json.MarshalIndent(v, w.prefix, w.indent)
if err != nil {
return nil, err
}
return b, nil
}
// Make sure the local EncodeJson and local Write are called.
// Does not call the parent WriteJson.
func (w *jsonIndentResponseWriter) WriteJson(v interface{}) error {
b, err := w.EncodeJson(v)
if err != nil {
return err
}
_, err = w.Write(b)
if err != nil {
return err
}
return nil
}
// Call the parent WriteHeader.
func (w *jsonIndentResponseWriter) WriteHeader(code int) {
w.ResponseWriter.WriteHeader(code)
w.wroteHeader = true
}
// Make sure the local WriteHeader is called, and call the parent Flush.
// Provided in order to implement the http.Flusher interface.
func (w *jsonIndentResponseWriter) Flush() {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
flusher := w.ResponseWriter.(http.Flusher)
flusher.Flush()
}
// Call the parent CloseNotify.
// Provided in order to implement the http.CloseNotifier interface.
func (w *jsonIndentResponseWriter) CloseNotify() <-chan bool {
notifier := w.ResponseWriter.(http.CloseNotifier)
return notifier.CloseNotify()
}
// Provided in order to implement the http.Hijacker interface.
func (w *jsonIndentResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker := w.ResponseWriter.(http.Hijacker)
return hijacker.Hijack()
}
// Make sure the local WriteHeader is called, and call the parent Write.
// Provided in order to implement the http.ResponseWriter interface.
func (w *jsonIndentResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
writer := w.ResponseWriter.(http.ResponseWriter)
return writer.Write(b)
}

View File

@@ -1,116 +0,0 @@
package rest
import (
"bufio"
"net"
"net/http"
)
// JsonpMiddleware provides JSONP responses on demand, based on the presence
// of a query string argument specifying the callback name.
type JsonpMiddleware struct {
// Name of the query string parameter used to specify the
// the name of the JS callback used for the padding.
// Defaults to "callback".
CallbackNameKey string
}
// MiddlewareFunc returns a HandlerFunc that implements the middleware.
func (mw *JsonpMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
if mw.CallbackNameKey == "" {
mw.CallbackNameKey = "callback"
}
return func(w ResponseWriter, r *Request) {
callbackName := r.URL.Query().Get(mw.CallbackNameKey)
// TODO validate the callbackName ?
if callbackName != "" {
// the client request JSONP, instantiate JsonpMiddleware.
writer := &jsonpResponseWriter{w, false, callbackName}
// call the handler with the wrapped writer
h(writer, r)
} else {
// do nothing special
h(w, r)
}
}
}
// Private responseWriter intantiated by the JSONP middleware.
// It adds the padding to the payload and set the proper headers.
// It implements the following interfaces:
// ResponseWriter
// http.ResponseWriter
// http.Flusher
// http.CloseNotifier
// http.Hijacker
type jsonpResponseWriter struct {
ResponseWriter
wroteHeader bool
callbackName string
}
// Overwrite the Content-Type to be text/javascript
func (w *jsonpResponseWriter) WriteHeader(code int) {
w.Header().Set("Content-Type", "text/javascript")
w.ResponseWriter.WriteHeader(code)
w.wroteHeader = true
}
// Make sure the local Write is called.
func (w *jsonpResponseWriter) WriteJson(v interface{}) error {
b, err := w.EncodeJson(v)
if err != nil {
return err
}
// JSONP security fix (http://miki.it/blog/2014/7/8/abusing-jsonp-with-rosetta-flash/)
w.Header().Set("Content-Disposition", "filename=f.txt")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Write([]byte("/**/" + w.callbackName + "("))
w.Write(b)
w.Write([]byte(")"))
return nil
}
// Make sure the local WriteHeader is called, and call the parent Flush.
// Provided in order to implement the http.Flusher interface.
func (w *jsonpResponseWriter) Flush() {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
flusher := w.ResponseWriter.(http.Flusher)
flusher.Flush()
}
// Call the parent CloseNotify.
// Provided in order to implement the http.CloseNotifier interface.
func (w *jsonpResponseWriter) CloseNotify() <-chan bool {
notifier := w.ResponseWriter.(http.CloseNotifier)
return notifier.CloseNotify()
}
// Provided in order to implement the http.Hijacker interface.
func (w *jsonpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker := w.ResponseWriter.(http.Hijacker)
return hijacker.Hijack()
}
// Make sure the local WriteHeader is called.
// Provided in order to implement the http.ResponseWriter interface.
func (w *jsonpResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
writer := w.ResponseWriter.(http.ResponseWriter)
return writer.Write(b)
}

View File

@@ -1,72 +0,0 @@
package rest
import (
"net/http"
)
// HandlerFunc defines the handler function. It is the go-json-rest equivalent of http.HandlerFunc.
type HandlerFunc func(ResponseWriter, *Request)
// App defines the interface that an object should implement to be used as an app in this framework
// stack. The App is the top element of the stack, the other elements being middlewares.
type App interface {
AppFunc() HandlerFunc
}
// AppSimple is an adapter type that makes it easy to write an App with a simple function.
// eg: rest.NewApi(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { ... }))
type AppSimple HandlerFunc
// AppFunc makes AppSimple implement the App interface.
func (as AppSimple) AppFunc() HandlerFunc {
return HandlerFunc(as)
}
// Middleware defines the interface that objects must implement in order to wrap a HandlerFunc and
// be used in the middleware stack.
type Middleware interface {
MiddlewareFunc(handler HandlerFunc) HandlerFunc
}
// MiddlewareSimple is an adapter type that makes it easy to write a Middleware with a simple
// function. eg: api.Use(rest.MiddlewareSimple(func(h HandlerFunc) Handlerfunc { ... }))
type MiddlewareSimple func(handler HandlerFunc) HandlerFunc
// MiddlewareFunc makes MiddlewareSimple implement the Middleware interface.
func (ms MiddlewareSimple) MiddlewareFunc(handler HandlerFunc) HandlerFunc {
return ms(handler)
}
// WrapMiddlewares calls the MiddlewareFunc methods in the reverse order and returns an HandlerFunc
// ready to be executed. This can be used to wrap a set of middlewares, post routing, on a per Route
// basis.
func WrapMiddlewares(middlewares []Middleware, handler HandlerFunc) HandlerFunc {
wrapped := handler
for i := len(middlewares) - 1; i >= 0; i-- {
wrapped = middlewares[i].MiddlewareFunc(wrapped)
}
return wrapped
}
// Handle the transition between net/http and go-json-rest objects.
// It intanciates the rest.Request and rest.ResponseWriter, ...
func adapterFunc(handler HandlerFunc) http.HandlerFunc {
return func(origWriter http.ResponseWriter, origRequest *http.Request) {
// instantiate the rest objects
request := &Request{
origRequest,
nil,
map[string]interface{}{},
}
writer := &responseWriter{
origWriter,
false,
}
// call the wrapped handler
handler(writer, request)
}
}

View File

@@ -1,29 +0,0 @@
package rest
const xPoweredByDefault = "go-json-rest"
// PoweredByMiddleware adds the "X-Powered-By" header to the HTTP response.
type PoweredByMiddleware struct {
// If specified, used as the value for the "X-Powered-By" response header.
// Defaults to "go-json-rest".
XPoweredBy string
}
// MiddlewareFunc makes PoweredByMiddleware implement the Middleware interface.
func (mw *PoweredByMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
poweredBy := xPoweredByDefault
if mw.XPoweredBy != "" {
poweredBy = mw.XPoweredBy
}
return func(w ResponseWriter, r *Request) {
w.Header().Add("X-Powered-By", poweredBy)
// call the handler
h(w, r)
}
}

View File

@@ -1,100 +0,0 @@
package rest
import (
"bufio"
"net"
"net/http"
)
// RecorderMiddleware keeps a record of the HTTP status code of the response,
// and the number of bytes written.
// The result is available to the wrapping handlers as request.Env["STATUS_CODE"].(int),
// and as request.Env["BYTES_WRITTEN"].(int64)
type RecorderMiddleware struct{}
// MiddlewareFunc makes RecorderMiddleware implement the Middleware interface.
func (mw *RecorderMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
return func(w ResponseWriter, r *Request) {
writer := &recorderResponseWriter{w, 0, false, 0}
// call the handler
h(writer, r)
r.Env["STATUS_CODE"] = writer.statusCode
r.Env["BYTES_WRITTEN"] = writer.bytesWritten
}
}
// Private responseWriter intantiated by the recorder middleware.
// It keeps a record of the HTTP status code of the response.
// It implements the following interfaces:
// ResponseWriter
// http.ResponseWriter
// http.Flusher
// http.CloseNotifier
// http.Hijacker
type recorderResponseWriter struct {
ResponseWriter
statusCode int
wroteHeader bool
bytesWritten int64
}
// Record the status code.
func (w *recorderResponseWriter) WriteHeader(code int) {
w.ResponseWriter.WriteHeader(code)
if w.wroteHeader {
return
}
w.statusCode = code
w.wroteHeader = true
}
// Make sure the local Write is called.
func (w *recorderResponseWriter) WriteJson(v interface{}) error {
b, err := w.EncodeJson(v)
if err != nil {
return err
}
_, err = w.Write(b)
if err != nil {
return err
}
return nil
}
// Make sure the local WriteHeader is called, and call the parent Flush.
// Provided in order to implement the http.Flusher interface.
func (w *recorderResponseWriter) Flush() {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
flusher := w.ResponseWriter.(http.Flusher)
flusher.Flush()
}
// Call the parent CloseNotify.
// Provided in order to implement the http.CloseNotifier interface.
func (w *recorderResponseWriter) CloseNotify() <-chan bool {
notifier := w.ResponseWriter.(http.CloseNotifier)
return notifier.CloseNotify()
}
// Provided in order to implement the http.Hijacker interface.
func (w *recorderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker := w.ResponseWriter.(http.Hijacker)
return hijacker.Hijack()
}
// Make sure the local WriteHeader is called, and call the parent Write.
// Provided in order to implement the http.ResponseWriter interface.
func (w *recorderResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
writer := w.ResponseWriter.(http.ResponseWriter)
written, err := writer.Write(b)
w.bytesWritten += int64(written)
return written, err
}

View File

@@ -1,74 +0,0 @@
package rest
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"runtime/debug"
)
// RecoverMiddleware catches the panic errors that occur in the wrapped HandleFunc,
// and convert them to 500 responses.
type RecoverMiddleware struct {
// Custom logger used for logging the panic errors,
// optional, defaults to log.New(os.Stderr, "", 0)
Logger *log.Logger
// If true, the log records will be printed as JSON. Convenient for log parsing.
EnableLogAsJson bool
// If true, when a "panic" happens, the error string and the stack trace will be
// printed in the 500 response body.
EnableResponseStackTrace bool
}
// MiddlewareFunc makes RecoverMiddleware implement the Middleware interface.
func (mw *RecoverMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
// set the default Logger
if mw.Logger == nil {
mw.Logger = log.New(os.Stderr, "", 0)
}
return func(w ResponseWriter, r *Request) {
// catch user code's panic, and convert to http response
defer func() {
if reco := recover(); reco != nil {
trace := debug.Stack()
// log the trace
message := fmt.Sprintf("%s\n%s", reco, trace)
mw.logError(message)
// write error response
if mw.EnableResponseStackTrace {
Error(w, message, http.StatusInternalServerError)
} else {
Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
}()
// call the handler
h(w, r)
}
}
func (mw *RecoverMiddleware) logError(message string) {
if mw.EnableLogAsJson {
record := map[string]string{
"error": message,
}
b, err := json.Marshal(&record)
if err != nil {
panic(err)
}
mw.Logger.Printf("%s", b)
} else {
mw.Logger.Print(message)
}
}

View File

@@ -1,148 +0,0 @@
package rest
import (
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
var (
// ErrJsonPayloadEmpty is returned when the JSON payload is empty.
ErrJsonPayloadEmpty = errors.New("JSON payload is empty")
)
// Request inherits from http.Request, and provides additional methods.
type Request struct {
*http.Request
// Map of parameters that have been matched in the URL Path.
PathParams map[string]string
// Environment used by middlewares to communicate.
Env map[string]interface{}
}
// PathParam provides a convenient access to the PathParams map.
func (r *Request) PathParam(name string) string {
return r.PathParams[name]
}
// DecodeJsonPayload reads the request body and decodes the JSON using json.Unmarshal.
func (r *Request) DecodeJsonPayload(v interface{}) error {
content, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
if len(content) == 0 {
return ErrJsonPayloadEmpty
}
err = json.Unmarshal(content, v)
if err != nil {
return err
}
return nil
}
// BaseUrl returns a new URL object with the Host and Scheme taken from the request.
// (without the trailing slash in the host)
func (r *Request) BaseUrl() *url.URL {
scheme := r.URL.Scheme
if scheme == "" {
scheme = "http"
}
// HTTP sometimes gives the default scheme as HTTP even when used with TLS
// Check if TLS is not nil and given back https scheme
if scheme == "http" && r.TLS != nil {
scheme = "https"
}
host := r.Host
if len(host) > 0 && host[len(host)-1] == '/' {
host = host[:len(host)-1]
}
return &url.URL{
Scheme: scheme,
Host: host,
}
}
// UrlFor returns the URL object from UriBase with the Path set to path, and the query
// string built with queryParams.
func (r *Request) UrlFor(path string, queryParams map[string][]string) *url.URL {
baseUrl := r.BaseUrl()
baseUrl.Path = path
if queryParams != nil {
query := url.Values{}
for k, v := range queryParams {
for _, vv := range v {
query.Add(k, vv)
}
}
baseUrl.RawQuery = query.Encode()
}
return baseUrl
}
// CorsInfo contains the CORS request info derived from a rest.Request.
type CorsInfo struct {
IsCors bool
IsPreflight bool
Origin string
OriginUrl *url.URL
// The header value is converted to uppercase to avoid common mistakes.
AccessControlRequestMethod string
// The header values are normalized with http.CanonicalHeaderKey.
AccessControlRequestHeaders []string
}
// GetCorsInfo derives CorsInfo from Request.
func (r *Request) GetCorsInfo() *CorsInfo {
origin := r.Header.Get("Origin")
var originUrl *url.URL
var isCors bool
if origin == "" {
isCors = false
} else if origin == "null" {
isCors = true
} else {
var err error
originUrl, err = url.ParseRequestURI(origin)
isCors = err == nil && r.Host != originUrl.Host
}
reqMethod := r.Header.Get("Access-Control-Request-Method")
reqHeaders := []string{}
rawReqHeaders := r.Header[http.CanonicalHeaderKey("Access-Control-Request-Headers")]
for _, rawReqHeader := range rawReqHeaders {
if len(rawReqHeader) == 0 {
continue
}
// net/http does not handle comma delimited headers for us
for _, reqHeader := range strings.Split(rawReqHeader, ",") {
reqHeaders = append(reqHeaders, http.CanonicalHeaderKey(strings.TrimSpace(reqHeader)))
}
}
isPreflight := isCors && r.Method == "OPTIONS" && reqMethod != ""
return &CorsInfo{
IsCors: isCors,
IsPreflight: isPreflight,
Origin: origin,
OriginUrl: originUrl,
AccessControlRequestMethod: strings.ToUpper(reqMethod),
AccessControlRequestHeaders: reqHeaders,
}
}

View File

@@ -1,127 +0,0 @@
package rest
import (
"bufio"
"encoding/json"
"net"
"net/http"
)
// A ResponseWriter interface dedicated to JSON HTTP response.
// Note, the responseWriter object instantiated by the framework also implements many other interfaces
// accessible by type assertion: http.ResponseWriter, http.Flusher, http.CloseNotifier, http.Hijacker.
type ResponseWriter interface {
// Identical to the http.ResponseWriter interface
Header() http.Header
// Use EncodeJson to generate the payload, write the headers with http.StatusOK if
// they are not already written, then write the payload.
// The Content-Type header is set to "application/json", unless already specified.
WriteJson(v interface{}) error
// Encode the data structure to JSON, mainly used to wrap ResponseWriter in
// middlewares.
EncodeJson(v interface{}) ([]byte, error)
// Similar to the http.ResponseWriter interface, with additional JSON related
// headers set.
WriteHeader(int)
}
// This allows to customize the field name used in the error response payload.
// It defaults to "Error" for compatibility reason, but can be changed before starting the server.
// eg: rest.ErrorFieldName = "errorMessage"
var ErrorFieldName = "Error"
// Error produces an error response in JSON with the following structure, '{"Error":"My error message"}'
// The standard plain text net/http Error helper can still be called like this:
// http.Error(w, "error message", code)
func Error(w ResponseWriter, error string, code int) {
w.WriteHeader(code)
err := w.WriteJson(map[string]string{ErrorFieldName: error})
if err != nil {
panic(err)
}
}
// NotFound produces a 404 response with the following JSON, '{"Error":"Resource not found"}'
// The standard plain text net/http NotFound helper can still be called like this:
// http.NotFound(w, r.Request)
func NotFound(w ResponseWriter, r *Request) {
Error(w, "Resource not found", http.StatusNotFound)
}
// Private responseWriter intantiated by the resource handler.
// It implements the following interfaces:
// ResponseWriter
// http.ResponseWriter
// http.Flusher
// http.CloseNotifier
// http.Hijacker
type responseWriter struct {
http.ResponseWriter
wroteHeader bool
}
func (w *responseWriter) WriteHeader(code int) {
if w.Header().Get("Content-Type") == "" {
// Per spec, UTF-8 is the default, and the charset parameter should not
// be necessary. But some clients (eg: Chrome) think otherwise.
// Since json.Marshal produces UTF-8, setting the charset parameter is a
// safe option.
w.Header().Set("Content-Type", "application/json; charset=utf-8")
}
w.ResponseWriter.WriteHeader(code)
w.wroteHeader = true
}
func (w *responseWriter) EncodeJson(v interface{}) ([]byte, error) {
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
return b, nil
}
// Encode the object in JSON and call Write.
func (w *responseWriter) WriteJson(v interface{}) error {
b, err := w.EncodeJson(v)
if err != nil {
return err
}
_, err = w.Write(b)
if err != nil {
return err
}
return nil
}
// Provided in order to implement the http.ResponseWriter interface.
func (w *responseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(b)
}
// Provided in order to implement the http.Flusher interface.
func (w *responseWriter) Flush() {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
flusher := w.ResponseWriter.(http.Flusher)
flusher.Flush()
}
// Provided in order to implement the http.CloseNotifier interface.
func (w *responseWriter) CloseNotify() <-chan bool {
notifier := w.ResponseWriter.(http.CloseNotifier)
return notifier.CloseNotify()
}
// Provided in order to implement the http.Hijacker interface.
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker := w.ResponseWriter.(http.Hijacker)
return hijacker.Hijack()
}

View File

@@ -1,107 +0,0 @@
package rest
import (
"strings"
)
// Route defines a route as consumed by the router. It can be instantiated directly, or using one
// of the shortcut methods: rest.Get, rest.Post, rest.Put, rest.Patch and rest.Delete.
type Route struct {
// Any HTTP method. It will be used as uppercase to avoid common mistakes.
HttpMethod string
// A string like "/resource/:id.json".
// Placeholders supported are:
// :paramName that matches any char to the first '/' or '.'
// #paramName that matches any char to the first '/'
// *paramName that matches everything to the end of the string
// (placeholder names must be unique per PathExp)
PathExp string
// Code that will be executed when this route is taken.
Func HandlerFunc
}
// MakePath generates the path corresponding to this Route and the provided path parameters.
// This is used for reverse route resolution.
func (route *Route) MakePath(pathParams map[string]string) string {
path := route.PathExp
for paramName, paramValue := range pathParams {
paramPlaceholder := ":" + paramName
relaxedPlaceholder := "#" + paramName
splatPlaceholder := "*" + paramName
r := strings.NewReplacer(paramPlaceholder, paramValue, splatPlaceholder, paramValue, relaxedPlaceholder, paramValue)
path = r.Replace(path)
}
return path
}
// Head is a shortcut method that instantiates a HEAD route. See the Route object the parameters definitions.
// Equivalent to &Route{"HEAD", pathExp, handlerFunc}
func Head(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "HEAD",
PathExp: pathExp,
Func: handlerFunc,
}
}
// Get is a shortcut method that instantiates a GET route. See the Route object the parameters definitions.
// Equivalent to &Route{"GET", pathExp, handlerFunc}
func Get(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "GET",
PathExp: pathExp,
Func: handlerFunc,
}
}
// Post is a shortcut method that instantiates a POST route. See the Route object the parameters definitions.
// Equivalent to &Route{"POST", pathExp, handlerFunc}
func Post(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "POST",
PathExp: pathExp,
Func: handlerFunc,
}
}
// Put is a shortcut method that instantiates a PUT route. See the Route object the parameters definitions.
// Equivalent to &Route{"PUT", pathExp, handlerFunc}
func Put(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "PUT",
PathExp: pathExp,
Func: handlerFunc,
}
}
// Patch is a shortcut method that instantiates a PATCH route. See the Route object the parameters definitions.
// Equivalent to &Route{"PATCH", pathExp, handlerFunc}
func Patch(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "PATCH",
PathExp: pathExp,
Func: handlerFunc,
}
}
// Delete is a shortcut method that instantiates a DELETE route. Equivalent to &Route{"DELETE", pathExp, handlerFunc}
func Delete(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "DELETE",
PathExp: pathExp,
Func: handlerFunc,
}
}
// Options is a shortcut method that instantiates an OPTIONS route. See the Route object the parameters definitions.
// Equivalent to &Route{"OPTIONS", pathExp, handlerFunc}
func Options(pathExp string, handlerFunc HandlerFunc) *Route {
return &Route{
HttpMethod: "OPTIONS",
PathExp: pathExp,
Func: handlerFunc,
}
}

View File

@@ -1,194 +0,0 @@
package rest
import (
"errors"
"github.com/ant0ine/go-json-rest/rest/trie"
"net/http"
"net/url"
"strings"
)
type router struct {
Routes []*Route
disableTrieCompression bool
index map[*Route]int
trie *trie.Trie
}
// MakeRouter returns the router app. Given a set of Routes, it dispatches the request to the
// HandlerFunc of the first route that matches. The order of the Routes matters.
func MakeRouter(routes ...*Route) (App, error) {
r := &router{
Routes: routes,
}
err := r.start()
if err != nil {
return nil, err
}
return r, nil
}
// Handle the REST routing and run the user code.
func (rt *router) AppFunc() HandlerFunc {
return func(writer ResponseWriter, request *Request) {
// find the route
route, params, pathMatched := rt.findRouteFromURL(request.Method, request.URL)
if route == nil {
if pathMatched {
// no route found, but path was matched: 405 Method Not Allowed
Error(writer, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// no route found, the path was not matched: 404 Not Found
NotFound(writer, request)
return
}
// a route was found, set the PathParams
request.PathParams = params
// run the user code
handler := route.Func
handler(writer, request)
}
}
// This is run for each new request, perf is important.
func escapedPath(urlObj *url.URL) string {
// the escape method of url.URL should be public
// that would avoid this split.
parts := strings.SplitN(urlObj.RequestURI(), "?", 2)
return parts[0]
}
var preEscape = strings.NewReplacer("*", "__SPLAT_PLACEHOLDER__", "#", "__RELAXED_PLACEHOLDER__")
var postEscape = strings.NewReplacer("__SPLAT_PLACEHOLDER__", "*", "__RELAXED_PLACEHOLDER__", "#")
// This is run at init time only.
func escapedPathExp(pathExp string) (string, error) {
// PathExp validation
if pathExp == "" {
return "", errors.New("empty PathExp")
}
if pathExp[0] != '/' {
return "", errors.New("PathExp must start with /")
}
if strings.Contains(pathExp, "?") {
return "", errors.New("PathExp must not contain the query string")
}
// Get the right escaping
// XXX a bit hacky
pathExp = preEscape.Replace(pathExp)
urlObj, err := url.Parse(pathExp)
if err != nil {
return "", err
}
// get the same escaping as find requests
pathExp = urlObj.RequestURI()
pathExp = postEscape.Replace(pathExp)
return pathExp, nil
}
// This validates the Routes and prepares the Trie data structure.
// It must be called once the Routes are defined and before trying to find Routes.
// The order matters, if multiple Routes match, the first defined will be used.
func (rt *router) start() error {
rt.trie = trie.New()
rt.index = map[*Route]int{}
for i, route := range rt.Routes {
// work with the PathExp urlencoded.
pathExp, err := escapedPathExp(route.PathExp)
if err != nil {
return err
}
// insert in the Trie
err = rt.trie.AddRoute(
strings.ToUpper(route.HttpMethod), // work with the HttpMethod in uppercase
pathExp,
route,
)
if err != nil {
return err
}
// index
rt.index[route] = i
}
if rt.disableTrieCompression == false {
rt.trie.Compress()
}
return nil
}
// return the result that has the route defined the earliest
func (rt *router) ofFirstDefinedRoute(matches []*trie.Match) *trie.Match {
minIndex := -1
var bestMatch *trie.Match
for _, result := range matches {
route := result.Route.(*Route)
routeIndex := rt.index[route]
if minIndex == -1 || routeIndex < minIndex {
minIndex = routeIndex
bestMatch = result
}
}
return bestMatch
}
// Return the first matching Route and the corresponding parameters for a given URL object.
func (rt *router) findRouteFromURL(httpMethod string, urlObj *url.URL) (*Route, map[string]string, bool) {
// lookup the routes in the Trie
matches, pathMatched := rt.trie.FindRoutesAndPathMatched(
strings.ToUpper(httpMethod), // work with the httpMethod in uppercase
escapedPath(urlObj), // work with the path urlencoded
)
// short cuts
if len(matches) == 0 {
// no route found
return nil, nil, pathMatched
}
if len(matches) == 1 {
// one route found
return matches[0].Route.(*Route), matches[0].Params, pathMatched
}
// multiple routes found, pick the first defined
result := rt.ofFirstDefinedRoute(matches)
return result.Route.(*Route), result.Params, pathMatched
}
// Parse the url string (complete or just the path) and return the first matching Route and the corresponding parameters.
func (rt *router) findRoute(httpMethod, urlStr string) (*Route, map[string]string, bool, error) {
// parse the url
urlObj, err := url.Parse(urlStr)
if err != nil {
return nil, nil, false, err
}
route, params, pathMatched := rt.findRouteFromURL(httpMethod, urlObj)
return route, params, pathMatched, nil
}

View File

@@ -1,129 +0,0 @@
package rest
import (
"fmt"
"log"
"os"
"sync"
"time"
)
// StatusMiddleware keeps track of various stats about the processed requests.
// It depends on request.Env["STATUS_CODE"] and request.Env["ELAPSED_TIME"],
// recorderMiddleware and timerMiddleware must be in the wrapped middlewares.
type StatusMiddleware struct {
lock sync.RWMutex
start time.Time
pid int
responseCounts map[string]int
totalResponseTime time.Time
}
// MiddlewareFunc makes StatusMiddleware implement the Middleware interface.
func (mw *StatusMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
mw.start = time.Now()
mw.pid = os.Getpid()
mw.responseCounts = map[string]int{}
mw.totalResponseTime = time.Time{}
return func(w ResponseWriter, r *Request) {
// call the handler
h(w, r)
if r.Env["STATUS_CODE"] == nil {
log.Fatal("StatusMiddleware: Env[\"STATUS_CODE\"] is nil, " +
"RecorderMiddleware may not be in the wrapped Middlewares.")
}
statusCode := r.Env["STATUS_CODE"].(int)
if r.Env["ELAPSED_TIME"] == nil {
log.Fatal("StatusMiddleware: Env[\"ELAPSED_TIME\"] is nil, " +
"TimerMiddleware may not be in the wrapped Middlewares.")
}
responseTime := r.Env["ELAPSED_TIME"].(*time.Duration)
mw.lock.Lock()
mw.responseCounts[fmt.Sprintf("%d", statusCode)]++
mw.totalResponseTime = mw.totalResponseTime.Add(*responseTime)
mw.lock.Unlock()
}
}
// Status contains stats and status information. It is returned by GetStatus.
// These information can be made available as an API endpoint, see the "status"
// example to install the following status route.
// GET /.status returns something like:
//
// {
// "Pid": 21732,
// "UpTime": "1m15.926272s",
// "UpTimeSec": 75.926272,
// "Time": "2013-03-04 08:00:27.152986 +0000 UTC",
// "TimeUnix": 1362384027,
// "StatusCodeCount": {
// "200": 53,
// "404": 11
// },
// "TotalCount": 64,
// "TotalResponseTime": "16.777ms",
// "TotalResponseTimeSec": 0.016777,
// "AverageResponseTime": "262.14us",
// "AverageResponseTimeSec": 0.00026214
// }
type Status struct {
Pid int
UpTime string
UpTimeSec float64
Time string
TimeUnix int64
StatusCodeCount map[string]int
TotalCount int
TotalResponseTime string
TotalResponseTimeSec float64
AverageResponseTime string
AverageResponseTimeSec float64
}
// GetStatus computes and returns a Status object based on the request informations accumulated
// since the start of the process.
func (mw *StatusMiddleware) GetStatus() *Status {
mw.lock.RLock()
now := time.Now()
uptime := now.Sub(mw.start)
totalCount := 0
for _, count := range mw.responseCounts {
totalCount += count
}
totalResponseTime := mw.totalResponseTime.Sub(time.Time{})
averageResponseTime := time.Duration(0)
if totalCount > 0 {
avgNs := int64(totalResponseTime) / int64(totalCount)
averageResponseTime = time.Duration(avgNs)
}
status := &Status{
Pid: mw.pid,
UpTime: uptime.String(),
UpTimeSec: uptime.Seconds(),
Time: now.String(),
TimeUnix: now.Unix(),
StatusCodeCount: mw.responseCounts,
TotalCount: totalCount,
TotalResponseTime: totalResponseTime.String(),
TotalResponseTimeSec: totalResponseTime.Seconds(),
AverageResponseTime: averageResponseTime.String(),
AverageResponseTimeSec: averageResponseTime.Seconds(),
}
mw.lock.RUnlock()
return status
}

View File

@@ -1,26 +0,0 @@
package rest
import (
"time"
)
// TimerMiddleware computes the elapsed time spent during the execution of the wrapped handler.
// The result is available to the wrapping handlers as request.Env["ELAPSED_TIME"].(*time.Duration),
// and as request.Env["START_TIME"].(*time.Time)
type TimerMiddleware struct{}
// MiddlewareFunc makes TimerMiddleware implement the Middleware interface.
func (mw *TimerMiddleware) MiddlewareFunc(h HandlerFunc) HandlerFunc {
return func(w ResponseWriter, r *Request) {
start := time.Now()
r.Env["START_TIME"] = &start
// call the handler
h(w, r)
end := time.Now()
elapsed := end.Sub(start)
r.Env["ELAPSED_TIME"] = &elapsed
}
}

View File

@@ -1,426 +0,0 @@
// Special Trie implementation for HTTP routing.
//
// This Trie implementation is designed to support strings that includes
// :param and *splat parameters. Strings that are commonly used to represent
// the Path in HTTP routing. This implementation also maintain for each Path
// a map of HTTP Methods associated with the Route.
//
// You probably don't need to use this package directly.
//
package trie
import (
"errors"
"fmt"
)
func splitParam(remaining string) (string, string) {
i := 0
for len(remaining) > i && remaining[i] != '/' && remaining[i] != '.' {
i++
}
return remaining[:i], remaining[i:]
}
func splitRelaxed(remaining string) (string, string) {
i := 0
for len(remaining) > i && remaining[i] != '/' {
i++
}
return remaining[:i], remaining[i:]
}
type node struct {
HttpMethodToRoute map[string]interface{}
Children map[string]*node
ChildrenKeyLen int
ParamChild *node
ParamName string
RelaxedChild *node
RelaxedName string
SplatChild *node
SplatName string
}
func (n *node) addRoute(httpMethod, pathExp string, route interface{}, usedParams []string) error {
if len(pathExp) == 0 {
// end of the path, leaf node, update the map
if n.HttpMethodToRoute == nil {
n.HttpMethodToRoute = map[string]interface{}{
httpMethod: route,
}
return nil
} else {
if n.HttpMethodToRoute[httpMethod] != nil {
return errors.New("node.Route already set, duplicated path and method")
}
n.HttpMethodToRoute[httpMethod] = route
return nil
}
}
token := pathExp[0:1]
remaining := pathExp[1:]
var nextNode *node
if token[0] == ':' {
// :param case
var name string
name, remaining = splitParam(remaining)
// Check param name is unique
for _, e := range usedParams {
if e == name {
return errors.New(
fmt.Sprintf("A route can't have two placeholders with the same name: %s", name),
)
}
}
usedParams = append(usedParams, name)
if n.ParamChild == nil {
n.ParamChild = &node{}
n.ParamName = name
} else {
if n.ParamName != name {
return errors.New(
fmt.Sprintf(
"Routes sharing a common placeholder MUST name it consistently: %s != %s",
n.ParamName,
name,
),
)
}
}
nextNode = n.ParamChild
} else if token[0] == '#' {
// #param case
var name string
name, remaining = splitRelaxed(remaining)
// Check param name is unique
for _, e := range usedParams {
if e == name {
return errors.New(
fmt.Sprintf("A route can't have two placeholders with the same name: %s", name),
)
}
}
usedParams = append(usedParams, name)
if n.RelaxedChild == nil {
n.RelaxedChild = &node{}
n.RelaxedName = name
} else {
if n.RelaxedName != name {
return errors.New(
fmt.Sprintf(
"Routes sharing a common placeholder MUST name it consistently: %s != %s",
n.RelaxedName,
name,
),
)
}
}
nextNode = n.RelaxedChild
} else if token[0] == '*' {
// *splat case
name := remaining
remaining = ""
// Check param name is unique
for _, e := range usedParams {
if e == name {
return errors.New(
fmt.Sprintf("A route can't have two placeholders with the same name: %s", name),
)
}
}
if n.SplatChild == nil {
n.SplatChild = &node{}
n.SplatName = name
}
nextNode = n.SplatChild
} else {
// general case
if n.Children == nil {
n.Children = map[string]*node{}
n.ChildrenKeyLen = 1
}
if n.Children[token] == nil {
n.Children[token] = &node{}
}
nextNode = n.Children[token]
}
return nextNode.addRoute(httpMethod, remaining, route, usedParams)
}
func (n *node) compress() {
// *splat branch
if n.SplatChild != nil {
n.SplatChild.compress()
}
// :param branch
if n.ParamChild != nil {
n.ParamChild.compress()
}
// #param branch
if n.RelaxedChild != nil {
n.RelaxedChild.compress()
}
// main branch
if len(n.Children) == 0 {
return
}
// compressable ?
canCompress := true
for _, node := range n.Children {
if node.HttpMethodToRoute != nil || node.SplatChild != nil || node.ParamChild != nil || node.RelaxedChild != nil {
canCompress = false
}
}
// compress
if canCompress {
merged := map[string]*node{}
for key, node := range n.Children {
for gdKey, gdNode := range node.Children {
mergedKey := key + gdKey
merged[mergedKey] = gdNode
}
}
n.Children = merged
n.ChildrenKeyLen++
n.compress()
// continue
} else {
for _, node := range n.Children {
node.compress()
}
}
}
func printFPadding(padding int, format string, a ...interface{}) {
for i := 0; i < padding; i++ {
fmt.Print(" ")
}
fmt.Printf(format, a...)
}
// Private function for now
func (n *node) printDebug(level int) {
level++
// *splat branch
if n.SplatChild != nil {
printFPadding(level, "*splat\n")
n.SplatChild.printDebug(level)
}
// :param branch
if n.ParamChild != nil {
printFPadding(level, ":param\n")
n.ParamChild.printDebug(level)
}
// #param branch
if n.RelaxedChild != nil {
printFPadding(level, "#relaxed\n")
n.RelaxedChild.printDebug(level)
}
// main branch
for key, node := range n.Children {
printFPadding(level, "\"%s\"\n", key)
node.printDebug(level)
}
}
// utility for the node.findRoutes recursive method
type paramMatch struct {
name string
value string
}
type findContext struct {
paramStack []paramMatch
matchFunc func(httpMethod, path string, node *node)
}
func newFindContext() *findContext {
return &findContext{
paramStack: []paramMatch{},
}
}
func (fc *findContext) pushParams(name, value string) {
fc.paramStack = append(
fc.paramStack,
paramMatch{name, value},
)
}
func (fc *findContext) popParams() {
fc.paramStack = fc.paramStack[:len(fc.paramStack)-1]
}
func (fc *findContext) paramsAsMap() map[string]string {
r := map[string]string{}
for _, param := range fc.paramStack {
if r[param.name] != "" {
// this is checked at addRoute time, and should never happen.
panic(fmt.Sprintf(
"placeholder %s already found, placeholder names should be unique per route",
param.name,
))
}
r[param.name] = param.value
}
return r
}
type Match struct {
// Same Route as in AddRoute
Route interface{}
// map of params matched for this result
Params map[string]string
}
func (n *node) find(httpMethod, path string, context *findContext) {
if n.HttpMethodToRoute != nil && path == "" {
context.matchFunc(httpMethod, path, n)
}
if len(path) == 0 {
return
}
// *splat branch
if n.SplatChild != nil {
context.pushParams(n.SplatName, path)
n.SplatChild.find(httpMethod, "", context)
context.popParams()
}
// :param branch
if n.ParamChild != nil {
value, remaining := splitParam(path)
context.pushParams(n.ParamName, value)
n.ParamChild.find(httpMethod, remaining, context)
context.popParams()
}
// #param branch
if n.RelaxedChild != nil {
value, remaining := splitRelaxed(path)
context.pushParams(n.RelaxedName, value)
n.RelaxedChild.find(httpMethod, remaining, context)
context.popParams()
}
// main branch
length := n.ChildrenKeyLen
if len(path) < length {
return
}
token := path[0:length]
remaining := path[length:]
if n.Children[token] != nil {
n.Children[token].find(httpMethod, remaining, context)
}
}
type Trie struct {
root *node
}
// Instanciate a Trie with an empty node as the root.
func New() *Trie {
return &Trie{
root: &node{},
}
}
// Insert the route in the Trie following or creating the nodes corresponding to the path.
func (t *Trie) AddRoute(httpMethod, pathExp string, route interface{}) error {
return t.root.addRoute(httpMethod, pathExp, route, []string{})
}
// Reduce the size of the tree, must be done after the last AddRoute.
func (t *Trie) Compress() {
t.root.compress()
}
// Private function for now.
func (t *Trie) printDebug() {
fmt.Print("<trie>\n")
t.root.printDebug(0)
fmt.Print("</trie>\n")
}
// Given a path and an http method, return all the matching routes.
func (t *Trie) FindRoutes(httpMethod, path string) []*Match {
context := newFindContext()
matches := []*Match{}
context.matchFunc = func(httpMethod, path string, node *node) {
if node.HttpMethodToRoute[httpMethod] != nil {
// path and method match, found a route !
matches = append(
matches,
&Match{
Route: node.HttpMethodToRoute[httpMethod],
Params: context.paramsAsMap(),
},
)
}
}
t.root.find(httpMethod, path, context)
return matches
}
// Same as FindRoutes, but return in addition a boolean indicating if the path was matched.
// Useful to return 405
func (t *Trie) FindRoutesAndPathMatched(httpMethod, path string) ([]*Match, bool) {
context := newFindContext()
pathMatched := false
matches := []*Match{}
context.matchFunc = func(httpMethod, path string, node *node) {
pathMatched = true
if node.HttpMethodToRoute[httpMethod] != nil {
// path and method match, found a route !
matches = append(
matches,
&Match{
Route: node.HttpMethodToRoute[httpMethod],
Params: context.paramsAsMap(),
},
)
}
}
t.root.find(httpMethod, path, context)
return matches, pathMatched
}
// Given a path, and whatever the http method, return all the matching routes.
func (t *Trie) FindRoutesForPath(path string) []*Match {
context := newFindContext()
matches := []*Match{}
context.matchFunc = func(httpMethod, path string, node *node) {
params := context.paramsAsMap()
for _, route := range node.HttpMethodToRoute {
matches = append(
matches,
&Match{
Route: route,
Params: params,
},
)
}
}
t.root.find("", path, context)
return matches
}

View File

@@ -1,15 +0,0 @@
ISC License
Copyright (c) 2012-2016 Dave Collins <dave@davec.name>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

View File

@@ -1,145 +0,0 @@
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when the code is not running on Google App Engine, compiled by GopherJS, and
// "-tags safe" is not added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// Go versions prior to 1.4 are disabled because they use a different layout
// for interfaces which make the implementation of unsafeReflectValue more complex.
// +build !js,!appengine,!safe,!disableunsafe,go1.4
package spew
import (
"reflect"
"unsafe"
)
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = false
// ptrSize is the size of a pointer on the current arch.
ptrSize = unsafe.Sizeof((*byte)(nil))
)
type flag uintptr
var (
// flagRO indicates whether the value field of a reflect.Value
// is read-only.
flagRO flag
// flagAddr indicates whether the address of the reflect.Value's
// value may be taken.
flagAddr flag
)
// flagKindMask holds the bits that make up the kind
// part of the flags field. In all the supported versions,
// it is in the lower 5 bits.
const flagKindMask = flag(0x1f)
// Different versions of Go have used different
// bit layouts for the flags type. This table
// records the known combinations.
var okFlags = []struct {
ro, addr flag
}{{
// From Go 1.4 to 1.5
ro: 1 << 5,
addr: 1 << 7,
}, {
// Up to Go tip.
ro: 1<<5 | 1<<6,
addr: 1 << 8,
}}
var flagValOffset = func() uintptr {
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
return field.Offset
}()
// flagField returns a pointer to the flag field of a reflect.Value.
func flagField(v *reflect.Value) *flag {
return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset))
}
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
// the typical safety restrictions preventing access to unaddressable and
// unexported data. It works by digging the raw pointer to the underlying
// value out of the protected value and generating a new unprotected (unsafe)
// reflect.Value to it.
//
// This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) reflect.Value {
if !v.IsValid() || (v.CanInterface() && v.CanAddr()) {
return v
}
flagFieldPtr := flagField(&v)
*flagFieldPtr &^= flagRO
*flagFieldPtr |= flagAddr
return v
}
// Sanity checks against future reflect package changes
// to the type or semantics of the Value.flag field.
func init() {
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() {
panic("reflect.Value flag field has changed kind")
}
type t0 int
var t struct {
A t0
// t0 will have flagEmbedRO set.
t0
// a will have flagStickyRO set
a t0
}
vA := reflect.ValueOf(t).FieldByName("A")
va := reflect.ValueOf(t).FieldByName("a")
vt0 := reflect.ValueOf(t).FieldByName("t0")
// Infer flagRO from the difference between the flags
// for the (otherwise identical) fields in t.
flagPublic := *flagField(&vA)
flagWithRO := *flagField(&va) | *flagField(&vt0)
flagRO = flagPublic ^ flagWithRO
// Infer flagAddr from the difference between a value
// taken from a pointer and not.
vPtrA := reflect.ValueOf(&t).Elem().FieldByName("A")
flagNoPtr := *flagField(&vA)
flagPtr := *flagField(&vPtrA)
flagAddr = flagNoPtr ^ flagPtr
// Check that the inferred flags tally with one of the known versions.
for _, f := range okFlags {
if flagRO == f.ro && flagAddr == f.addr {
return
}
}
panic("reflect.Value read-only flag has changed semantics")
}

View File

@@ -1,38 +0,0 @@
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when the code is running on Google App Engine, compiled by GopherJS, or
// "-tags safe" is added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// +build js appengine safe disableunsafe !go1.4
package spew
import "reflect"
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = true
)
// unsafeReflectValue typically converts the passed reflect.Value into a one
// that bypasses the typical safety restrictions preventing access to
// unaddressable and unexported data. However, doing this relies on access to
// the unsafe package. This is a stub version which simply returns the passed
// reflect.Value when the unsafe package is not available.
func unsafeReflectValue(v reflect.Value) reflect.Value {
return v
}

View File

@@ -1,341 +0,0 @@
/*
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"io"
"reflect"
"sort"
"strconv"
)
// Some constants in the form of bytes to avoid string overhead. This mirrors
// the technique used in the fmt package.
var (
panicBytes = []byte("(PANIC=")
plusBytes = []byte("+")
iBytes = []byte("i")
trueBytes = []byte("true")
falseBytes = []byte("false")
interfaceBytes = []byte("(interface {})")
commaNewlineBytes = []byte(",\n")
newlineBytes = []byte("\n")
openBraceBytes = []byte("{")
openBraceNewlineBytes = []byte("{\n")
closeBraceBytes = []byte("}")
asteriskBytes = []byte("*")
colonBytes = []byte(":")
colonSpaceBytes = []byte(": ")
openParenBytes = []byte("(")
closeParenBytes = []byte(")")
spaceBytes = []byte(" ")
pointerChainBytes = []byte("->")
nilAngleBytes = []byte("<nil>")
maxNewlineBytes = []byte("<max depth reached>\n")
maxShortBytes = []byte("<max>")
circularBytes = []byte("<already shown>")
circularShortBytes = []byte("<shown>")
invalidAngleBytes = []byte("<invalid>")
openBracketBytes = []byte("[")
closeBracketBytes = []byte("]")
percentBytes = []byte("%")
precisionBytes = []byte(".")
openAngleBytes = []byte("<")
closeAngleBytes = []byte(">")
openMapBytes = []byte("map[")
closeMapBytes = []byte("]")
lenEqualsBytes = []byte("len=")
capEqualsBytes = []byte("cap=")
)
// hexDigits is used to map a decimal value to a hex digit.
var hexDigits = "0123456789abcdef"
// catchPanic handles any panics that might occur during the handleMethods
// calls.
func catchPanic(w io.Writer, v reflect.Value) {
if err := recover(); err != nil {
w.Write(panicBytes)
fmt.Fprintf(w, "%v", err)
w.Write(closeParenBytes)
}
}
// handleMethods attempts to call the Error and String methods on the underlying
// type the passed reflect.Value represents and outputes the result to Writer w.
//
// It handles panics in any called methods by catching and displaying the error
// as the formatted value.
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
// We need an interface to check if the type implements the error or
// Stringer interface. However, the reflect package won't give us an
// interface on certain things like unexported struct fields in order
// to enforce visibility rules. We use unsafe, when it's available,
// to bypass these restrictions since this package does not mutate the
// values.
if !v.CanInterface() {
if UnsafeDisabled {
return false
}
v = unsafeReflectValue(v)
}
// Choose whether or not to do error and Stringer interface lookups against
// the base type or a pointer to the base type depending on settings.
// Technically calling one of these methods with a pointer receiver can
// mutate the value, however, types which choose to satisify an error or
// Stringer interface with a pointer receiver should not be mutating their
// state inside these interface methods.
if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() {
v = unsafeReflectValue(v)
}
if v.CanAddr() {
v = v.Addr()
}
// Is it an error or Stringer?
switch iface := v.Interface().(type) {
case error:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.Error()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.Error()))
return true
case fmt.Stringer:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.String()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.String()))
return true
}
return false
}
// printBool outputs a boolean value as true or false to Writer w.
func printBool(w io.Writer, val bool) {
if val {
w.Write(trueBytes)
} else {
w.Write(falseBytes)
}
}
// printInt outputs a signed integer value to Writer w.
func printInt(w io.Writer, val int64, base int) {
w.Write([]byte(strconv.FormatInt(val, base)))
}
// printUint outputs an unsigned integer value to Writer w.
func printUint(w io.Writer, val uint64, base int) {
w.Write([]byte(strconv.FormatUint(val, base)))
}
// printFloat outputs a floating point value using the specified precision,
// which is expected to be 32 or 64bit, to Writer w.
func printFloat(w io.Writer, val float64, precision int) {
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
}
// printComplex outputs a complex value using the specified float precision
// for the real and imaginary parts to Writer w.
func printComplex(w io.Writer, c complex128, floatPrecision int) {
r := real(c)
w.Write(openParenBytes)
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
i := imag(c)
if i >= 0 {
w.Write(plusBytes)
}
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
w.Write(iBytes)
w.Write(closeParenBytes)
}
// printHexPtr outputs a uintptr formatted as hexadecimal with a leading '0x'
// prefix to Writer w.
func printHexPtr(w io.Writer, p uintptr) {
// Null pointer.
num := uint64(p)
if num == 0 {
w.Write(nilAngleBytes)
return
}
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
buf := make([]byte, 18)
// It's simpler to construct the hex string right to left.
base := uint64(16)
i := len(buf) - 1
for num >= base {
buf[i] = hexDigits[num%base]
num /= base
i--
}
buf[i] = hexDigits[num]
// Add '0x' prefix.
i--
buf[i] = 'x'
i--
buf[i] = '0'
// Strip unused leading bytes.
buf = buf[i:]
w.Write(buf)
}
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
// elements to be sorted.
type valuesSorter struct {
values []reflect.Value
strings []string // either nil or same len and values
cs *ConfigState
}
// newValuesSorter initializes a valuesSorter instance, which holds a set of
// surrogate keys on which the data should be sorted. It uses flags in
// ConfigState to decide if and how to populate those surrogate keys.
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
vs := &valuesSorter{values: values, cs: cs}
if canSortSimply(vs.values[0].Kind()) {
return vs
}
if !cs.DisableMethods {
vs.strings = make([]string, len(values))
for i := range vs.values {
b := bytes.Buffer{}
if !handleMethods(cs, &b, vs.values[i]) {
vs.strings = nil
break
}
vs.strings[i] = b.String()
}
}
if vs.strings == nil && cs.SpewKeys {
vs.strings = make([]string, len(values))
for i := range vs.values {
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
}
}
return vs
}
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
// directly, or whether it should be considered for sorting by surrogate keys
// (if the ConfigState allows it).
func canSortSimply(kind reflect.Kind) bool {
// This switch parallels valueSortLess, except for the default case.
switch kind {
case reflect.Bool:
return true
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return true
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.String:
return true
case reflect.Uintptr:
return true
case reflect.Array:
return true
}
return false
}
// Len returns the number of values in the slice. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Len() int {
return len(s.values)
}
// Swap swaps the values at the passed indices. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Swap(i, j int) {
s.values[i], s.values[j] = s.values[j], s.values[i]
if s.strings != nil {
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
}
}
// valueSortLess returns whether the first value should sort before the second
// value. It is used by valueSorter.Less as part of the sort.Interface
// implementation.
func valueSortLess(a, b reflect.Value) bool {
switch a.Kind() {
case reflect.Bool:
return !a.Bool() && b.Bool()
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return a.Int() < b.Int()
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return a.Uint() < b.Uint()
case reflect.Float32, reflect.Float64:
return a.Float() < b.Float()
case reflect.String:
return a.String() < b.String()
case reflect.Uintptr:
return a.Uint() < b.Uint()
case reflect.Array:
// Compare the contents of both arrays.
l := a.Len()
for i := 0; i < l; i++ {
av := a.Index(i)
bv := b.Index(i)
if av.Interface() == bv.Interface() {
continue
}
return valueSortLess(av, bv)
}
}
return a.String() < b.String()
}
// Less returns whether the value at index i should sort before the
// value at index j. It is part of the sort.Interface implementation.
func (s *valuesSorter) Less(i, j int) bool {
if s.strings == nil {
return valueSortLess(s.values[i], s.values[j])
}
return s.strings[i] < s.strings[j]
}
// sortValues is a sort function that handles both native types and any type that
// can be converted to error or Stringer. Other inputs are sorted according to
// their Value.String() value to ensure display stability.
func sortValues(values []reflect.Value, cs *ConfigState) {
if len(values) == 0 {
return
}
sort.Sort(newValuesSorter(values, cs))
}

View File

@@ -1,306 +0,0 @@
/*
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"io"
"os"
)
// ConfigState houses the configuration options used by spew to format and
// display values. There is a global instance, Config, that is used to control
// all top-level Formatter and Dump functionality. Each ConfigState instance
// provides methods equivalent to the top-level functions.
//
// The zero value for ConfigState provides no indentation. You would typically
// want to set it to a space or a tab.
//
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
// with default settings. See the documentation of NewDefaultConfig for default
// values.
type ConfigState struct {
// Indent specifies the string to use for each indentation level. The
// global config instance that all top-level functions use set this to a
// single space by default. If you would like more indentation, you might
// set this to a tab with "\t" or perhaps two spaces with " ".
Indent string
// MaxDepth controls the maximum number of levels to descend into nested
// data structures. The default, 0, means there is no limit.
//
// NOTE: Circular data structures are properly detected, so it is not
// necessary to set this value unless you specifically want to limit deeply
// nested data structures.
MaxDepth int
// DisableMethods specifies whether or not error and Stringer interfaces are
// invoked for types that implement them.
DisableMethods bool
// DisablePointerMethods specifies whether or not to check for and invoke
// error and Stringer interfaces on types which only accept a pointer
// receiver when the current type is not a pointer.
//
// NOTE: This might be an unsafe action since calling one of these methods
// with a pointer receiver could technically mutate the value, however,
// in practice, types which choose to satisify an error or Stringer
// interface with a pointer receiver should not be mutating their state
// inside these interface methods. As a result, this option relies on
// access to the unsafe package, so it will not have any effect when
// running in environments without access to the unsafe package such as
// Google App Engine or with the "safe" build tag specified.
DisablePointerMethods bool
// DisablePointerAddresses specifies whether to disable the printing of
// pointer addresses. This is useful when diffing data structures in tests.
DisablePointerAddresses bool
// DisableCapacities specifies whether to disable the printing of capacities
// for arrays, slices, maps and channels. This is useful when diffing
// data structures in tests.
DisableCapacities bool
// ContinueOnMethod specifies whether or not recursion should continue once
// a custom error or Stringer interface is invoked. The default, false,
// means it will print the results of invoking the custom error or Stringer
// interface and return immediately instead of continuing to recurse into
// the internals of the data type.
//
// NOTE: This flag does not have any effect if method invocation is disabled
// via the DisableMethods or DisablePointerMethods options.
ContinueOnMethod bool
// SortKeys specifies map keys should be sorted before being printed. Use
// this to have a more deterministic, diffable output. Note that only
// native types (bool, int, uint, floats, uintptr and string) and types
// that support the error or Stringer interfaces (if methods are
// enabled) are supported, with other types sorted according to the
// reflect.Value.String() output which guarantees display stability.
SortKeys bool
// SpewKeys specifies that, as a last resort attempt, map keys should
// be spewed to strings and sorted by those strings. This is only
// considered if SortKeys is true.
SpewKeys bool
}
// Config is the active configuration of the top-level functions.
// The configuration can be changed by modifying the contents of spew.Config.
var Config = ConfigState{Indent: " "}
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the formatted string as a value that satisfies error. See NewFormatter
// for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
return fmt.Errorf(format, c.convertArgs(a)...)
}
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprint(w, c.convertArgs(a)...)
}
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
return fmt.Fprintf(w, format, c.convertArgs(a)...)
}
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
// passed with a Formatter interface returned by c.NewFormatter. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprintln(w, c.convertArgs(a)...)
}
// Print is a wrapper for fmt.Print that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
return fmt.Print(c.convertArgs(a)...)
}
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
return fmt.Printf(format, c.convertArgs(a)...)
}
// Println is a wrapper for fmt.Println that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
return fmt.Println(c.convertArgs(a)...)
}
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprint(a ...interface{}) string {
return fmt.Sprint(c.convertArgs(a)...)
}
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
return fmt.Sprintf(format, c.convertArgs(a)...)
}
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
// were passed with a Formatter interface returned by c.NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprintln(a ...interface{}) string {
return fmt.Sprintln(c.convertArgs(a)...)
}
/*
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
interface. As a result, it integrates cleanly with standard fmt package
printing functions. The formatter is useful for inline printing of smaller data
types similar to the standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Typically this function shouldn't be called directly. It is much easier to make
use of the custom formatter by calling one of the convenience functions such as
c.Printf, c.Println, or c.Printf.
*/
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
return newFormatter(c, v)
}
// Fdump formats and displays the passed arguments to io.Writer w. It formats
// exactly the same as Dump.
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
fdump(c, w, a...)
}
/*
Dump displays the passed parameters to standard out with newlines, customizable
indentation, and additional debug information such as complete types and all
pointer addresses used to indirect to the final value. It provides the
following features over the built-in printing facilities provided by the fmt
package:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output
The configuration options are controlled by modifying the public members
of c. See ConfigState for options documentation.
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
get the formatted result as a string.
*/
func (c *ConfigState) Dump(a ...interface{}) {
fdump(c, os.Stdout, a...)
}
// Sdump returns a string with the passed arguments formatted exactly the same
// as Dump.
func (c *ConfigState) Sdump(a ...interface{}) string {
var buf bytes.Buffer
fdump(c, &buf, a...)
return buf.String()
}
// convertArgs accepts a slice of arguments and returns a slice of the same
// length with each argument converted to a spew Formatter interface using
// the ConfigState associated with s.
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
formatters = make([]interface{}, len(args))
for index, arg := range args {
formatters[index] = newFormatter(c, arg)
}
return formatters
}
// NewDefaultConfig returns a ConfigState with the following default settings.
//
// Indent: " "
// MaxDepth: 0
// DisableMethods: false
// DisablePointerMethods: false
// ContinueOnMethod: false
// SortKeys: false
func NewDefaultConfig() *ConfigState {
return &ConfigState{Indent: " "}
}

View File

@@ -1,211 +0,0 @@
/*
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
/*
Package spew implements a deep pretty printer for Go data structures to aid in
debugging.
A quick overview of the additional features spew provides over the built-in
printing facilities for Go data types are as follows:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output (only when using
Dump style)
There are two different approaches spew allows for dumping Go data structures:
* Dump style which prints with newlines, customizable indentation,
and additional debug information such as types and all pointer addresses
used to indirect to the final value
* A custom Formatter interface that integrates cleanly with the standard fmt
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
similar to the default %v while providing the additional functionality
outlined above and passing unsupported format verbs such as %x and %q
along to fmt
Quick Start
This section demonstrates how to quickly get started with spew. See the
sections below for further details on formatting and configuration options.
To dump a variable with full newlines, indentation, type, and pointer
information use Dump, Fdump, or Sdump:
spew.Dump(myVar1, myVar2, ...)
spew.Fdump(someWriter, myVar1, myVar2, ...)
str := spew.Sdump(myVar1, myVar2, ...)
Alternatively, if you would prefer to use format strings with a compacted inline
printing style, use the convenience wrappers Printf, Fprintf, etc with
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
%#+v (adds types and pointer addresses):
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
Configuration Options
Configuration of spew is handled by fields in the ConfigState type. For
convenience, all of the top-level functions use a global state available
via the spew.Config global.
It is also possible to create a ConfigState instance that provides methods
equivalent to the top-level functions. This allows concurrent configuration
options. See the ConfigState documentation for more details.
The following configuration options are available:
* Indent
String to use for each indentation level for Dump functions.
It is a single space by default. A popular alternative is "\t".
* MaxDepth
Maximum number of levels to descend into nested data structures.
There is no limit by default.
* DisableMethods
Disables invocation of error and Stringer interface methods.
Method invocation is enabled by default.
* DisablePointerMethods
Disables invocation of error and Stringer interface methods on types
which only accept pointer receivers from non-pointer variables.
Pointer method invocation is enabled by default.
* DisablePointerAddresses
DisablePointerAddresses specifies whether to disable the printing of
pointer addresses. This is useful when diffing data structures in tests.
* DisableCapacities
DisableCapacities specifies whether to disable the printing of
capacities for arrays, slices, maps and channels. This is useful when
diffing data structures in tests.
* ContinueOnMethod
Enables recursion into types after invoking error and Stringer interface
methods. Recursion after method invocation is disabled by default.
* SortKeys
Specifies map keys should be sorted before being printed. Use
this to have a more deterministic, diffable output. Note that
only native types (bool, int, uint, floats, uintptr and string)
and types which implement error or Stringer interfaces are
supported with other types sorted according to the
reflect.Value.String() output which guarantees display
stability. Natural map order is used by default.
* SpewKeys
Specifies that, as a last resort attempt, map keys should be
spewed to strings and sorted by those strings. This is only
considered if SortKeys is true.
Dump Usage
Simply call spew.Dump with a list of variables you want to dump:
spew.Dump(myVar1, myVar2, ...)
You may also call spew.Fdump if you would prefer to output to an arbitrary
io.Writer. For example, to dump to standard error:
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
A third option is to call spew.Sdump to get the formatted output as a string:
str := spew.Sdump(myVar1, myVar2, ...)
Sample Dump Output
See the Dump example for details on the setup of the types and variables being
shown here.
(main.Foo) {
unexportedField: (*main.Bar)(0xf84002e210)({
flag: (main.Flag) flagTwo,
data: (uintptr) <nil>
}),
ExportedField: (map[interface {}]interface {}) (len=1) {
(string) (len=3) "one": (bool) true
}
}
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
command as shown.
([]uint8) (len=32 cap=32) {
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
00000020 31 32 |12|
}
Custom Formatter
Spew provides a custom formatter that implements the fmt.Formatter interface
so that it integrates cleanly with standard fmt package printing functions. The
formatter is useful for inline printing of smaller data types similar to the
standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Custom Formatter Usage
The simplest way to make use of the spew custom formatter is to call one of the
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
functions have syntax you are most likely already familiar with:
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
spew.Println(myVar, myVar2)
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
See the Index for the full list convenience functions.
Sample Formatter Output
Double pointer to a uint8:
%v: <**>5
%+v: <**>(0xf8400420d0->0xf8400420c8)5
%#v: (**uint8)5
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
Pointer to circular struct with a uint8 field and a pointer to itself:
%v: <*>{1 <*><shown>}
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
See the Printf example for details on the setup of variables being shown
here.
Errors
Since it is possible for custom Stringer/error interfaces to panic, spew
detects them and handles them internally by printing the panic information
inline with the output. Since spew is intended to provide deep pretty printing
capabilities on structures, it intentionally does not return any errors.
*/
package spew

View File

@@ -1,509 +0,0 @@
/*
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"os"
"reflect"
"regexp"
"strconv"
"strings"
)
var (
// uint8Type is a reflect.Type representing a uint8. It is used to
// convert cgo types to uint8 slices for hexdumping.
uint8Type = reflect.TypeOf(uint8(0))
// cCharRE is a regular expression that matches a cgo char.
// It is used to detect character arrays to hexdump them.
cCharRE = regexp.MustCompile(`^.*\._Ctype_char$`)
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
// char. It is used to detect unsigned character arrays to hexdump
// them.
cUnsignedCharRE = regexp.MustCompile(`^.*\._Ctype_unsignedchar$`)
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
// It is used to detect uint8_t arrays to hexdump them.
cUint8tCharRE = regexp.MustCompile(`^.*\._Ctype_uint8_t$`)
)
// dumpState contains information about the state of a dump operation.
type dumpState struct {
w io.Writer
depth int
pointers map[uintptr]int
ignoreNextType bool
ignoreNextIndent bool
cs *ConfigState
}
// indent performs indentation according to the depth level and cs.Indent
// option.
func (d *dumpState) indent() {
if d.ignoreNextIndent {
d.ignoreNextIndent = false
return
}
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
}
// unpackValue returns values inside of non-nil interfaces when possible.
// This is useful for data types like structs, arrays, slices, and maps which
// can contain varying types packed inside an interface.
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface && !v.IsNil() {
v = v.Elem()
}
return v
}
// dumpPtr handles formatting of pointers by indirecting them as necessary.
func (d *dumpState) dumpPtr(v reflect.Value) {
// Remove pointers at or below the current depth from map used to detect
// circular refs.
for k, depth := range d.pointers {
if depth >= d.depth {
delete(d.pointers, k)
}
}
// Keep list of all dereferenced pointers to show later.
pointerChain := make([]uintptr, 0)
// Figure out how many levels of indirection there are by dereferencing
// pointers and unpacking interfaces down the chain while detecting circular
// references.
nilFound := false
cycleFound := false
indirects := 0
ve := v
for ve.Kind() == reflect.Ptr {
if ve.IsNil() {
nilFound = true
break
}
indirects++
addr := ve.Pointer()
pointerChain = append(pointerChain, addr)
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
cycleFound = true
indirects--
break
}
d.pointers[addr] = d.depth
ve = ve.Elem()
if ve.Kind() == reflect.Interface {
if ve.IsNil() {
nilFound = true
break
}
ve = ve.Elem()
}
}
// Display type information.
d.w.Write(openParenBytes)
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
d.w.Write([]byte(ve.Type().String()))
d.w.Write(closeParenBytes)
// Display pointer information.
if !d.cs.DisablePointerAddresses && len(pointerChain) > 0 {
d.w.Write(openParenBytes)
for i, addr := range pointerChain {
if i > 0 {
d.w.Write(pointerChainBytes)
}
printHexPtr(d.w, addr)
}
d.w.Write(closeParenBytes)
}
// Display dereferenced value.
d.w.Write(openParenBytes)
switch {
case nilFound:
d.w.Write(nilAngleBytes)
case cycleFound:
d.w.Write(circularBytes)
default:
d.ignoreNextType = true
d.dump(ve)
}
d.w.Write(closeParenBytes)
}
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
// reflection) arrays and slices are dumped in hexdump -C fashion.
func (d *dumpState) dumpSlice(v reflect.Value) {
// Determine whether this type should be hex dumped or not. Also,
// for types which should be hexdumped, try to use the underlying data
// first, then fall back to trying to convert them to a uint8 slice.
var buf []uint8
doConvert := false
doHexDump := false
numEntries := v.Len()
if numEntries > 0 {
vt := v.Index(0).Type()
vts := vt.String()
switch {
// C types that need to be converted.
case cCharRE.MatchString(vts):
fallthrough
case cUnsignedCharRE.MatchString(vts):
fallthrough
case cUint8tCharRE.MatchString(vts):
doConvert = true
// Try to use existing uint8 slices and fall back to converting
// and copying if that fails.
case vt.Kind() == reflect.Uint8:
// We need an addressable interface to convert the type
// to a byte slice. However, the reflect package won't
// give us an interface on certain things like
// unexported struct fields in order to enforce
// visibility rules. We use unsafe, when available, to
// bypass these restrictions since this package does not
// mutate the values.
vs := v
if !vs.CanInterface() || !vs.CanAddr() {
vs = unsafeReflectValue(vs)
}
if !UnsafeDisabled {
vs = vs.Slice(0, numEntries)
// Use the existing uint8 slice if it can be
// type asserted.
iface := vs.Interface()
if slice, ok := iface.([]uint8); ok {
buf = slice
doHexDump = true
break
}
}
// The underlying data needs to be converted if it can't
// be type asserted to a uint8 slice.
doConvert = true
}
// Copy and convert the underlying type if needed.
if doConvert && vt.ConvertibleTo(uint8Type) {
// Convert and copy each element into a uint8 byte
// slice.
buf = make([]uint8, numEntries)
for i := 0; i < numEntries; i++ {
vv := v.Index(i)
buf[i] = uint8(vv.Convert(uint8Type).Uint())
}
doHexDump = true
}
}
// Hexdump the entire slice as needed.
if doHexDump {
indent := strings.Repeat(d.cs.Indent, d.depth)
str := indent + hex.Dump(buf)
str = strings.Replace(str, "\n", "\n"+indent, -1)
str = strings.TrimRight(str, d.cs.Indent)
d.w.Write([]byte(str))
return
}
// Recursively call dump for each item.
for i := 0; i < numEntries; i++ {
d.dump(d.unpackValue(v.Index(i)))
if i < (numEntries - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
// dump is the main workhorse for dumping a value. It uses the passed reflect
// value to figure out what kind of object we are dealing with and formats it
// appropriately. It is a recursive function, however circular data structures
// are detected and handled properly.
func (d *dumpState) dump(v reflect.Value) {
// Handle invalid reflect values immediately.
kind := v.Kind()
if kind == reflect.Invalid {
d.w.Write(invalidAngleBytes)
return
}
// Handle pointers specially.
if kind == reflect.Ptr {
d.indent()
d.dumpPtr(v)
return
}
// Print type information unless already handled elsewhere.
if !d.ignoreNextType {
d.indent()
d.w.Write(openParenBytes)
d.w.Write([]byte(v.Type().String()))
d.w.Write(closeParenBytes)
d.w.Write(spaceBytes)
}
d.ignoreNextType = false
// Display length and capacity if the built-in len and cap functions
// work with the value's kind and the len/cap itself is non-zero.
valueLen, valueCap := 0, 0
switch v.Kind() {
case reflect.Array, reflect.Slice, reflect.Chan:
valueLen, valueCap = v.Len(), v.Cap()
case reflect.Map, reflect.String:
valueLen = v.Len()
}
if valueLen != 0 || !d.cs.DisableCapacities && valueCap != 0 {
d.w.Write(openParenBytes)
if valueLen != 0 {
d.w.Write(lenEqualsBytes)
printInt(d.w, int64(valueLen), 10)
}
if !d.cs.DisableCapacities && valueCap != 0 {
if valueLen != 0 {
d.w.Write(spaceBytes)
}
d.w.Write(capEqualsBytes)
printInt(d.w, int64(valueCap), 10)
}
d.w.Write(closeParenBytes)
d.w.Write(spaceBytes)
}
// Call Stringer/error interfaces if they exist and the handle methods flag
// is enabled
if !d.cs.DisableMethods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(d.cs, d.w, v); handled {
return
}
}
}
switch kind {
case reflect.Invalid:
// Do nothing. We should never get here since invalid has already
// been handled above.
case reflect.Bool:
printBool(d.w, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
printInt(d.w, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
printUint(d.w, v.Uint(), 10)
case reflect.Float32:
printFloat(d.w, v.Float(), 32)
case reflect.Float64:
printFloat(d.w, v.Float(), 64)
case reflect.Complex64:
printComplex(d.w, v.Complex(), 32)
case reflect.Complex128:
printComplex(d.w, v.Complex(), 64)
case reflect.Slice:
if v.IsNil() {
d.w.Write(nilAngleBytes)
break
}
fallthrough
case reflect.Array:
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
d.dumpSlice(v)
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.String:
d.w.Write([]byte(strconv.Quote(v.String())))
case reflect.Interface:
// The only time we should get here is for nil interfaces due to
// unpackValue calls.
if v.IsNil() {
d.w.Write(nilAngleBytes)
}
case reflect.Ptr:
// Do nothing. We should never get here since pointers have already
// been handled above.
case reflect.Map:
// nil maps should be indicated as different than empty maps
if v.IsNil() {
d.w.Write(nilAngleBytes)
break
}
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
numEntries := v.Len()
keys := v.MapKeys()
if d.cs.SortKeys {
sortValues(keys, d.cs)
}
for i, key := range keys {
d.dump(d.unpackValue(key))
d.w.Write(colonSpaceBytes)
d.ignoreNextIndent = true
d.dump(d.unpackValue(v.MapIndex(key)))
if i < (numEntries - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.Struct:
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
vt := v.Type()
numFields := v.NumField()
for i := 0; i < numFields; i++ {
d.indent()
vtf := vt.Field(i)
d.w.Write([]byte(vtf.Name))
d.w.Write(colonSpaceBytes)
d.ignoreNextIndent = true
d.dump(d.unpackValue(v.Field(i)))
if i < (numFields - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.Uintptr:
printHexPtr(d.w, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
printHexPtr(d.w, v.Pointer())
// There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it in case any new
// types are added.
default:
if v.CanInterface() {
fmt.Fprintf(d.w, "%v", v.Interface())
} else {
fmt.Fprintf(d.w, "%v", v.String())
}
}
}
// fdump is a helper function to consolidate the logic from the various public
// methods which take varying writers and config states.
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
for _, arg := range a {
if arg == nil {
w.Write(interfaceBytes)
w.Write(spaceBytes)
w.Write(nilAngleBytes)
w.Write(newlineBytes)
continue
}
d := dumpState{w: w, cs: cs}
d.pointers = make(map[uintptr]int)
d.dump(reflect.ValueOf(arg))
d.w.Write(newlineBytes)
}
}
// Fdump formats and displays the passed arguments to io.Writer w. It formats
// exactly the same as Dump.
func Fdump(w io.Writer, a ...interface{}) {
fdump(&Config, w, a...)
}
// Sdump returns a string with the passed arguments formatted exactly the same
// as Dump.
func Sdump(a ...interface{}) string {
var buf bytes.Buffer
fdump(&Config, &buf, a...)
return buf.String()
}
/*
Dump displays the passed parameters to standard out with newlines, customizable
indentation, and additional debug information such as complete types and all
pointer addresses used to indirect to the final value. It provides the
following features over the built-in printing facilities provided by the fmt
package:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output
The configuration options are controlled by an exported package global,
spew.Config. See ConfigState for options documentation.
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
get the formatted result as a string.
*/
func Dump(a ...interface{}) {
fdump(&Config, os.Stdout, a...)
}

View File

@@ -1,419 +0,0 @@
/*
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
)
// supportedFlags is a list of all the character flags supported by fmt package.
const supportedFlags = "0-+# "
// formatState implements the fmt.Formatter interface and contains information
// about the state of a formatting operation. The NewFormatter function can
// be used to get a new Formatter which can be used directly as arguments
// in standard fmt package printing calls.
type formatState struct {
value interface{}
fs fmt.State
depth int
pointers map[uintptr]int
ignoreNextType bool
cs *ConfigState
}
// buildDefaultFormat recreates the original format string without precision
// and width information to pass in to fmt.Sprintf in the case of an
// unrecognized type. Unless new types are added to the language, this
// function won't ever be called.
func (f *formatState) buildDefaultFormat() (format string) {
buf := bytes.NewBuffer(percentBytes)
for _, flag := range supportedFlags {
if f.fs.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
buf.WriteRune('v')
format = buf.String()
return format
}
// constructOrigFormat recreates the original format string including precision
// and width information to pass along to the standard fmt package. This allows
// automatic deferral of all format strings this package doesn't support.
func (f *formatState) constructOrigFormat(verb rune) (format string) {
buf := bytes.NewBuffer(percentBytes)
for _, flag := range supportedFlags {
if f.fs.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
if width, ok := f.fs.Width(); ok {
buf.WriteString(strconv.Itoa(width))
}
if precision, ok := f.fs.Precision(); ok {
buf.Write(precisionBytes)
buf.WriteString(strconv.Itoa(precision))
}
buf.WriteRune(verb)
format = buf.String()
return format
}
// unpackValue returns values inside of non-nil interfaces when possible and
// ensures that types for values which have been unpacked from an interface
// are displayed when the show types flag is also set.
// This is useful for data types like structs, arrays, slices, and maps which
// can contain varying types packed inside an interface.
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface {
f.ignoreNextType = false
if !v.IsNil() {
v = v.Elem()
}
}
return v
}
// formatPtr handles formatting of pointers by indirecting them as necessary.
func (f *formatState) formatPtr(v reflect.Value) {
// Display nil if top level pointer is nil.
showTypes := f.fs.Flag('#')
if v.IsNil() && (!showTypes || f.ignoreNextType) {
f.fs.Write(nilAngleBytes)
return
}
// Remove pointers at or below the current depth from map used to detect
// circular refs.
for k, depth := range f.pointers {
if depth >= f.depth {
delete(f.pointers, k)
}
}
// Keep list of all dereferenced pointers to possibly show later.
pointerChain := make([]uintptr, 0)
// Figure out how many levels of indirection there are by derferencing
// pointers and unpacking interfaces down the chain while detecting circular
// references.
nilFound := false
cycleFound := false
indirects := 0
ve := v
for ve.Kind() == reflect.Ptr {
if ve.IsNil() {
nilFound = true
break
}
indirects++
addr := ve.Pointer()
pointerChain = append(pointerChain, addr)
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
cycleFound = true
indirects--
break
}
f.pointers[addr] = f.depth
ve = ve.Elem()
if ve.Kind() == reflect.Interface {
if ve.IsNil() {
nilFound = true
break
}
ve = ve.Elem()
}
}
// Display type or indirection level depending on flags.
if showTypes && !f.ignoreNextType {
f.fs.Write(openParenBytes)
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
f.fs.Write([]byte(ve.Type().String()))
f.fs.Write(closeParenBytes)
} else {
if nilFound || cycleFound {
indirects += strings.Count(ve.Type().String(), "*")
}
f.fs.Write(openAngleBytes)
f.fs.Write([]byte(strings.Repeat("*", indirects)))
f.fs.Write(closeAngleBytes)
}
// Display pointer information depending on flags.
if f.fs.Flag('+') && (len(pointerChain) > 0) {
f.fs.Write(openParenBytes)
for i, addr := range pointerChain {
if i > 0 {
f.fs.Write(pointerChainBytes)
}
printHexPtr(f.fs, addr)
}
f.fs.Write(closeParenBytes)
}
// Display dereferenced value.
switch {
case nilFound:
f.fs.Write(nilAngleBytes)
case cycleFound:
f.fs.Write(circularShortBytes)
default:
f.ignoreNextType = true
f.format(ve)
}
}
// format is the main workhorse for providing the Formatter interface. It
// uses the passed reflect value to figure out what kind of object we are
// dealing with and formats it appropriately. It is a recursive function,
// however circular data structures are detected and handled properly.
func (f *formatState) format(v reflect.Value) {
// Handle invalid reflect values immediately.
kind := v.Kind()
if kind == reflect.Invalid {
f.fs.Write(invalidAngleBytes)
return
}
// Handle pointers specially.
if kind == reflect.Ptr {
f.formatPtr(v)
return
}
// Print type information unless already handled elsewhere.
if !f.ignoreNextType && f.fs.Flag('#') {
f.fs.Write(openParenBytes)
f.fs.Write([]byte(v.Type().String()))
f.fs.Write(closeParenBytes)
}
f.ignoreNextType = false
// Call Stringer/error interfaces if they exist and the handle methods
// flag is enabled.
if !f.cs.DisableMethods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(f.cs, f.fs, v); handled {
return
}
}
}
switch kind {
case reflect.Invalid:
// Do nothing. We should never get here since invalid has already
// been handled above.
case reflect.Bool:
printBool(f.fs, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
printInt(f.fs, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
printUint(f.fs, v.Uint(), 10)
case reflect.Float32:
printFloat(f.fs, v.Float(), 32)
case reflect.Float64:
printFloat(f.fs, v.Float(), 64)
case reflect.Complex64:
printComplex(f.fs, v.Complex(), 32)
case reflect.Complex128:
printComplex(f.fs, v.Complex(), 64)
case reflect.Slice:
if v.IsNil() {
f.fs.Write(nilAngleBytes)
break
}
fallthrough
case reflect.Array:
f.fs.Write(openBracketBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
numEntries := v.Len()
for i := 0; i < numEntries; i++ {
if i > 0 {
f.fs.Write(spaceBytes)
}
f.ignoreNextType = true
f.format(f.unpackValue(v.Index(i)))
}
}
f.depth--
f.fs.Write(closeBracketBytes)
case reflect.String:
f.fs.Write([]byte(v.String()))
case reflect.Interface:
// The only time we should get here is for nil interfaces due to
// unpackValue calls.
if v.IsNil() {
f.fs.Write(nilAngleBytes)
}
case reflect.Ptr:
// Do nothing. We should never get here since pointers have already
// been handled above.
case reflect.Map:
// nil maps should be indicated as different than empty maps
if v.IsNil() {
f.fs.Write(nilAngleBytes)
break
}
f.fs.Write(openMapBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
keys := v.MapKeys()
if f.cs.SortKeys {
sortValues(keys, f.cs)
}
for i, key := range keys {
if i > 0 {
f.fs.Write(spaceBytes)
}
f.ignoreNextType = true
f.format(f.unpackValue(key))
f.fs.Write(colonBytes)
f.ignoreNextType = true
f.format(f.unpackValue(v.MapIndex(key)))
}
}
f.depth--
f.fs.Write(closeMapBytes)
case reflect.Struct:
numFields := v.NumField()
f.fs.Write(openBraceBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
vt := v.Type()
for i := 0; i < numFields; i++ {
if i > 0 {
f.fs.Write(spaceBytes)
}
vtf := vt.Field(i)
if f.fs.Flag('+') || f.fs.Flag('#') {
f.fs.Write([]byte(vtf.Name))
f.fs.Write(colonBytes)
}
f.format(f.unpackValue(v.Field(i)))
}
}
f.depth--
f.fs.Write(closeBraceBytes)
case reflect.Uintptr:
printHexPtr(f.fs, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
printHexPtr(f.fs, v.Pointer())
// There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it if any get added.
default:
format := f.buildDefaultFormat()
if v.CanInterface() {
fmt.Fprintf(f.fs, format, v.Interface())
} else {
fmt.Fprintf(f.fs, format, v.String())
}
}
}
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
// details.
func (f *formatState) Format(fs fmt.State, verb rune) {
f.fs = fs
// Use standard formatting for verbs that are not v.
if verb != 'v' {
format := f.constructOrigFormat(verb)
fmt.Fprintf(fs, format, f.value)
return
}
if f.value == nil {
if fs.Flag('#') {
fs.Write(interfaceBytes)
}
fs.Write(nilAngleBytes)
return
}
f.format(reflect.ValueOf(f.value))
}
// newFormatter is a helper function to consolidate the logic from the various
// public methods which take varying config states.
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
fs := &formatState{value: v, cs: cs}
fs.pointers = make(map[uintptr]int)
return fs
}
/*
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
interface. As a result, it integrates cleanly with standard fmt package
printing functions. The formatter is useful for inline printing of smaller data
types similar to the standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Typically this function shouldn't be called directly. It is much easier to make
use of the custom formatter by calling one of the convenience functions such as
Printf, Println, or Fprintf.
*/
func NewFormatter(v interface{}) fmt.Formatter {
return newFormatter(&Config, v)
}

View File

@@ -1,148 +0,0 @@
/*
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"fmt"
"io"
)
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the formatted string as a value that satisfies error. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Errorf(format string, a ...interface{}) (err error) {
return fmt.Errorf(format, convertArgs(a)...)
}
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprint(w, convertArgs(a)...)
}
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
return fmt.Fprintf(w, format, convertArgs(a)...)
}
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
// passed with a default Formatter interface returned by NewFormatter. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprintln(w, convertArgs(a)...)
}
// Print is a wrapper for fmt.Print that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
func Print(a ...interface{}) (n int, err error) {
return fmt.Print(convertArgs(a)...)
}
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Printf(format string, a ...interface{}) (n int, err error) {
return fmt.Printf(format, convertArgs(a)...)
}
// Println is a wrapper for fmt.Println that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
func Println(a ...interface{}) (n int, err error) {
return fmt.Println(convertArgs(a)...)
}
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
func Sprint(a ...interface{}) string {
return fmt.Sprint(convertArgs(a)...)
}
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Sprintf(format string, a ...interface{}) string {
return fmt.Sprintf(format, convertArgs(a)...)
}
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
// were passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
func Sprintln(a ...interface{}) string {
return fmt.Sprintln(convertArgs(a)...)
}
// convertArgs accepts a slice of arguments and returns a slice of the same
// length with each argument converted to a default spew Formatter interface.
func convertArgs(args []interface{}) (formatters []interface{}) {
formatters = make([]interface{}, len(args))
for index, arg := range args {
formatters[index] = NewFormatter(arg)
}
return formatters
}

View File

@@ -1,14 +0,0 @@
freebsd_task:
name: 'FreeBSD'
freebsd_instance:
image_family: freebsd-14-1
install_script:
- pkg update -f
- pkg install -y go
test_script:
# run tests as user "cirrus" instead of root
- pw useradd cirrus -m
- chown -R cirrus:cirrus .
- FSNOTIFY_BUFFER=4096 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./...
- sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./...
- FSNOTIFY_DEBUG=1 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race -v ./...

View File

@@ -1,10 +0,0 @@
# go test -c output
*.test
*.test.exe
# Output of go build ./cmd/fsnotify
/fsnotify
/fsnotify.exe
/test/kqueue
/test/a.out

Some files were not shown because too many files have changed in this diff Show More