Skip to content

Commit

Permalink
Rest filter support pattern & some bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
oxsean committed Jul 31, 2024
1 parent 98e96d9 commit e40f3c1
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

public final class RestConstants {

public static final String REST = "rest";

public static final String REST_FILTER_KEY = "rest.filter";
public static final String EXTENSION_KEY = "extension";
public static final String EXTENSIONS_ATTRIBUTE_KEY = "restExtensionsAttributeKey";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.dubbo.rpc.protocol.tri.rest.filter;

import org.apache.dubbo.common.utils.ArrayUtils;
import org.apache.dubbo.rpc.protocol.tri.rest.util.RestUtils;

import java.util.Arrays;
Expand Down Expand Up @@ -47,7 +48,7 @@ public String toString() {
sb.append(", priority=").append(priority);
}
String[] patterns = getPatterns();
if (patterns != null) {
if (ArrayUtils.isNotEmpty(patterns)) {
sb.append(", patterns=").append(Arrays.toString(patterns));
}
return sb.append('}').toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import org.apache.dubbo.common.constants.CommonConstants;
import org.apache.dubbo.common.extension.Activate;
import org.apache.dubbo.common.extension.ExtensionAccessorAware;
import org.apache.dubbo.common.logger.Logger;
import org.apache.dubbo.common.logger.LoggerFactory;
import org.apache.dubbo.common.utils.ArrayUtils;
import org.apache.dubbo.common.utils.CollectionUtils;
import org.apache.dubbo.common.utils.StringUtils;
import org.apache.dubbo.remoting.http12.HttpRequest;
import org.apache.dubbo.remoting.http12.HttpResponse;
Expand All @@ -35,20 +39,27 @@
import org.apache.dubbo.rpc.protocol.tri.rest.Messages;
import org.apache.dubbo.rpc.protocol.tri.rest.RestConstants;
import org.apache.dubbo.rpc.protocol.tri.rest.RestInitializeException;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RadixTree;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RadixTree.Match;
import org.apache.dubbo.rpc.protocol.tri.rest.util.RestUtils;
import org.apache.dubbo.rpc.protocol.tri.rest.util.TypeUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;

@Activate(group = CommonConstants.PROVIDER, order = 1000)
public class RestExtensionExecutionFilter extends RestFilterAdapter {

private static final Logger LOGGER = LoggerFactory.getLogger(RestExtensionExecutionFilter.class);
private static final String KEY = RestExtensionExecutionFilter.class.getSimpleName();

private final Map<RestFilter, RadixTree<Boolean>> filterTreeCache = CollectionUtils.newConcurrentHashMap();
private final ApplicationModel applicationModel;
private final List<RestExtensionAdapter<Object>> extensionAdapters;

Expand All @@ -61,7 +72,7 @@ public RestExtensionExecutionFilter(ApplicationModel applicationModel) {
@Override
protected Result invoke(Invoker<?> invoker, Invocation invocation, HttpRequest request, HttpResponse response)
throws RpcException {
RestFilter[] filters = getFilters(invoker);
RestFilter[] filters = matchFilters(getFilters(invoker), request.path());
DefaultFilterChain chain = new DefaultFilterChain(filters, invocation, () -> invoker.invoke(invocation));
invocation.put(KEY, chain);
try {
Expand Down Expand Up @@ -127,13 +138,65 @@ protected void onError(
chain.onError(t, request, response);
}

private RestFilter[] matchFilters(RestFilter[] filters, String path) {
int len = filters.length;
BitSet bitSet = new BitSet(len);
out:
for (int i = 0; i < len; i++) {
RestFilter filter = filters[i];
String[] patterns = filter.getPatterns();
if (ArrayUtils.isEmpty(patterns)) {
continue;
}
RadixTree<Boolean> filterTree = filterTreeCache.computeIfAbsent(filter, f -> {
RadixTree<Boolean> tree = new RadixTree<>();
for (String pattern : patterns) {
if (StringUtils.isNotEmpty(pattern)) {
if (pattern.charAt(0) == '!') {
tree.addPath(pattern.substring(1), false);
} else {
tree.addPath(pattern, true);
}
}
}
return tree;
});

List<Match<Boolean>> matches = filterTree.match(path);
int size = matches.size();
if (size == 0) {
bitSet.set(i);
continue;
}
for (int j = 0; j < size; j++) {
if (!matches.get(j).getValue()) {
bitSet.set(i);
continue out;
}
}
}
if (bitSet.isEmpty()) {
return filters;
}
RestFilter[] matched = new RestFilter[len - bitSet.cardinality()];
for (int i = 0, j = 0; i < len; i++) {
if (!bitSet.get(i)) {
matched[j++] = filters[i];
}
}
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Matched filters for path '{}' is {}", path, Arrays.toString(matched));
}
return matched;
}

@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
private RestFilter[] getFilters(Invoker<?> invoker) {
URL url = invoker.getUrl();
RestFilter[] filters = getFilters(url);
if (filters != null) {
return filters;
}
//noinspection SynchronizationOnLocalVariableOrMethodParameter
synchronized (invoker) {
filters = getFilters(url);
if (filters != null) {
Expand All @@ -150,6 +213,7 @@ private RestFilter[] getFilters(URL url) {
}

private RestFilter[] loadFilters(URL url) {
LOGGER.info("Loading rest filters for {}", url);
List<RestFilter> extensions = new ArrayList<>();

// 1. load from extension config
Expand Down Expand Up @@ -186,13 +250,29 @@ private void adaptExtension(Object extension, List<RestFilter> extensions) {
extension = ((Supplier<?>) extension).get();
}
if (extension instanceof RestFilter) {
extensions.add((RestFilter) extension);
addRestFilter(extension, (RestFilter) extension, extensions);
return;
}
for (RestExtensionAdapter<Object> adapter : extensionAdapters) {
if (adapter.accept(extension)) {
extensions.add(adapter.adapt(extension));
addRestFilter(extension, adapter.adapt(extension), extensions);
}
}
}

private void addRestFilter(Object extension, RestFilter filter, List<RestFilter> extensions) {
extensions.add(filter);
if (!LOGGER.isInfoEnabled()) {
return;
}
StringBuilder sb = new StringBuilder(64);
sb.append("Rest filter [").append(extension).append("] loaded");
if (filter.getPriority() != 0) {
sb.append(", priority=").append(filter.getPriority());
}
if (ArrayUtils.isNotEmpty(filter.getPatterns())) {
sb.append(", patterns=").append(Arrays.toString(filter.getPatterns()));
}
LOGGER.info(sb.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ public void match(KeyString path, List<Match<T>> matches) {
for (int i = 0, size = directMatches.size(); i < size; i++) {
matches.add(directMatches.get(i));
}
return;
}

matchRecursive(root, path, 1, new HashMap<>(), matches);
Expand All @@ -157,11 +156,12 @@ public void match(String path, List<Match<T>> matches) {

public List<Match<T>> match(KeyString path) {
List<Match<T>> matches = directPathMap.get(path);
if (matches != null) {
return new ArrayList<>(matches);
if (matches == null) {
matches = new ArrayList<>();
} else {
matches = new ArrayList<>(matches);
}

matches = new ArrayList<>();
matchRecursive(root, path, 1, new HashMap<>(), matches);
return matches;
}
Expand All @@ -175,8 +175,10 @@ private void matchRecursive(
int end = path.indexOf('/', start);
Node<T> node = current.children.get(new KeyString(path, start, end));
if (node != null) {
if (node.isLeaf()) {
addMatch(node, variableMap, matches);
if (end == -1) {
if (node.isLeaf()) {
addMatch(node, variableMap, matches);
}
return;
}
matchRecursive(node, path, end + 1, variableMap, matches);
Expand All @@ -191,13 +193,19 @@ private void matchRecursive(
if (segment.match(path, start, end, workVariableMap)) {
workVariableMap.putAll(variableMap);
Node<T> child = entry.getValue();
if (segment.isTailMatching() || child.isLeaf()) {
if (segment.isTailMatching()) {
addMatch(child, workVariableMap, matches);
} else {
matchRecursive(child, path, end + 1, workVariableMap, matches);
if (end == -1) {
if (child.isLeaf()) {
addMatch(child, workVariableMap, matches);
}
} else {
matchRecursive(child, path, end + 1, workVariableMap, matches);
}
}
if (!workVariableMap.isEmpty()) {
workVariableMap = new HashMap<>();
workVariableMap = new LinkedHashMap<>();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,21 @@ public Map<String, String> match(@Nonnull String path) {
}
Map<String, String> variableMap = new LinkedHashMap<>();
int start, end = 0;
for (PathSegment segment : segments) {
for (int i = 0, len = segments.length; i < len; i++) {
PathSegment segment = segments[i];
if (end != -1) {
start = end + 1;
end = path.indexOf('/', start);
if (segment.match(new KeyString(path), start, end, variableMap)) {
if (i == len - 1 && segment.isTailMatching()) {
return variableMap;
}
continue;
}
}
return null;
}
return variableMap;
return end == -1 ? variableMap : null;
}

public int compareTo(PathExpression other, String lookupPath) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ public boolean equals(Object obj) {

@Override
public String toString() {
return value.substring(offset, length - offset);
return value.substring(offset, offset + length);
}

public int indexOf(char ch, int start) {
int index = value.indexOf(ch, offset + start);
return index == -1 ? -1 : index - offset;
return index == -1 || index >= offset + length ? -1 : index - offset;
}

public boolean regionMatches(int start, String value, int i, int length) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.dubbo.rpc.protocol.tri.rest.filter

import org.apache.dubbo.common.URL
import org.apache.dubbo.rpc.Invoker
import org.apache.dubbo.rpc.model.ApplicationModel

import spock.lang.Specification

class RestFilterTest extends Specification {

@SuppressWarnings('GroovyAccessibility')
def "test filter patterns"() {
given:
Invoker invoker = Mock(Invoker)
invoker.getUrl() >> URL.valueOf("tri://127.0.0.1/test?extension=org.apache.dubbo.rpc.protocol.tri.rest.filter.TestRestFilter")

var filter = new RestExtensionExecutionFilter(ApplicationModel.defaultModel())
expect:
filter.matchFilters(filter.getFilters(invoker), path).length == len
where:
path | len
'/filter/one' | 1
'/filter/one/1' | 1
'/one.filter' | 2
'/filter/two' | 2
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import spock.lang.Specification

class RadixTreeTest extends Specification {

def "Match"() {
def "match"() {
given:
def tree = new RadixTree<String>()
tree.addPath('/a/*', 'abc')
Expand All @@ -33,7 +33,7 @@ class RadixTreeTest extends Specification {
!match.empty
}

def "Clear"() {
def "clear"() {
given:
def tree = new RadixTree<String>()
tree.addPath('/a/*', 'abc')
Expand All @@ -44,4 +44,18 @@ class RadixTreeTest extends Specification {
then:
tree.empty
}

def "test end match"() {
given:
def tree = new RadixTree<Boolean>()
tree.addPath('/a/*/*', true)
expect:
tree.match(path).size() == len
where:
path | len
'/a' | 0
'/a/b' | 0
'/a/b/c' | 1
'/a/b/c/d' | 0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class PathExpressionTest extends Specification {
'/resources/**' | '/resources/a/b/c' | true
'/resources/{*path}' | '/resources/a/b/c' | [path: 'a/b/c']
'/resources/*' | '/resources/a' | true
'/resources/*' | '/resources/a/b/c' | true
'/resources/{*}' | '/resources/a/b/c' | true
'/{id:\\d+}' | '/123' | [id: '123']
'/{id:\\d+}' | '/one' | false
'/a?cd/ef' | '/abcd/ef' | true
Expand All @@ -194,8 +194,8 @@ class PathExpressionTest extends Specification {
expect:
parse(path).match(value) != null
where:
path | value
'/resources/*' | '/resources/a/b/c'
path | value
'/resources/{*}' | '/resources/a/b/c'
}

def "CompareTo"() {
Expand Down
Loading

0 comments on commit e40f3c1

Please sign in to comment.