Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

snippet inject support like <head lang="en"> #8736

Merged
merged 12 commits into from
Aug 9, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ public boolean isContentTypeTextHtml() {
if (contentType == null) {
contentType = super.getHeader("content-type");
}
return contentType != null && contentType.startsWith("text/html");
return contentType != null
&& (contentType.startsWith("text/html") || "application/xhtml+xml".equals(contentType));
oliver-zhang marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,32 @@ abstract class AbstractServlet3Test<SERVER, CONTEXT> extends HttpServerTest<SERV
+ "<p>test works</p>\n"
+ "</body>\n"
+ "</html>")
public static final ServerEndpoint HTML_PRINT_WRITER_WITH_OTHER_HEAD =
new ServerEndpoint("HTML_PRINT_WRITER_WITH_OTHER_HEAD", "htmlPrintWriterWithOtherHead",
200,
"<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"
+ "<head lang=\"en\">\n"
+ " <meta charset=\"UTF-8\">\n"
+ " <title>Title</title>\n"
+ "</head>\n"
+ "<body>\n"
+ "<p>test works</p>\n"
+ "</body>\n"
+ "</html>")
public static final ServerEndpoint HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD =
new ServerEndpoint("HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD", "htmlServletOutputStreamWithOtherHead",
200,
"<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"
+ "<head lang=\"en\">\n"
+ " <meta charset=\"UTF-8\">\n"
+ " <title>Title</title>\n"
+ "</head>\n"
+ "<body>\n"
+ "<p>test works</p>\n"
+ "</body>\n"
+ "</html>")
protected void setupServlets(CONTEXT context) {
def servlet = servlet()

Expand All @@ -80,6 +106,8 @@ abstract class AbstractServlet3Test<SERVER, CONTEXT> extends HttpServerTest<SERV
addServlet(context, CAPTURE_PARAMETERS.path, servlet)
addServlet(context, HTML_PRINT_WRITER.path, servlet)
addServlet(context, HTML_SERVLET_OUTPUT_STREAM.path, servlet)
addServlet(context, HTML_PRINT_WRITER_WITH_OTHER_HEAD.path,servlet)
addServlet(context, HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD.path,servlet)
}

protected ServerEndpoint lastRequest
Expand Down Expand Up @@ -215,4 +243,89 @@ abstract class AbstractServlet3Test<SERVER, CONTEXT> extends HttpServerTest<SERV
}
}
}

def "snippet injection with PrintWriterWithOtherHead"() {
setup:
ExperimentalSnippetHolder.setSnippet("\n <script type=\"text/javascript\"> Test </script>")
def request = request(HTML_PRINT_WRITER_WITH_OTHER_HEAD, "GET")
def response = client.execute(request).aggregate().join()

expect:
response.status().code() == HTML_PRINT_WRITER_WITH_OTHER_HEAD.status
String result = "<!DOCTYPE html>\n" +
"<html lang=\"en\">\n" +
"<head lang=\"en\">\n" +
" <script type=\"text/javascript\"> Test </script>\n" +
" <meta charset=\"UTF-8\">\n" +
" <title>Title</title>\n" +
"</head>\n" +
"<body>\n" +
"<p>test works</p>\n" +
"</body>\n" +
"</html>"
println "result from call ${response.contentUtf8()}"
response.contentUtf8() == result
response.headers().contentLength() == result.length()

cleanup:
ExperimentalSnippetHolder.setSnippet("")

def expectedRoute = expectedHttpRoute(HTML_PRINT_WRITER_WITH_OTHER_HEAD)
assertTraces(1) {
trace(0, 2) {
span(0) {
name "GET" + (expectedRoute != null ? " " + expectedRoute : "")
kind SpanKind.SERVER
hasNoParent()
}
span(1) {
name "controller"
kind SpanKind.INTERNAL
childOf span(0)
}
}
}
}

