Files
Patrick Nagurny e2dd29259f initial commit
2018-10-19 15:31:41 -04:00

149 lines
3.6 KiB
Go

package rest
import (
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
var (
// ErrJsonPayloadEmpty is returned when the JSON payload is empty.
ErrJsonPayloadEmpty = errors.New("JSON payload is empty")
)
// Request inherits from http.Request, and provides additional methods.
type Request struct {
*http.Request
// Map of parameters that have been matched in the URL Path.
PathParams map[string]string
// Environment used by middlewares to communicate.
Env map[string]interface{}
}
// PathParam provides a convenient access to the PathParams map.
func (r *Request) PathParam(name string) string {
return r.PathParams[name]
}
// DecodeJsonPayload reads the request body and decodes the JSON using json.Unmarshal.
func (r *Request) DecodeJsonPayload(v interface{}) error {
content, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
if len(content) == 0 {
return ErrJsonPayloadEmpty
}
err = json.Unmarshal(content, v)
if err != nil {
return err
}
return nil
}
// BaseUrl returns a new URL object with the Host and Scheme taken from the request.
// (without the trailing slash in the host)
func (r *Request) BaseUrl() *url.URL {
scheme := r.URL.Scheme
if scheme == "" {
scheme = "http"
}
// HTTP sometimes gives the default scheme as HTTP even when used with TLS
// Check if TLS is not nil and given back https scheme
if scheme == "http" && r.TLS != nil {
scheme = "https"
}
host := r.Host
if len(host) > 0 && host[len(host)-1] == '/' {
host = host[:len(host)-1]
}
return &url.URL{
Scheme: scheme,
Host: host,
}
}
// UrlFor returns the URL object from UriBase with the Path set to path, and the query
// string built with queryParams.
func (r *Request) UrlFor(path string, queryParams map[string][]string) *url.URL {
baseUrl := r.BaseUrl()
baseUrl.Path = path
if queryParams != nil {
query := url.Values{}
for k, v := range queryParams {
for _, vv := range v {
query.Add(k, vv)
}
}
baseUrl.RawQuery = query.Encode()
}
return baseUrl
}
// CorsInfo contains the CORS request info derived from a rest.Request.
type CorsInfo struct {
IsCors bool
IsPreflight bool
Origin string
OriginUrl *url.URL
// The header value is converted to uppercase to avoid common mistakes.
AccessControlRequestMethod string
// The header values are normalized with http.CanonicalHeaderKey.
AccessControlRequestHeaders []string
}
// GetCorsInfo derives CorsInfo from Request.
func (r *Request) GetCorsInfo() *CorsInfo {
origin := r.Header.Get("Origin")
var originUrl *url.URL
var isCors bool
if origin == "" {
isCors = false
} else if origin == "null" {
isCors = true
} else {
var err error
originUrl, err = url.ParseRequestURI(origin)
isCors = err == nil && r.Host != originUrl.Host
}
reqMethod := r.Header.Get("Access-Control-Request-Method")
reqHeaders := []string{}
rawReqHeaders := r.Header[http.CanonicalHeaderKey("Access-Control-Request-Headers")]
for _, rawReqHeader := range rawReqHeaders {
if len(rawReqHeader) == 0 {
continue
}
// net/http does not handle comma delimited headers for us
for _, reqHeader := range strings.Split(rawReqHeader, ",") {
reqHeaders = append(reqHeaders, http.CanonicalHeaderKey(strings.TrimSpace(reqHeader)))
}
}
isPreflight := isCors && r.Method == "OPTIONS" && reqMethod != ""
return &CorsInfo{
IsCors: isCors,
IsPreflight: isPreflight,
Origin: origin,
OriginUrl: originUrl,
AccessControlRequestMethod: strings.ToUpper(reqMethod),
AccessControlRequestHeaders: reqHeaders,
}
}