From e84570c32cdd207458c03de589007783d11e9b7a Mon Sep 17 00:00:00 2001 From: Jack Li Date: Mon, 7 Aug 2017 13:39:34 -0400 Subject: [PATCH] Set peer before sending request (#1423) --- call.go | 6 ++--- test/end2end_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/call.go b/call.go index f0b459125611..797190f1471c 100644 --- a/call.go +++ b/call.go @@ -74,9 +74,6 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran dopts.copts.StatsHandler.HandleRPC(ctx, inPayload) } c.trailerMD = stream.Trailer() - if peer, ok := peer.FromContext(stream.Context()); ok { - c.peer = peer - } return nil } @@ -262,6 +259,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } + if peer, ok := peer.FromContext(stream.Context()); ok { + c.peer = peer + } err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, stream, t, args, topts) if err != nil { if put != nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index f226bdcf21c0..2b9c9c971103 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2264,6 +2264,62 @@ func testPeerNegative(t *testing.T, e env) { tc.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(peer)) } +func TestPeerFailedRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPeerFailedRPC(t, e) + } +} + +func testPeerFailedRPC(t *testing.T, e env) { + te := newTest(t, e) + te.maxServerReceiveMsgSize = newInt(1 * 1024) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + // first make a successful request to the server + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + + // make a second request that will be rejected by the server + const largeSize = 5 * 1024 + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + Payload: largePayload, + } + + peer := new(peer.Peer) + if _, err := tc.UnaryCall(context.Background(), req, grpc.Peer(peer)); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } else { + pa := peer.Addr.String() + if e.network == "unix" { + if pa != te.srvAddr { + t.Fatalf("peer.Addr = %v, want %v", pa, te.srvAddr) + } + return + } + _, pp, err := net.SplitHostPort(pa) + if err != nil { + t.Fatalf("Failed to parse address from peer.") + } + _, sp, err := net.SplitHostPort(te.srvAddr) + if err != nil { + t.Fatalf("Failed to parse address of test server.") + } + if pp != sp { + t.Fatalf("peer.Addr = localhost:%v, want localhost:%v", pp, sp) + } + } +} + func TestMetadataUnaryRPC(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() {