diff --git a/client.go b/client.go index 395c045..402e7f8 100644 --- a/client.go +++ b/client.go @@ -8,20 +8,28 @@ import ( ) type client struct { - Request *pack.Request + Request *pack.Request Response *pack.Response - Http *http.Client + Http *http.Request + HttpClient *http.Client } // 初始化一个客户端 -func Client(addr string, method string, params interface{}) *client { +func Client(addr string, method string, params interface{}) (*client, error) { + httpRequest, err := http.NewRequest(http.MethodPost, addr, nil) + if err != nil { + return nil, err + } + + httpRequest.Header.Set("User-Agent", "Go Yar Rpc-0.1") c := &client{ - Request: pack.NewRequest(addr, method, params), + Request: pack.NewRequest(addr, method, params), Response: new(pack.Response), - Http: &http.Client{Timeout: time.Second}, + Http: httpRequest, + HttpClient: &http.Client{Timeout: time.Second}, } - return c + return c, nil } // 设置返回值结构体 @@ -44,7 +52,10 @@ func (c *client) Send() error { buffer.Write(data) // 发送请求 - resp, err := c.Http.Post(c.Request.Addr, packHandler.ContentType(), buffer) + c.Http.Body = ioutil.NopCloser(buffer) + c.Http.Header.Set("Content-Type", packHandler.ContentType()) + + resp, err := c.HttpClient.Do(c.Http) if err != nil { return err } @@ -54,11 +65,15 @@ func (c *client) Send() error { // 解析处理 headerData := pack.NewHeaderWithBody(body, c.Request.Protocol) packHandler = pack.GetPackHandler(headerData.Packager) - bodyContent := body[pack.ProtocolLength + pack.PackagerLength:] + bodyContent := body[pack.ProtocolLength+pack.PackagerLength:] err = packHandler.Decode(bodyContent, c.Response) if err != nil { return err } - return c.Response.Except + if c.Response.Except != nil { + return c.Response.Except + } + + return nil } diff --git a/pack/msgpack.go b/pack/msgpack.go index cc48f18..a6974c5 100644 --- a/pack/msgpack.go +++ b/pack/msgpack.go @@ -1,18 +1,30 @@ package pack -import "github.com/vmihailenco/msgpack" +import ( + "bytes" + "github.com/vmihailenco/msgpack" +) -// msgpack处理器 +// msgpack处理器,兼容json tag定义 type EncoderMsgpack struct { } func (p *EncoderMsgpack) Encode(request *Request) ([]byte, error) { - return msgpack.Marshal(request) + var buf bytes.Buffer + encoder := msgpack.NewEncoder(&buf) + encoder.UseJSONTag(true) + err := encoder.Encode(request) + if err != nil { + return nil, err + } + return buf.Bytes(), err } func (p *EncoderMsgpack) Decode(body []byte, response *Response) error{ - response.Protocol = ProtocolMsgpack - return msgpack.Unmarshal(body, response) + reader := bytes.NewReader(body) + decoder := msgpack.NewDecoder(reader) + decoder.UseJSONTag(true) + return decoder.Decode(response) } func (p *EncoderMsgpack) ContentType() string {