Skip to content

Commit

Permalink
add test and fix bugs for websocket (#379)
Browse files Browse the repository at this point in the history
* add test for websocket

* fix websocket inherit panic

* add comment for websocket proxy

* add more test

* fix proxy close bug

* fix tls bug

* add more test

* update github action

* update websocket test

* copy more headers for websocket proxy

* add more test

* add more test

* update comments
  • Loading branch information
suchen-sci authored Nov 23, 2021
1 parent d21c8ec commit 0e61435
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 25 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ require (
github.com/spf13/cobra v1.2.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.8.1
github.com/stretchr/testify v1.7.0
github.com/tcnksm/go-httpstat v0.2.1-0.20191008022543-e866bb274419
github.com/tidwall/gjson v1.11.0
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce
Expand Down
84 changes: 59 additions & 25 deletions pkg/object/websocketserver/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ import (
"github.com/megaease/easegress/pkg/supervisor"
)

const (
xForwardedFor = "X-Forwarded-For"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
)

var (
// headersToSkip are gorilla library's request headers, our websocket proxy should not set.
headersToSkip = map[string]struct{}{
"Upgrade": {},
"Connection": {},
"Sec-Websocket-Key": {},
"Sec-Websocket-Version": {},
"Sec-Websocket-Extensions": {},
"Sec-Websocket-Protocol": {},
}
)

var (
// defaultUpgrader specifies the parameters for upgrading an HTTP
// connection to a WebSocket connection.
Expand Down Expand Up @@ -149,57 +167,73 @@ func (p *Proxy) run() {
p.dialer = dialer
p.upgrader = defaultUpgrader

http.HandleFunc("/", p.handle)
mux := http.NewServeMux()
mux.HandleFunc("/", p.handle)
addr := fmt.Sprintf(":%d", spec.Port)
svr := &http.Server{
Addr: addr,
Handler: nil,
}

p.server.Addr = addr
p.server.Handler = mux

if spec.HTTPS {
tlsConfig, err := spec.tlsConfig()
if err != nil {
logger.Errorf("%s gen websocketserver's httpserver tlsConfig: %#v, failed: %v",
p.superSpec.Name(), spec, err)
}
svr.TLSConfig = tlsConfig
p.server.TLSConfig = tlsConfig
}

if err := svr.ListenAndServe(); err != nil {
logger.Errorf("%s websocketserver ListenAndServe failed: %v", p.superSpec.Name(), err)
if p.server.TLSConfig != nil {
if err := p.server.ListenAndServeTLS("", ""); err != nil {
logger.Errorf("%s websocketserver ListenAndServeTLS failed: %v", p.superSpec.Name(), err)
}
} else {
if err := p.server.ListenAndServe(); err != nil {
logger.Errorf("%s websocketserver ListenAndServe failed: %v", p.superSpec.Name(), err)
}
}
}

// copyHeader copies headers from the incoming request to the dialer and forward them to
// the destination.
func (p *Proxy) copyHeader(req *http.Request) http.Header {
// Based on https://docs.oracle.com/en-us/iaas/Content/Balance/Reference/httpheaders.htm
// For load balancer, we add following key-value pairs to headers
// X-Forwarded-For: <original_client>, <proxy1>, <proxy2>
// X-Forwarded-Host: www.example.com:8080
// X-Forwarded-Proto: https

// New client connection is created using [Gorilla websocket library](https://github.com/gorilla/websocket), which takes care of some of the headers.
// Let's copy copy all headers from the incoming request, except the ones gorilla will set.

requestHeader := http.Header{}
if origin := req.Header.Get("Origin"); origin != "" {
requestHeader.Add("Origin", origin)
}
for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
requestHeader.Add("Sec-WebSocket-Protocol", prot)
}
for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] {
requestHeader.Add("Cookie", cookie)
}
if req.Host != "" {
requestHeader.Set("Host", req.Host)
for k, values := range req.Header {
if _, ok := headersToSkip[k]; ok {
continue
}
for _, v := range values {
requestHeader.Add(k, v)
}
}

xff := requestHeader.Get(xForwardedFor)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
if xff == "" {
requestHeader.Set(xForwardedFor, clientIP)
} else {
requestHeader.Set(xForwardedFor, fmt.Sprintf("%s, %s", xff, clientIP))
}
requestHeader.Set("X-Forwarded-For", clientIP)
}

requestHeader.Set("X-Forwarded-Proto", "http")
if req.TLS != nil {
requestHeader.Set("X-Forwarded-Proto", "https")
xfh := requestHeader.Get(xForwardedHost)
if xfh == "" && req.Host != "" {
requestHeader.Set(xForwardedHost, req.Host)
}

requestHeader.Set(xForwardedProto, "http")
if req.TLS != nil {
requestHeader.Set(xForwardedProto, "https")
}
return requestHeader
}

Expand Down
77 changes: 77 additions & 0 deletions pkg/object/websocketserver/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2017, MegaEase
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package websocketserver

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestProxyCopyHeader(t *testing.T) {
assert := assert.New(t)
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
require.Nil(t, err)
req.Header.Add("Origin", "origin")
req.Header.Add("Sec-WebSocket-Protocol", "protocol")
req.Header.Add("Cookie", "cookie1=1")
req.RemoteAddr = fmt.Sprintf("%s:%d", req.Host, 8888)

p := &Proxy{}
header := p.copyHeader(req)
assert.Equal("origin", header.Get("Origin"))
assert.Equal("", header.Get("Sec-WebSocket-Protocol"))
assert.Equal("cookie1=1", header.Get("Cookie"))
assert.Equal("127.0.0.1", header.Get(xForwardedFor))
assert.Equal("127.0.0.1", header.Get(xForwardedHost))
assert.Equal("http", header.Get(xForwardedProto))
fmt.Printf("header: %v\n", header)
}

func TestProxyUpgradeRspHeader(t *testing.T) {
assert := assert.New(t)
resp := &http.Response{}
resp.Header = make(http.Header)
resp.Header.Add("Sec-Websocket-Protocol", "protocol")
resp.Header.Add("Set-Cookie", "cookie=1")
resp.Header.Add("Sec-Websocket-Extensions", "extensions")

p := &Proxy{}
newHeader := p.upgradeRspHeader(resp)
assert.Equal("protocol", newHeader.Get("Sec-Websocket-Protocol"))
assert.Equal("cookie=1", newHeader.Get("Set-Cookie"))
// only copy protocol and cookie
assert.Equal("", newHeader.Get("Sec-Websocket-Extensions"))
}

func TestCopyResponse(t *testing.T) {
testSrv := getTestServer(t, "127.0.0.1:8000")
defer testSrv.Close()

req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/start", nil)
require.Nil(t, err)
resp, err := http.DefaultClient.Do(req)
require.Nil(t, err)
copyResp := httptest.NewRecorder()
copyResponse(copyResp, resp)
assert.Equal(t, copyResp.Header(), resp.Header)
}
128 changes: 128 additions & 0 deletions pkg/object/websocketserver/spec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) 2017, MegaEase
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package websocketserver

import (
"encoding/base64"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSpecValidate(t *testing.T) {
tests := []struct {
spec *Spec
valid bool
}{
{
spec: &Spec{
Port: 10081,
HTTPS: false,
Backend: "127.0.0.1:8888",
},
valid: false,
},
{
spec: &Spec{
Port: 10081,
HTTPS: false,
Backend: "http://127.0.0.1:8888",
},
valid: false,
},
{
spec: &Spec{
Port: 10081,
HTTPS: true,
Backend: "ws://127.0.0.1:8888",
},
valid: false,
},
{
spec: &Spec{
Port: 10081,
HTTPS: false,
Backend: "wss://127.0.0.1:8888",
},
valid: false,
},
}
for _, testCase := range tests {
err := testCase.spec.Validate()
if testCase.valid {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}

// this certPem and keyPem come from golang crypto/tls/testdata
// with original name: example-cert.pem and example-key.pem
const certPem = `
-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----
`
const keyPem = `
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----
`

func TestSpecTLS(t *testing.T) {
cert := base64.StdEncoding.EncodeToString([]byte(certPem))
key := base64.StdEncoding.EncodeToString([]byte(keyPem))
spec := &Spec{
CertBase64: cert,
KeyBase64: key,
WssCertBase64: cert,
WssKeyBase64: key,
}
_, err := spec.wssTLSConfig()
assert.Nil(t, err)
_, err = spec.tlsConfig()
assert.Nil(t, err)

spec = &Spec{}
_, err = spec.wssTLSConfig()
assert.NotNil(t, err)
_, err = spec.tlsConfig()
assert.NotNil(t, err)

spec = &Spec{
CertBase64: "cert",
KeyBase64: "key",
WssCertBase64: "cert",
WssKeyBase64: "key",
}
_, err = spec.wssTLSConfig()
assert.NotNil(t, err)
_, err = spec.tlsConfig()
assert.NotNil(t, err)
}
Loading

0 comments on commit 0e61435

Please sign in to comment.