Skip to content

Commit

Permalink
gateway: extract CORS to headers middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias authored Jan 24, 2024
1 parent 4c3a1f2 commit 3d57bce
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 127 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ The following emojis are used to highlight certain changes:

- `blockservice` now has `ContextWithSession` and `EmbedSessionInContext` functions, which allows to embed a session in a context. Future calls to `BlockGetter.GetBlock`, `BlockGetter.GetBlocks` and `NewSession` will use the session in the context.
- `blockservice.NewWritethrough` deprecated function has been removed, instead you can do `blockservice.New(..., ..., WriteThrough())` like previously.
- `gateway`: a new header configuration middleware has been added to replace the existing header configuration, which can be used more generically.

### Changed

### Removed

- 🛠 `gateway`: the header configuration `Config.Headers` and `AddAccessControlHeaders` has been replaced by the new middleware provided by `NewHeaders`.

### Security

## [v0.17.0]
Expand Down
11 changes: 4 additions & 7 deletions examples/gateway/common/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ import (

func NewHandler(gwAPI gateway.IPFSBackend) http.Handler {
conf := gateway.Config{
// Initialize the headers. For this example, we do not add any special headers,
// only the required ones via gateway.AddAccessControlHeaders.
Headers: map[string][]string{},

// If you set DNSLink to point at the CID from CAR, you can load it!
NoDNSLink: false,

Expand Down Expand Up @@ -58,9 +54,6 @@ func NewHandler(gwAPI gateway.IPFSBackend) http.Handler {
},
}

// Add required access control headers to the configuration.
gateway.AddAccessControlHeaders(conf.Headers)

// Creates a mux to serve the gateway paths. This is not strictly necessary
// and gwHandler could be used directly. However, on the next step we also want
// to add prometheus metrics, hence needing the mux.
Expand All @@ -86,6 +79,10 @@ func NewHandler(gwAPI gateway.IPFSBackend) http.Handler {
// http.ServeMux which does not support CONNECT by default.
handler = withConnect(handler)

// Add headers middleware that applies any headers we define to all requests
// as well as a default CORS configuration.
handler = gateway.NewHeaders(nil).ApplyCors().Wrap(handler)

// Finally, wrap with the otelhttp handler. This will allow the tracing system
// to work and for correct propagation of tracing headers. This step is optional
// and only required if you want to use tracing. Note that OTel must be correctly
Expand Down
10 changes: 3 additions & 7 deletions gateway/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@ This example shows how you can start your own gateway, assuming you have an `IPF
implementation.

```go
// Initialize your headers and apply the default headers.
headers := map[string][]string{}
gateway.AddAccessControlHeaders(headers)

conf := gateway.Config{
Headers: headers,
}
conf := gateway.Config{}

// Initialize an IPFSBackend interface for both an online and offline versions.
// The offline version should not make any network request for missing content.
Expand All @@ -29,9 +23,11 @@ ipfsBackend := ...
// Create http mux and setup path gateway handler.
mux := http.NewServeMux()
handler := gateway.NewHandler(conf, ipfsBackend)
handler = gateway.NewHeaders(nil).ApplyCors().Wrap(handler)
mux.Handle("/ipfs/", handler)
mux.Handle("/ipns/", handler)


// Start the server on :8080 and voilá! You have a basic IPFS gateway running
// in http://localhost:8080.
_ = http.ListenAndServe(":8080", mux)
Expand Down
4 changes: 2 additions & 2 deletions gateway/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestWebError(t *testing.T) {
t.Parallel()

// Create a handler to be able to test `webError`.
config := &Config{Headers: map[string][]string{}}
config := &Config{}

t.Run("429 Too Many Requests", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestWebError(t *testing.T) {
t.Run("Error is sent as plain text when 'Accept' header contains 'text/html' and config.DisableHTMLErrors is true", func(t *testing.T) {
t.Parallel()

config := &Config{Headers: map[string][]string{}, DisableHTMLErrors: true}
config := &Config{DisableHTMLErrors: true}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/blah", nil)
r.Header.Set("Accept", "something/else, text/html")
Expand Down
80 changes: 0 additions & 80 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"errors"
"fmt"
"io"
"net/http"
"sort"
"strconv"
"strings"
"time"
Expand All @@ -20,11 +18,6 @@ import (

// Config is the configuration used when creating a new gateway handler.
type Config struct {
// Headers is a map containing all the headers that should be sent by default
// in all requests. You can define custom headers, as well as add the recommended
// headers via AddAccessControlHeaders.
Headers map[string][]string

// DeserializedResponses configures this gateway to support returning data
// in deserialized format. By default, the gateway will only support
// trustless, verifiable [application/vnd.ipld.raw] and
Expand Down Expand Up @@ -394,79 +387,6 @@ type WithContextHint interface {
WrapContextForRequest(context.Context) context.Context
}

// cleanHeaderSet is an helper function that cleans a set of headers by
// (1) canonicalizing, (2) de-duplicating and (3) sorting.
func cleanHeaderSet(headers []string) []string {
// Deduplicate and canonicalize.
m := make(map[string]struct{}, len(headers))
for _, h := range headers {
m[http.CanonicalHeaderKey(h)] = struct{}{}
}
result := make([]string, 0, len(m))
for k := range m {
result = append(result, k)
}

// Sort
sort.Strings(result)
return result
}

// AddAccessControlHeaders ensures safe default HTTP headers are used for
// controlling cross-origin requests. This function adds several values to the
// [Access-Control-Allow-Headers] and [Access-Control-Expose-Headers] entries
// to be exposed on GET and OPTIONS responses, including [CORS Preflight].
//
// If the Access-Control-Allow-Origin entry is missing, a default value of '*' is
// added, indicating that browsers should allow requesting code from any
// origin to access the resource.
//
// If the Access-Control-Allow-Methods entry is missing a value, 'GET, HEAD,
// OPTIONS' is added, indicating that browsers may use them when issuing cross
// origin requests.
//
// [Access-Control-Allow-Headers]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
// [Access-Control-Expose-Headers]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
// [CORS Preflight]: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
func AddAccessControlHeaders(headers map[string][]string) {
// Hard-coded headers.
const ACAHeadersName = "Access-Control-Allow-Headers"
const ACEHeadersName = "Access-Control-Expose-Headers"
const ACAOriginName = "Access-Control-Allow-Origin"
const ACAMethodsName = "Access-Control-Allow-Methods"

if _, ok := headers[ACAOriginName]; !ok {
// Default to *all*
headers[ACAOriginName] = []string{"*"}
}
if _, ok := headers[ACAMethodsName]; !ok {
// Default to GET, HEAD, OPTIONS
headers[ACAMethodsName] = []string{
http.MethodGet,
http.MethodHead,
http.MethodOptions,
}
}

headers[ACAHeadersName] = cleanHeaderSet(
append([]string{
"Content-Type",
"User-Agent",
"Range",
"X-Requested-With",
}, headers[ACAHeadersName]...))

headers[ACEHeadersName] = cleanHeaderSet(
append([]string{
"Content-Length",
"Content-Range",
"X-Chunked-Output",
"X-Stream-Output",
"X-Ipfs-Path",
"X-Ipfs-Roots",
}, headers[ACEHeadersName]...))
}

// RequestContextKey is a type representing a [context.Context] value key.
type RequestContextKey string

Expand Down
8 changes: 2 additions & 6 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ func TestHeaders(t *testing.T) {
headers := map[string][]string{}
headers[headerACAO] = []string{expectedACAO}

ts := newTestServerWithConfig(t, backend, Config{
Headers: headers,
ts := newTestServerWithConfigAndHeaders(t, backend, Config{
PublicGateways: map[string]*PublicGateway{
"subgw.example.com": {
Paths: []string{"/ipfs", "/ipns"},
Expand All @@ -362,7 +361,7 @@ func TestHeaders(t *testing.T) {
},
},
DeserializedResponses: true,
})
}, headers)
t.Logf("test server url: %s", ts.URL)

testCORSPreflightRequest := func(t *testing.T, path, hostHeader string, requestOriginHeader string, code int) {
Expand Down Expand Up @@ -532,7 +531,6 @@ func TestRedirects(t *testing.T) {
backend.namesys["/ipns/example.com"] = newMockNamesysItem(path.FromCid(root), 0)

ts := newTestServerWithConfig(t, backend, Config{
Headers: map[string][]string{},
NoDNSLink: false,
PublicGateways: map[string]*PublicGateway{
"example.com": {
Expand Down Expand Up @@ -590,7 +588,6 @@ func TestDeserializedResponses(t *testing.T) {
backend, root := newMockBackend(t, "fixtures.car")

ts := newTestServerWithConfig(t, backend, Config{
Headers: map[string][]string{},
NoDNSLink: false,
PublicGateways: map[string]*PublicGateway{
"trustless.com": {
Expand Down Expand Up @@ -670,7 +667,6 @@ func TestDeserializedResponses(t *testing.T) {
backend.namesys["/ipns/trusted.com"] = newMockNamesysItem(path.FromCid(root), 0)

ts := newTestServerWithConfig(t, backend, Config{
Headers: map[string][]string{},
NoDNSLink: false,
PublicGateways: map[string]*PublicGateway{
"trustless.com": {
Expand Down
8 changes: 0 additions & 8 deletions gateway/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ func (i *handler) optionsHandler(w http.ResponseWriter, r *http.Request) {
// OPTIONS is a noop request that is used by the browsers to check if server accepts
// cross-site XMLHttpRequest, which is indicated by the presence of CORS headers:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Access_control_CORS#Preflighted_requests
addCustomHeaders(w, i.config.Headers) // return all custom headers (including CORS ones, if set)
}

// addAllowHeader sets Allow header with supported HTTP methods
Expand Down Expand Up @@ -264,7 +263,6 @@ func (i *handler) getOrHeadHandler(w http.ResponseWriter, r *http.Request) {
trace.SpanFromContext(r.Context()).SetAttributes(attribute.String("ResponseFormat", responseFormat))
i.requestTypeMetric.WithLabelValues(contentPath.Namespace(), responseFormat).Inc()

addCustomHeaders(w, i.config.Headers) // ok, _now_ write user's headers.
w.Header().Set("X-Ipfs-Path", contentPath.String())

// Fail fast if unsupported request type was sent to a Trustless Gateway.
Expand Down Expand Up @@ -340,12 +338,6 @@ func (i *handler) getOrHeadHandler(w http.ResponseWriter, r *http.Request) {
}
}

func addCustomHeaders(w http.ResponseWriter, headers map[string][]string) {
for k, v := range headers {
w.Header()[http.CanonicalHeaderKey(k)] = v
}
}

// isDeserializedResponsePossible returns true if deserialized responses
// are allowed on the specified hostname, or globally. Host-specific rules
// override global config.
Expand Down
1 change: 0 additions & 1 deletion gateway/handler_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ func TestDagJsonCborPreview(t *testing.T) {
backend, root := newMockBackend(t, "fixtures.car")

ts := newTestServerWithConfig(t, backend, Config{
Headers: map[string][]string{},
NoDNSLink: false,
PublicGateways: map[string]*PublicGateway{
"example.com": {
Expand Down
112 changes: 112 additions & 0 deletions gateway/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package gateway

import (
"net/http"
"sort"
)

// Headers is an HTTP middleware that sets the configured headers in all requests.
type Headers struct {
headers map[string][]string
}

// NewHeaders creates a new [Headers] middleware that applies the given headers
// to all requests. If you call [Headers.ApplyCors], the default CORS configuration
// will also be applied, if any of the CORS headers is missing.
func NewHeaders(headers map[string][]string) *Headers {
h := &Headers{
headers: map[string][]string{},
}

for k, v := range headers {
h.headers[http.CanonicalHeaderKey(k)] = v
}

return h
}

// ApplyCors applies safe default HTTP headers for controlling cross-origin
// requests. This function adds several values to the [Access-Control-Allow-Headers]
// and [Access-Control-Expose-Headers] entries to be exposed on GET and OPTIONS
// responses, including [CORS Preflight].
//
// If the Access-Control-Allow-Origin entry is missing, a default value of '*' is
// added, indicating that browsers should allow requesting code from any
// origin to access the resource.
//
// If the Access-Control-Allow-Methods entry is missing a value, 'GET, HEAD,
// OPTIONS' is added, indicating that browsers may use them when issuing cross
// origin requests.
//
// [Access-Control-Allow-Headers]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
// [Access-Control-Expose-Headers]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
// [CORS Preflight]: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
func (h *Headers) ApplyCors() *Headers {
// Hard-coded headers.
const ACAHeadersName = "Access-Control-Allow-Headers"
const ACEHeadersName = "Access-Control-Expose-Headers"
const ACAOriginName = "Access-Control-Allow-Origin"
const ACAMethodsName = "Access-Control-Allow-Methods"

if _, ok := h.headers[ACAOriginName]; !ok {
// Default to *all*
h.headers[ACAOriginName] = []string{"*"}
}
if _, ok := h.headers[ACAMethodsName]; !ok {
// Default to GET, HEAD, OPTIONS
h.headers[ACAMethodsName] = []string{
http.MethodGet,
http.MethodHead,
http.MethodOptions,
}
}

h.headers[ACAHeadersName] = cleanHeaderSet(
append([]string{
"Content-Type",
"User-Agent",
"Range",
"X-Requested-With",
}, h.headers[ACAHeadersName]...))

h.headers[ACEHeadersName] = cleanHeaderSet(
append([]string{
"Content-Length",
"Content-Range",
"X-Chunked-Output",
"X-Stream-Output",
"X-Ipfs-Path",
"X-Ipfs-Roots",
}, h.headers[ACEHeadersName]...))

return h
}

// Wrap wraps the given [http.Handler] with the headers middleware.
func (h *Headers) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for k, v := range h.headers {
w.Header()[k] = v
}

next.ServeHTTP(w, r)
})
}

// cleanHeaderSet is an helper function that cleans a set of headers by
// (1) canonicalizing, (2) de-duplicating and (3) sorting.
func cleanHeaderSet(headers []string) []string {
// Deduplicate and canonicalize.
m := make(map[string]struct{}, len(headers))
for _, h := range headers {
m[http.CanonicalHeaderKey(h)] = struct{}{}
}
result := make([]string, 0, len(m))
for k := range m {
result = append(result, k)
}

// Sort
sort.Strings(result)
return result
}
Loading

0 comments on commit 3d57bce

Please sign in to comment.