Skip to content

Commit

Permalink
transport: support stats.Handler in serverHandlerTransport (#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
etdub authored and menghanl committed Feb 6, 2018
1 parent 104054a commit 7646b53
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) {
// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
// and subject to change.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r)
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
27 changes: 25 additions & 2 deletions transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)

// NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server
// supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2")
}
Expand All @@ -73,6 +74,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
writes: make(chan func()),
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
}

if v := r.Header.Get("grpc-timeout"); v != "" {
Expand Down Expand Up @@ -137,6 +139,8 @@ type serverHandlerTransport struct {
// we store both contentType and contentSubtype so we don't keep recreating them
// TODO make sure this is consistent across handler_server and http2_server
contentSubtype string

stats stats.Handler
}

func (ht *serverHandlerTransport) Close() error {
Expand Down Expand Up @@ -230,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
})

if err == nil { // transport has not been closed
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
ht.Close()
close(ht.writes)
}
Expand Down Expand Up @@ -274,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts
}

func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return ht.do(func() {
err := ht.do(func() {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
Expand All @@ -290,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
})

if err == nil {
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
}
}
return err
}

func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
Expand Down Expand Up @@ -342,6 +356,15 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s)
if ht.stats != nil {
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: ht.RemoteAddr(),
Compression: s.recvCompress,
}
ht.stats.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, recv: s.buf},
windowHandler: func(int) {},
Expand Down
6 changes: 3 additions & 3 deletions transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
if tt.modrw != nil {
rw = tt.modrw(rw)
}
got, gotErr := NewServerHandlerTransport(rw, tt.req)
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
continue
Expand Down Expand Up @@ -272,7 +272,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req)
ht, err := NewServerHandlerTransport(rw, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -357,7 +357,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req)
ht, err := NewServerHandlerTransport(rw, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 7646b53

Please sign in to comment.