diff --git a/client/request.go b/client/request.go index d7dc564e..4c00ed3a 100644 --- a/client/request.go +++ b/client/request.go @@ -161,15 +161,22 @@ func (r *request) buildHTTP(mediaType, basePath string, producers map[string]run }() for fn, f := range r.fileFields { for _, fi := range f { - // Need to read the data so that we can detect the content type - buf := make([]byte, 512) - size, err := fi.Read(buf) - if err != nil { - logClose(err, pw) - return + var fileContentType string + if p, ok := fi.(interface { + ContentType() string + }); ok { + fileContentType = p.ContentType() + } else { + // Need to read the data so that we can detect the content type + buf := make([]byte, 512) + size, err := fi.Read(buf) + if err != nil { + logClose(err, pw) + return + } + fileContentType = http.DetectContentType(buf) + fi = runtime.NamedReader(fi.Name(), io.MultiReader(bytes.NewReader(buf[:size]), fi)) } - fileContentType := http.DetectContentType(buf) - newFi := runtime.NamedReader(fi.Name(), io.MultiReader(bytes.NewReader(buf[:size]), fi)) // Create the MIME headers for the new part h := make(textproto.MIMEHeader) @@ -183,7 +190,7 @@ func (r *request) buildHTTP(mediaType, basePath string, producers map[string]run logClose(err, pw) return } - if _, err := io.Copy(wrtr, newFi); err != nil { + if _, err := io.Copy(wrtr, fi); err != nil { logClose(err, pw) } } diff --git a/client/request_test.go b/client/request_test.go index e6a06988..4a5e68ef 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -430,6 +430,7 @@ func TestBuildRequest_BuildHTTP_Files(t *testing.T) { } } } + func TestBuildRequest_BuildHTTP_Files_URLEncoded(t *testing.T) { cont, _ := os.ReadFile("./runtime.go") cont2, _ := os.ReadFile("./request.go") @@ -475,6 +476,57 @@ func TestBuildRequest_BuildHTTP_Files_URLEncoded(t *testing.T) { } } +type contentTypeProvider struct { + runtime.NamedReadCloser + contentType string +} + +func (p contentTypeProvider) ContentType() string { + return p.contentType +} + +func TestBuildRequest_BuildHTTP_File_ContentType(t *testing.T) { + cont, _ := os.ReadFile("./runtime.go") + cont2, _ := os.ReadFile("./request.go") + reqWrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error { + _ = req.SetPathParam("id", "1234") + _ = req.SetFileParam("file1", contentTypeProvider{ + NamedReadCloser: mustGetFile("./runtime.go"), + contentType: "application/octet-stream", + }) + _ = req.SetFileParam("file2", mustGetFile("./request.go")) + + return nil + }) + r, _ := newRequest("GET", "/flats/{id}/", reqWrtr) + _ = r.SetHeaderParam(runtime.HeaderContentType, runtime.JSONMime) + req, err := r.BuildHTTP(runtime.JSONMime, "", testProducers, nil) + if assert.NoError(t, err) && assert.NotNil(t, req) { + assert.Equal(t, "/flats/1234/", req.URL.Path) + mediaType, params, err := mime.ParseMediaType(req.Header.Get(runtime.HeaderContentType)) + if assert.NoError(t, err) { + assert.Equal(t, runtime.MultipartFormMime, mediaType) + boundary := params["boundary"] + mr := multipart.NewReader(req.Body, boundary) + defer req.Body.Close() + frm, err := mr.ReadForm(1 << 20) + if assert.NoError(t, err) { + fileverifier := func(name string, index int, filename string, content []byte, contentType string) { + mpff := frm.File[name][index] + mpf, _ := mpff.Open() + defer mpf.Close() + assert.Equal(t, filename, mpff.Filename) + actual, _ := io.ReadAll(mpf) + assert.Equal(t, content, actual) + assert.Equal(t, mpff.Header.Get("Content-Type"), contentType) + } + fileverifier("file1", 0, "runtime.go", cont, "application/octet-stream") + fileverifier("file2", 0, "request.go", cont2, "text/plain; charset=utf-8") + } + } + } +} + func TestBuildRequest_BuildHTTP_BasePath(t *testing.T) { reqWrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error { _ = req.SetBodyParam(nil)