Skip to content

Commit

Permalink
rpc: fix staticcheck warning SA1029 by add PeerInfo (ethereum#24255)
Browse files Browse the repository at this point in the history
  • Loading branch information
gzliudan committed Oct 25, 2024
1 parent e52b8a2 commit 2844dbc
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 18 deletions.
5 changes: 3 additions & 2 deletions rpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type clientConn struct {

func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
handler := newHandler(ctx, conn, c.idgen, c.services)
return &clientConn{conn, handler}
}
Expand Down Expand Up @@ -473,7 +474,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
// Check type of channel first.
chanVal := reflect.ValueOf(channel)
if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 {
panic("first argument to Subscribe must be a writable channel")
panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel))
}
if chanVal.IsNil() {
panic("channel given to Subscribe must not be nil")
Expand Down Expand Up @@ -532,8 +533,8 @@ func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error
}

func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
// The previous write failed. Try to establish a new connection.
if c.writeConn == nil {
// The previous write failed. Try to establish a new connection.
if err := c.reconnect(ctx); err != nil {
return err
}
Expand Down
30 changes: 18 additions & 12 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,18 @@ type httpConn struct {
headers http.Header
}

// httpConn is treated specially by Client.
// httpConn implements ServerCodec, but it is treated specially by Client
// and some methods don't work. The panic() stubs here exist to ensure
// this special treatment is correct.

func (hc *httpConn) writeJSON(context.Context, interface{}) error {
panic("writeJSON called on httpConn")
}

func (hc *httpConn) peerInfo() PeerInfo {
panic("peerInfo called on httpConn")
}

func (hc *httpConn) remoteAddr() string {
return hc.url
}
Expand Down Expand Up @@ -258,20 +265,19 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), code)
return
}

// Create request-scoped context.
connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr}
connInfo.HTTP.Version = r.Proto
connInfo.HTTP.Host = r.Host
connInfo.HTTP.Origin = r.Header.Get("Origin")
connInfo.HTTP.UserAgent = r.Header.Get("User-Agent")
ctx := r.Context()
ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo)

// All checks passed, create a codec that reads directly from the request body
// until EOF, writes the response to w, and orders the server to process a
// single request.
ctx := r.Context()
ctx = context.WithValue(ctx, "remote", r.RemoteAddr)
ctx = context.WithValue(ctx, "scheme", r.Proto)
ctx = context.WithValue(ctx, "local", r.Host)
if ua := r.Header.Get("User-Agent"); ua != "" {
ctx = context.WithValue(ctx, "User-Agent", ua)
}
if origin := r.Header.Get("Origin"); origin != "" {
ctx = context.WithValue(ctx, "Origin", origin)
}

w.Header().Set("content-type", contentType)
codec := newHTTPServerConn(r, w)
defer codec.close()
Expand Down
36 changes: 36 additions & 0 deletions rpc/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,39 @@ func TestHTTPRespBodyUnlimited(t *testing.T) {
t.Fatalf("response has wrong length %d, want %d", len(r), respLength)
}
}

func TestHTTPPeerInfo(t *testing.T) {
s := newTestServer()
defer s.Stop()
ts := httptest.NewServer(s)
defer ts.Close()

c, err := Dial(ts.URL)
if err != nil {
t.Fatal(err)
}
c.SetHeader("user-agent", "ua-testing")
c.SetHeader("origin", "origin.example.com")

// Request peer information.
var info PeerInfo
if err := c.Call(&info, "test_peerInfo"); err != nil {
t.Fatal(err)
}

if info.RemoteAddr == "" {
t.Error("RemoteAddr not set")
}
if info.Transport != "http" {
t.Errorf("wrong Transport %q", info.Transport)
}
if info.HTTP.Version != "HTTP/1.1" {
t.Errorf("wrong HTTP.Version %q", info.HTTP.Version)
}
if info.HTTP.UserAgent != "ua-testing" {
t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent)
}
if info.HTTP.Origin != "origin.example.com" {
t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent)
}
}
5 changes: 5 additions & 0 deletions rpc/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ func NewCodec(conn Conn) ServerCodec {
return NewFuncCodec(conn, enc.Encode, dec.Decode)
}

func (c *jsonCodec) peerInfo() PeerInfo {
// This returns "ipc" because all other built-in transports have a separate codec type.
return PeerInfo{Transport: "ipc", RemoteAddr: c.remote}
}