def "snippet injection with ServletOutputStreamWithOtherHead"() {
setup:
ExperimentalSnippetHolder.setSnippet("\n <script type=\"text/javascript\"> Test Test</script>")
def request = request(HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD, "GET")
def response = client.execute(request).aggregate().join()

expect:
response.status().code() == HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD.status
String result = "<!DOCTYPE html>\n" +
"<html lang=\"en\">\n" +
"<head lang=\"en\">\n" +
" <script type=\"text/javascript\"> Test Test</script>\n" +
" <meta charset=\"UTF-8\">\n" +
" <title>Title</title>\n" +
"</head>\n" +
"<body>\n" +
"<p>test works</p>\n" +
"</body>\n" +
"</html>"
response.contentUtf8() == result
response.headers().contentLength() == result.length()

cleanup:
ExperimentalSnippetHolder.setSnippet("")

def expectedRoute = expectedHttpRoute(HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD)
assertTraces(1) {
trace(0, 2) {
span(0) {
name "GET" + (expectedRoute != null ? " " + expectedRoute : "")
kind SpanKind.SERVER
hasNoParent()
}
span(1) {
name "controller"
kind SpanKind.INTERNAL
childOf span(0)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ class JettyServlet3TestDispatchAsync extends JettyDispatchTest {
protected void setupServlets(ServletContextHandler context) {
super.setupServlets(context)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER.path, TestServlet3.DispatchAsync)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER_WITH_OTHER_HEAD.path, TestServlet3.DispatchAsync)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD.path, TestServlet3.DispatchAsync)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM.path, TestServlet3.DispatchAsync)
addServlet(context, "/dispatch" + SUCCESS.path, TestServlet3.DispatchAsync)
addServlet(context, "/dispatch" + QUERY_PARAM.path, TestServlet3.DispatchAsync)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ class TestServlet3 {
byte[] body = endpoint.body.getBytes()
resp.getOutputStream().write(body, 0, body.length)
break
case AbstractServlet3Test.HTML_PRINT_WRITER_WITH_OTHER_HEAD:
resp.contentType = "text/html"
resp.status = endpoint.status
resp.setContentLengthLong(endpoint.body.length())
resp.writer.print(endpoint.body)
break
case AbstractServlet3Test.HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD:
resp.contentType = "text/html"
resp.status = endpoint.status
resp.setContentLength(endpoint.body.length())
byte[] body = endpoint.body.getBytes()
resp.getOutputStream().write(body, 0, body.length)
break
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ class TomcatServlet3TestForward extends TomcatDispatchTest {
addServlet(context, "/dispatch" + INDEXED_CHILD.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM_WITH_OTHER_HEAD.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER_WITH_OTHER_HEAD.path, RequestDispatcherServlet.Forward)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public boolean isContentTypeTextHtml() {
if (contentType == null) {
contentType = super.getHeader("content-type");
}
return contentType != null && contentType.startsWith("text/html");
return contentType != null
&& (contentType.startsWith("text/html") || "application/xhtml+xml".equals(contentType));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// this is shared by both ServletOutputStream and PrintWriter injection
public class InjectionState {
private static final int HEAD_TAG_WRITTEN_FAKE_VALUE = -1;
private static final int HEAD_TAG_LENGTH = "<head>".length();
private static final int HEAD_TAG_PREFIX_LENGTH = "<head".length();
private final SnippetInjectingResponseWrapper wrapper;
private int headTagBytesSeen = 0;

Expand Down Expand Up @@ -40,12 +40,12 @@ public boolean processByte(int b) {
if (isHeadTagWritten()) {
return false;
}
if (inHeadTag(b)) {
if (inHeadTag(b) && headTagBytesSeen < HEAD_TAG_PREFIX_LENGTH) {
oliver-zhang marked this conversation as resolved.
Show resolved Hide resolved
headTagBytesSeen++;
} else {
} else if (headTagBytesSeen != HEAD_TAG_PREFIX_LENGTH) {
headTagBytesSeen = 0;
}
if (headTagBytesSeen == HEAD_TAG_LENGTH) {
if (headTagBytesSeen == HEAD_TAG_PREFIX_LENGTH && b == '>') {
setHeadTagWritten();
return true;
} else {
Expand All @@ -64,10 +64,9 @@ private boolean inHeadTag(int b) {
return true;
} else if (headTagBytesSeen == 4 && b == 'd') {
return true;
} else if (headTagBytesSeen == 5 && b == '>') {
return true;
} else {
return headTagBytesSeen == HEAD_TAG_PREFIX_LENGTH;
}
return false;
}

public SnippetInjectingResponseWrapper getWrapper() {
Expand Down