Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lnd: Add CORS support to the WalletUnlocker proxy #4551

Merged
merged 1 commit into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lnd.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr,
return nil, err
}

srv := &http.Server{Handler: mux}
srv := &http.Server{Handler: allowCORS(mux, cfg.RestCORS)}

for _, restEndpoint := range restEndpoints {
lis, err := lncfg.TLSListenOnAddress(restEndpoint, tlsConf)
Expand Down
15 changes: 8 additions & 7 deletions rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,6 @@ func (r *rpcServer) Start() error {
// Wrap the default grpc-gateway handler with the WebSocket handler.
restHandler := lnrpc.NewWebSocketProxy(restMux, rpcsLog)

// Set the CORS headers if configured. This wraps the HTTP handler with
// another handler.
if len(r.cfg.RestCORS) > 0 {
restHandler = allowCORS(restHandler, r.cfg.RestCORS)
}

// With our custom REST proxy mux created, register our main RPC and
// give all subservers a chance to register as well.
err := lnrpc.RegisterLightningHandlerFromEndpoint(
Expand Down Expand Up @@ -871,7 +865,8 @@ func (r *rpcServer) Start() error {
// through the following chain:
// req ---> CORS handler --> WS proxy --->
// REST proxy --> gRPC endpoint
err := http.Serve(lis, restHandler)
corsHandler := allowCORS(restHandler, r.cfg.RestCORS)
err := http.Serve(lis, corsHandler)
if err != nil && !lnrpc.IsClosedConnError(err) {
rpcsLog.Error(err)
}
Expand Down Expand Up @@ -944,6 +939,12 @@ func allowCORS(handler http.Handler, origins []string) http.Handler {
allowMethods := "Access-Control-Allow-Methods"
allowOrigin := "Access-Control-Allow-Origin"

// If the user didn't supply any origins that means CORS is disabled
// and we should return the original handler.
if len(origins) == 0 {
return handler
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")

Expand Down