From b3503e0d04281b4cfd9e370e602fa6067b30bb61 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Wed, 8 Jun 2016 16:12:06 -0700 Subject: [PATCH] respect contexts while reading messages in dht License: MIT Signed-off-by: Jeromy --- routing/dht/dht_net.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/routing/dht/dht_net.go b/routing/dht/dht_net.go index 0152dab4a911..4e416f632476 100644 --- a/routing/dht/dht_net.go +++ b/routing/dht/dht_net.go @@ -1,6 +1,7 @@ package dht import ( + "fmt" "sync" "time" @@ -214,7 +215,7 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb log.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes) mes := new(pb.Message) - if err := ms.r.ReadMsg(mes); err != nil { + if err := ms.ctxReadMsg(ctx, mes); err != nil { ms.s.Close() ms.s = nil return nil, err @@ -227,3 +228,23 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb return mes, nil } + +func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { + t := time.NewTimer(time.Second * 30) + defer t.Stop() + + errc := make(chan error, 1) + go func() { + errc <- ms.r.ReadMsg(mes) + }() + + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + log.Warning("dht context read timeout") + return fmt.Errorf("reading message timed out") + } +}