Skip to content

Commit

Permalink
Fix in MockMultipartHttpServletRequest#getMultipartHeaders
Browse files Browse the repository at this point in the history
Previously this method returned headers only when a Content-Type part header
was present. Now it is guaranteed to return headers (possibly empty) as long
as there is a MultipartFile or Part with the given name.

Closes gh-26501
  • Loading branch information
rstoyanchev committed Feb 3, 2021
1 parent 7a329eb commit c52526a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.mock.web;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
Expand All @@ -33,6 +34,7 @@
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;

Expand Down Expand Up @@ -155,15 +157,28 @@ public HttpHeaders getRequestHeaders() {

@Override
public HttpHeaders getMultipartHeaders(String paramOrFileName) {
String contentType = getMultipartContentType(paramOrFileName);
if (contentType != null) {
MultipartFile file = getFile(paramOrFileName);
if (file != null) {
HttpHeaders headers = new HttpHeaders();
headers.add(HttpHeaders.CONTENT_TYPE, contentType);
if (file.getContentType() != null) {
headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType());
}
return headers;
}
else {
return null;
try {
Part part = getPart(paramOrFileName);
if (part != null) {
HttpHeaders headers = new HttpHeaders();
for (String headerName : part.getHeaderNames()) {
headers.put(headerName, new ArrayList<>(part.getHeaders(headerName)));
}
return headers;
}
}
catch (Throwable ex) {
throw new MultipartException("Could not access multipart servlet request", ex);
}
return null;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2011 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,9 +60,10 @@ public interface MultipartHttpServletRequest extends HttpServletRequest, Multipa
HttpHeaders getRequestHeaders();

/**
* Return the headers associated with the specified part of the multipart request.
* <p>If the underlying implementation supports access to headers, then all headers are returned.
* Otherwise, the returned headers will include a 'Content-Type' header at the very least.
* Return the headers for the specified part of the multipart request.
* <p>If the underlying implementation supports access to part headers,
* then all headers are returned. Otherwise, e.g. for a file upload, the
* returned headers may expose a 'Content-Type' if available.
*/
@Nullable
HttpHeaders getMultipartHeaders(String paramOrFileName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.web.testfixture.servlet;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
Expand All @@ -33,6 +34,7 @@
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;

Expand Down Expand Up @@ -155,15 +157,28 @@ public HttpHeaders getRequestHeaders() {

@Override
public HttpHeaders getMultipartHeaders(String paramOrFileName) {
String contentType = getMultipartContentType(paramOrFileName);
if (contentType != null) {
MultipartFile file = getFile(paramOrFileName);
if (file != null) {
HttpHeaders headers = new HttpHeaders();
headers.add(HttpHeaders.CONTENT_TYPE, contentType);
if (file.getContentType() != null) {
headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType());
}
return headers;
}
else {
return null;
try {
Part part = getPart(paramOrFileName);
if (part != null) {
HttpHeaders headers = new HttpHeaders();
for (String headerName : part.getHeaderNames()) {
headers.put(headerName, new ArrayList<>(part.getHeaders(headerName)));
}
return headers;
}
}
catch (Throwable ex) {
throw new MultipartException("Could not access multipart servlet request", ex);
}
return null;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,6 +36,7 @@
import org.springframework.http.HttpInputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.util.ReflectionUtils;
import org.springframework.validation.BindingResult;
Expand All @@ -51,6 +52,7 @@
import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.support.MissingServletRequestPartException;
import org.springframework.web.testfixture.method.ResolvableMethod;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import org.springframework.web.testfixture.servlet.MockMultipartFile;
Expand Down Expand Up @@ -311,6 +313,22 @@ public void resolveRequestPartNotRequired() throws Exception {
testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
}

@Test // gh-26501
public void resolveRequestPartWithoutContentType() throws Exception {
MockMultipartHttpServletRequest servletRequest = new MockMultipartHttpServletRequest();
servletRequest.addPart(new MockPart("requestPartString", "part value".getBytes(StandardCharsets.UTF_8)));
ServletWebRequest webRequest = new ServletWebRequest(servletRequest, new MockHttpServletResponse());

List<HttpMessageConverter<?>> converters = Collections.singletonList(new StringHttpMessageConverter());
RequestPartMethodArgumentResolver resolver = new RequestPartMethodArgumentResolver(converters);
MethodParameter parameter = ResolvableMethod.on(getClass()).named("handle").build().arg(String.class);

Object actualValue = resolver.resolveArgument(
parameter, new ModelAndViewContainer(), webRequest, new ValidatingBinderFactory());

assertThat(actualValue).isEqualTo("part value");
}

@Test
public void isMultipartRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
Expand Down Expand Up @@ -606,7 +624,8 @@ public void handle(
@RequestPart("requestPart") Optional<List<MultipartFile>> optionalMultipartFileList,
Optional<Part> optionalPart,
@RequestPart("requestPart") Optional<List<Part>> optionalPartList,
@RequestPart("requestPart") Optional<SimpleBean> optionalRequestPart) {
@RequestPart("requestPart") Optional<SimpleBean> optionalRequestPart,
@RequestPart("requestPartString") String requestPartString) {
}

}

0 comments on commit c52526a

Please sign in to comment.