Skip to content

Commit

Permalink
Refactor CrossOriginFilter
Browse files Browse the repository at this point in the history
Signed-off-by: Denny Abraham Cheriyan <[email protected]>
  • Loading branch information
dennyac committed Mar 14, 2020
1 parent b1d30fc commit 3c4ab14
Showing 1 changed file with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter;
Expand Down Expand Up @@ -157,8 +159,10 @@ public class CrossOriginFilter implements Filter
private boolean anyOriginAllowed;
private boolean anyTimingOriginAllowed;
private boolean anyHeadersAllowed;
private List<String> allowedOrigins = new ArrayList<String>();
private List<String> allowedTimingOrigins = new ArrayList<String>();
private Set<String> allowedOrigins = new HashSet<String>();
private List<Pattern> allowedOriginPatterns = new ArrayList<Pattern>();
private Set<String> allowedTimingOrigins = new HashSet<String>();
private List<Pattern> allowedTimingOriginPatterns = new ArrayList<Pattern>();
private List<String> allowedMethods = new ArrayList<String>();
private List<String> allowedHeaders = new ArrayList<String>();
private List<String> exposedHeaders = new ArrayList<String>();
Expand All @@ -172,8 +176,8 @@ public void init(FilterConfig config) throws ServletException
String allowedOriginsConfig = config.getInitParameter(ALLOWED_ORIGINS_PARAM);
String allowedTimingOriginsConfig = config.getInitParameter(ALLOWED_TIMING_ORIGINS_PARAM);

anyOriginAllowed = generateAllowedOrigins(allowedOrigins, allowedOriginsConfig, DEFAULT_ALLOWED_ORIGINS);
anyTimingOriginAllowed = generateAllowedOrigins(allowedTimingOrigins, allowedTimingOriginsConfig, DEFAULT_ALLOWED_TIMING_ORIGINS);
anyOriginAllowed = generateAllowedOrigins(allowedOrigins, allowedOriginPatterns, allowedOriginsConfig, DEFAULT_ALLOWED_ORIGINS);
anyTimingOriginAllowed = generateAllowedOrigins(allowedTimingOrigins, allowedTimingOriginPatterns, allowedTimingOriginsConfig, DEFAULT_ALLOWED_TIMING_ORIGINS);

String allowedMethodsConfig = config.getInitParameter(ALLOWED_METHODS_PARAM);
if (allowedMethodsConfig == null)
Expand Down Expand Up @@ -235,7 +239,7 @@ else if ("*".equals(allowedHeadersConfig))
}
}

private boolean generateAllowedOrigins(List<String> allowedOriginStore, String allowedOriginsConfig, String defaultOrigin)
private boolean generateAllowedOrigins(Set<String> allowedOriginStore, List<Pattern> allowedOriginPatternStore, String allowedOriginsConfig, String defaultOrigin)
{
if (allowedOriginsConfig == null)
allowedOriginsConfig = defaultOrigin;
Expand All @@ -247,8 +251,12 @@ private boolean generateAllowedOrigins(List<String> allowedOriginStore, String a
if (ANY_ORIGIN.equals(allowedOrigin))
{
allowedOriginStore.clear();
allowedOriginPatternStore.clear();
return true;
}
else if (allowedOrigin.contains("*")) {
allowedOriginPatternStore.add(Pattern.compile(parseAllowedWildcardOriginToRegex(allowedOrigin)));
}
else
{
allowedOriginStore.add(allowedOrigin);
Expand All @@ -270,7 +278,7 @@ private void handle(HttpServletRequest request, HttpServletResponse response, Fi
// Is it a cross origin request ?
if (origin != null && isEnabled(request))
{
if (anyOriginAllowed || originMatches(allowedOrigins, origin))
if (anyOriginAllowed || originMatches(allowedOrigins, allowedOriginPatterns, origin))
{
if (isSimpleRequest(request))
{
Expand All @@ -292,7 +300,7 @@ else if (isPreflightRequest(request))
handleSimpleResponse(request, response, origin);
}

if (anyTimingOriginAllowed || originMatches(allowedTimingOrigins, origin))
if (anyTimingOriginAllowed || originMatches(allowedTimingOrigins, allowedTimingOriginPatterns, origin))
{
response.setHeader(TIMING_ALLOW_ORIGIN_HEADER, origin);
}
Expand Down Expand Up @@ -330,7 +338,7 @@ protected boolean isEnabled(HttpServletRequest request)
return true;
}

private boolean originMatches(List<String> allowedOrigins, String originList)
private boolean originMatches(Set<String> allowedOrigins, List<Pattern> allowedOriginPatterns, String originList)
{
if (originList.trim().length() == 0)
return false;
Expand All @@ -341,30 +349,18 @@ private boolean originMatches(List<String> allowedOrigins, String originList)
if (origin.trim().length() == 0)
continue;

for (String allowedOrigin : allowedOrigins)
if (allowedOrigins.contains(origin))
return true;

for (Pattern allowedOrigin : allowedOriginPatterns)
{
if (allowedOrigin.contains("*"))
{
Matcher matcher = createMatcher(origin, allowedOrigin);
if (matcher.matches())
return true;
}
else if (allowedOrigin.equals(origin))
{
if (allowedOrigin.matcher(origin).matches())
return true;
}
}
}
return false;
}

private Matcher createMatcher(String origin, String allowedOrigin)
{
String regex = parseAllowedWildcardOriginToRegex(allowedOrigin);
Pattern pattern = Pattern.compile(regex);
return pattern.matcher(origin);
}

private String parseAllowedWildcardOriginToRegex(String allowedOrigin)
{
String regex = StringUtil.replace(allowedOrigin, ".", "\\.");
Expand Down Expand Up @@ -507,7 +503,11 @@ private String commify(List<String> strings)
public void destroy()
{
anyOriginAllowed = false;
anyTimingOriginAllowed = false;
allowedOrigins.clear();
allowedOriginPatterns.clear();
allowedTimingOrigins.clear();
allowedTimingOriginPatterns.clear();
allowedMethods.clear();
allowedHeaders.clear();
preflightMaxAge = 0;
Expand Down

0 comments on commit 3c4ab14

Please sign in to comment.