func (c *jsonCodec) remoteAddr() string {
return c.remote
}
Expand Down
35 changes: 35 additions & 0 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,38 @@ func (s *RPCService) Modules() map[string]string {
}
return modules
}

// PeerInfo contains information about the remote end of the network connection.
//
// This is available within RPC method handlers through the context. Call
// PeerInfoFromContext to get information about the client connection related to
// the current method call.
type PeerInfo struct {
// Transport is name of the protocol used by the client.
// This can be "http", "ws" or "ipc".
Transport string

// Address of client. This will usually contain the IP address and port.
RemoteAddr string

// Addditional information for HTTP and WebSocket connections.
HTTP struct {
// Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket.
Version string
// Header values sent by the client.
UserAgent string
Origin string
Host string
}
}

type peerInfoContextKey struct{}

// PeerInfoFromContext returns information about the client's network connection.
// Use this with the context passed to RPC method handler functions.
//
// The zero value is returned if no connection info is present in ctx.
func PeerInfoFromContext(ctx context.Context) PeerInfo {
info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo)
return info
}
2 changes: 1 addition & 1 deletion rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) {
t.Fatalf("Expected service calc to be registered")
}

wantCallbacks := 10
wantCallbacks := 11
if len(svc.callbacks) != wantCallbacks {
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
}
Expand Down
4 changes: 4 additions & 0 deletions rpc/testservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *
return echoResult{str, i, args}
}

func (s *testService) PeerInfo(ctx context.Context) PeerInfo {
return PeerInfoFromContext(ctx)
}

func (s *testService) Sleep(ctx context.Context, duration time.Duration) {
time.Sleep(duration)
}
Expand Down
2 changes: 2 additions & 0 deletions rpc/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ type DataError interface {
// a RPC session. Implementations must be go-routine safe since the codec can be called in
// multiple go-routines concurrently.
type ServerCodec interface {
peerInfo() PeerInfo
readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error)
close()

jsonWriter
}

Expand Down
20 changes: 17 additions & 3 deletions rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn)
codec := newWebsocketCodec(conn, r.Host, r.Header)
s.ServeCodec(codec, 0)
})
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
}
return nil, hErr
}
return newWebsocketCodec(conn), nil
return newWebsocketCodec(conn, endpoint, header), nil
})
}

Expand Down Expand Up @@ -241,18 +241,28 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
type websocketCodec struct {
*jsonCodec
conn *websocket.Conn
info PeerInfo

wg sync.WaitGroup
pingReset chan struct{}
}

func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
conn.SetReadLimit(wsMessageSizeLimit)
wc := &websocketCodec{
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
conn: conn,
pingReset: make(chan struct{}, 1),
info: PeerInfo{
Transport: "ws",
RemoteAddr: conn.RemoteAddr().String(),
},
}
// Fill in connection details.
wc.info.HTTP.Host = host
wc.info.HTTP.Origin = req.Get("Origin")
wc.info.HTTP.UserAgent = req.Get("User-Agent")
// Start pinger.
wc.wg.Add(1)
go wc.pingLoop()
return wc
Expand All @@ -263,6 +273,10 @@ func (wc *websocketCodec) close() {
wc.wg.Wait()
}

func (wc *websocketCodec) peerInfo() PeerInfo {
return wc.info
}

func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
err := wc.jsonCodec.writeJSON(ctx, v)
if err == nil {
Expand Down
35 changes: 35 additions & 0 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,41 @@ func TestWebsocketLargeCall(t *testing.T) {
}
}

func TestWebsocketPeerInfo(t *testing.T) {
var (
s = newTestServer()
ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}))
tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:")
)
defer s.Stop()
defer ts.Close()

ctx := context.Background()
c, err := DialWebsocket(ctx, tsurl, "origin.example.com")
if err != nil {
t.Fatal(err)
}

// Request peer information.
var connInfo PeerInfo
if err := c.Call(&connInfo, "test_peerInfo"); err != nil {
t.Fatal(err)
}

if connInfo.RemoteAddr == "" {
t.Error("RemoteAddr not set")
}
if connInfo.Transport != "ws" {
t.Errorf("wrong Transport %q", connInfo.Transport)
}
if connInfo.HTTP.UserAgent != "Go-http-client/1.1" {
t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent)
}
if connInfo.HTTP.Origin != "origin.example.com" {
t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent)
}
}

// This test checks that client handles WebSocket ping frames correctly.
func TestClientWebsocketPing(t *testing.T) {
t.Parallel()
Expand Down

0 comments on commit 2844dbc

Please sign in to comment.