Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: add SetHeader method to Client #21392

Merged
merged 12 commits into from
Aug 3, 2020
16 changes: 15 additions & 1 deletion rpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type Client struct {

// writeConn is used for writing to the connection on the caller's goroutine. It should
// only be accessed outside of dispatch, with the write lock held. The write lock is
// taken by sending on requestOp and released by sending on sendDone.
// taken by sending on reqInit and released by sending on reqSent.
writeConn jsonWriter

// for dispatch
Expand Down Expand Up @@ -260,6 +260,20 @@ func (c *Client) Close() {
}
}

// SetHeader sets the given key to the given value in the header of
// the http request of a client's http connection. // TODO improve
func (c *Client) SetHeader(key, value string) error {
conn := c.writeConn.(*httpConn)
if conn == nil {
return fmt.Errorf("client is not http") // TODO revise err?
}

conn.Lock()
conn.headers[key] = value
conn.Unlock()
return nil
}

// Call performs a JSON-RPC call with the given arguments and unmarshals into
// result if no error occurred.
//
Expand Down
36 changes: 36 additions & 0 deletions rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/log"
"github.com/stretchr/testify/assert"
)

func TestClientRequest(t *testing.T) {
Expand Down Expand Up @@ -429,6 +430,41 @@ func TestClientNotificationStorm(t *testing.T) {
doTest(23000, true)
}

// TestClientSetHeader tests whether an http header has been properly set
// to the given key and value on a client's http request.
func TestClientSetHeader(t *testing.T) {
renaynay marked this conversation as resolved.
Show resolved Hide resolved
server := newTestServer()
defer server.Stop()

client, hs := httpTestClient(server, "http", nil)
defer hs.Close()
defer client.Close()

headers := []struct {
key string
val string
}{
{ key: "test1", val: "success" },
{ key: "test2", val: "success" },
{ key: "test3", val: "success" },
}

for _, header := range headers {
if err := client.SetHeader(header.key, header.val); err != nil {
t.Fatal(err)
}
}

conn := client.writeConn.(*httpConn)
if conn == nil {
t.Fatal("client is not HTTP")
}

for _, header := range headers {
assert.Equal(t, header.val, conn.headers[header.key])
}
}

func TestClientHTTP(t *testing.T) {
server := newTestServer()
defer server.Stop()
Expand Down
21 changes: 17 additions & 4 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ const (
var acceptedContentTypes = []string{contentType, "application/json-rpc", "application/jsonrequest"}

type httpConn struct {
sync.Mutex
renaynay marked this conversation as resolved.
Show resolved Hide resolved

client *http.Client
headers map[string]string
renaynay marked this conversation as resolved.
Show resolved Hide resolved
req *http.Request
closeOnce sync.Once
closeCh chan interface{}
Expand Down Expand Up @@ -110,8 +113,9 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
req.Header.Set("Accept", contentType)

initctx := context.Background()
headers := map[string]string{"Accept": "application/json", "Content-Type":"application/json"}
return newClient(initctx, func(context.Context) (ServerCodec, error) {
return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
return &httpConn{client: client, headers: headers, req: req, closeCh: make(chan interface{})}, nil
})
}

Expand Down Expand Up @@ -166,10 +170,19 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
if err != nil {
return nil, err
}
req := hc.req.WithContext(ctx)
req.Body = ioutil.NopCloser(bytes.NewReader(body))
req, err := http.NewRequestWithContext(ctx, hc.req.Method, hc.req.URL.String(), ioutil.NopCloser(bytes.NewReader(body)))
if err != nil {
return nil, err
}
req.Host = hc.req.Host
req.ContentLength = int64(len(body))

// set headers
hc.Lock()
for key, val := range hc.headers {
req.Header.Set(key, val)
}
hc.Unlock()
// do request
resp, err := hc.client.Do(req)
if err != nil {
return nil, err
Expand Down