Skip to content

Commit

Permalink
fix: stop propagating request to backend if not valid
Browse files Browse the repository at this point in the history
  • Loading branch information
ytvnr authored and phiz71 committed Mar 28, 2022
1 parent 0366770 commit 926e820
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 23 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.gravitee.gateway</groupId>
<artifactId>gravitee-gateway-buffer</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import io.gravitee.gateway.api.buffer.Buffer;
import io.gravitee.gateway.api.http.HttpHeaderNames;
import io.gravitee.gateway.api.http.stream.TransformableRequestStreamBuilder;
import io.gravitee.gateway.api.stream.BufferedReadWriteStream;
import io.gravitee.gateway.api.stream.ReadWriteStream;
import io.gravitee.gateway.api.stream.SimpleReadWriteStream;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyConfiguration;
import io.gravitee.policy.api.PolicyResult;
Expand Down Expand Up @@ -172,12 +174,24 @@ public ReadWriteStream<Buffer> onRequestContent(Request request, PolicyChain pol
.anyMatch(ct -> ct.endsWith(MEDIA_TEXT_XML.getSubtype()))
) {
// The policy is only applicable to json content type.
return TransformableRequestStreamBuilder
.on(request)
.chain(policyChain)
.transform(buffer -> {
return new BufferedReadWriteStream() {
final Buffer buffer = Buffer.buffer();

@Override
public SimpleReadWriteStream<Buffer> write(Buffer content) {
buffer.appendBuffer(content);
return this;
}

@Override
public void end() {
try {
validateXml(buffer.toString());

if (buffer.length() > 0) {
super.write(buffer);
}
super.end();
} catch (XmlException e) {
policyChain.streamFailWith(
PolicyResult.failure(e.getKey(), HttpStatusCode.BAD_REQUEST_400, BAD_REQUEST, MediaType.TEXT_PLAIN)
Expand All @@ -187,12 +201,9 @@ public ReadWriteStream<Buffer> onRequestContent(Request request, PolicyChain pol
PolicyResult.failure(HttpStatusCode.INTERNAL_SERVER_ERROR_500, SERVER_ERROR, MediaType.TEXT_PLAIN)
);
}

return buffer;
})
.build();
}
};
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
*/
package io.gravitee.policy.threatprotection.xml;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;

import io.gravitee.common.http.MediaType;
import io.gravitee.gateway.api.Request;
import io.gravitee.gateway.api.buffer.Buffer;
import io.gravitee.gateway.api.http.HttpHeaderNames;
import io.gravitee.gateway.api.http.HttpHeaders;
import io.gravitee.gateway.api.stream.BufferedReadWriteStream;
import io.gravitee.gateway.api.stream.ReadWriteStream;
import io.gravitee.gateway.api.stream.SimpleReadWriteStream;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyResult;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -75,19 +77,23 @@ public void shouldAcceptAllWhenContentTypeIsNotXml() {
when(request.headers()).thenReturn(HttpHeaders.create());
ReadWriteStream<?> readWriteStream = cut.onRequestContent(request, policyChain);

assertNull(readWriteStream);
assertThat(readWriteStream).isNull();
}

@Test
public void shouldAcceptValidXml() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("<test valid=\"true\">value</test>"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isTrue();

verifyZeroInteractions(policyChain);
}

Expand All @@ -96,11 +102,15 @@ public void shouldRejectInvalidXml() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("Invalid"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -110,11 +120,15 @@ public void shouldRejectWhenMaxTextValueLengthExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("<test valid=\"true\">value</test>"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -124,11 +138,14 @@ public void shouldRejectWhenMaxLengthExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("<test valid=\"false\">1234</test>"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -138,7 +155,8 @@ public void shouldRejectWhenMaxAttributesPerElementExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(
Buffer.buffer(
Expand All @@ -147,6 +165,8 @@ public void shouldRejectWhenMaxAttributesPerElementExceeded() {
);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -156,11 +176,14 @@ public void shouldRejectWhenMaxChildrenExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("<test><child1></child1><child2></child2></test>"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -170,11 +193,14 @@ public void shouldRejectWhenMaxDepthExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("<test><child><subChild></subChild></child></test>"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -184,7 +210,8 @@ public void shouldRejectWhenEntityMaxDepthExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(
Buffer.buffer(
Expand All @@ -200,6 +227,8 @@ public void shouldRejectWhenEntityMaxDepthExceeded() {
);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -209,7 +238,8 @@ public void shouldRejectWhenMaxEntitiesExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(
Buffer.buffer(
Expand All @@ -218,6 +248,8 @@ public void shouldRejectWhenMaxEntitiesExceeded() {
);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -227,11 +259,14 @@ public void shouldRejectWhenMaxElementsExceeded() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(Buffer.buffer("<test><element1 /><element2 /><element3 /></test>"));
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

Expand All @@ -241,7 +276,8 @@ public void shouldRejectExternalEntities() {
cut = new XmlThreatProtectionPolicy(configuration);
ReadWriteStream<Buffer> readWriteStream = cut.onRequestContent(request, policyChain);

assertNotNull(readWriteStream);
assertThat(readWriteStream).isNotNull();
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

// Perform an injection of logback xml file.
String path = getClass().getResource("/logback-test.xml").getPath();
Expand All @@ -257,6 +293,23 @@ public void shouldRejectExternalEntities() {
);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

verify(policyChain, times(1)).streamFailWith(any(PolicyResult.class));
}

/**
* Replace the endHandler of the resulting ReadWriteStream of the policy execution.
* This endHandler will set an {@link AtomicBoolean} to {@code true} if its called.
* It will allow us to verify if super.end() has been called on {@link BufferedReadWriteStream#end()}
* @param readWriteStream: the {@link ReadWriteStream} to modify
* @return an AtomicBoolean set to {@code true} if {@link SimpleReadWriteStream#end()}, else {@code false}
*/
private AtomicBoolean spyEndHandler(ReadWriteStream readWriteStream) {
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = new AtomicBoolean(false);
readWriteStream.endHandler(__ -> {
hasCalledEndOnReadWriteStreamParentClass.set(true);
});
return hasCalledEndOnReadWriteStreamParentClass;
}
}

0 comments on commit 926e820

Please sign in to comment.