Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Fix CSV injection issue #447

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,80 @@

package com.amazon.opendistroforelasticsearch.sql.executor.csv;

import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Created by Eliran on 27/12/2015.
*/
public class CSVResult {

private static final Set<String> SENSITIVE_CHAR = ImmutableSet.of("=", "+", "-", "@");

private final List<String> headers;
private final List<String> lines;

/**
* Skip sanitizing if string line provided. This constructor is basically used by
* assertion in test code.
*/
public CSVResult(List<String> headers, List<String> lines) {
this.headers = headers;
this.lines = lines;
}

public CSVResult(String separator, List<String> headers, List<List<String>> lines) {
this.headers = sanitizeHeaders(headers);
this.lines = sanitizeLines(separator, lines);
}

/**
* Return CSV header names which are sanitized because Elasticsearch allows
* special character present in field name too.
* @return CSV header name list after sanitized
*/
public List<String> getHeaders() {
return headers;
}

/**
* Return CSV lines in which each cell is sanitized to avoid CSV injection.
* @return CSV lines after sanitized
*/
public List<String> getLines() {
return lines;
}

private List<String> sanitizeHeaders(List<String> headers) {
return headers.stream().
map(this::sanitizeCell).
collect(Collectors.toList());
}

private List<String> sanitizeLines(String separator, List<List<String>> lines) {
List<String> result = new ArrayList<>();
for (List<String> line : lines) {
result.add(line.stream().
map(this::sanitizeCell).
collect(Collectors.joining(separator)));
}
return result;
}

private String sanitizeCell(String cell) {
if (isStartWithSensitiveChar(cell)) {
return "'" + cell;
}
return cell;
}

private boolean isStartWithSensitiveChar(String cell) {
return SENSITIVE_CHAR.stream().
anyMatch(cell::startsWith);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.amazon.opendistroforelasticsearch.sql.expression.domain.BindingTuple;
import com.amazon.opendistroforelasticsearch.sql.expression.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.utils.Util;
import com.google.common.base.Joiner;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
Expand Down Expand Up @@ -66,44 +65,33 @@ public CSVResult extractResults(Object queryResult, boolean flat, String separat
SearchHit[] hits = ((SearchHits) queryResult).getHits();
List<Map<String, Object>> docsAsMap = new ArrayList<>();
List<String> headers = createHeadersAndFillDocsMap(flat, hits, docsAsMap, fieldNames);
List<String> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers);
return new CSVResult(headers, csvLines);
List<List<String>> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers);
return new CSVResult(separator, headers, csvLines);
}
if (queryResult instanceof Aggregations) {
List<String> headers = new ArrayList<>();
List<List<String>> lines = new ArrayList<>();
lines.add(new ArrayList<String>());
handleAggregations((Aggregations) queryResult, headers, lines);

List<String> csvLines = new ArrayList<>();
for (List<String> simpleLine : lines) {
csvLines.add(Joiner.on(separator).join(simpleLine));
}

//todo: need to handle more options for aggregations:
//Aggregations that inhrit from base
//ScriptedMetric

return new CSVResult(headers, csvLines);

return new CSVResult(separator, headers, lines);
}
// Handle List<BindingTuple> result.
if (queryResult instanceof List) {
List<BindingTuple> bindingTuples = (List<BindingTuple>) queryResult;
List<String> csvLines = bindingTuples.stream().map(tuple -> {
List<List<String>> csvLines = bindingTuples.stream().map(tuple -> {
Map<String, ExprValue> bindingMap = tuple.getBindingMap();
List<Object> rowValues = new ArrayList<>();
List<String> rowValues = new ArrayList<>();
for (String fieldName : fieldNames) {
if (bindingMap.containsKey(fieldName)) {
rowValues.add(bindingMap.get(fieldName).value());
rowValues.add(String.valueOf(bindingMap.get(fieldName).value()));
} else {
rowValues.add("");
}
}
return Joiner.on(separator).join(rowValues);
return rowValues;
}).collect(Collectors.toList());

return new CSVResult(fieldNames, csvLines);
return new CSVResult(separator, fieldNames, csvLines);
}
return null;
}
Expand Down Expand Up @@ -283,15 +271,16 @@ private Aggregation getFirstAggregation(Aggregations aggregations) {
return aggregations.asList().get(0);
}

private List<String> createCSVLinesFromDocs(boolean flat, String separator, List<Map<String, Object>> docsAsMap,
List<String> headers) {
List<String> csvLines = new ArrayList<>();
private List<List<String>> createCSVLinesFromDocs(boolean flat, String separator,
List<Map<String, Object>> docsAsMap,
List<String> headers) {
List<List<String>> csvLines = new ArrayList<>();
for (Map<String, Object> doc : docsAsMap) {
String line = "";
List<String> line = new ArrayList<>();
for (String header : headers) {
line += findFieldValue(header, doc, flat, separator);
line.add(findFieldValue(header, doc, flat, separator));
}
csvLines.add(line.substring(0, line.lastIndexOf(separator)));
csvLines.add(line);
}
return csvLines;
}
Expand Down Expand Up @@ -335,11 +324,11 @@ private String findFieldValue(String header, Map<String, Object> doc, boolean fl

for (String innerField : split) {
if (!(innerDoc instanceof Map)) {
return separator;
return "";
}
innerDoc = ((Map<String, Object>) innerDoc).get(innerField);
if (innerDoc == null) {
return separator;
return "";
}
}
return quoteValueIfRequired(innerDoc.toString(), separator);
Expand All @@ -348,14 +337,14 @@ private String findFieldValue(String header, Map<String, Object> doc, boolean fl
return quoteValueIfRequired(String.valueOf(doc.get(header)), separator);
}
}
return separator;
return "";
}

private String quoteValueIfRequired(final String input, final String separator) {
final String quote = "\"";

return input.contains(separator)
? quote + input.replaceAll("\"", "\"\"") + quote + separator : input + separator;
? quote + input.replaceAll("\"", "\"\"") + quote : input;
}

private void mergeHeaders(Set<String> headers, Map<String, Object> doc, boolean flat) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,30 @@ public void includeIdAndTypeButNoScore() throws Exception {
}
//endregion Tests migrated from CSVResultsExtractorTests

@Test
public void sensitiveCharacterSanitizeTest() throws IOException {
String requestBody =
"{" +
" \"=cmd|' /C notepad'!_xlbgnm.A1\": \"+cmd|' /C notepad'!_xlbgnm.A1\",\n" +
" \"-cmd|' /C notepad'!_xlbgnm.A1\": \"@cmd|' /C notepad'!_xlbgnm.A1\"\n" +
"}";

Request request = new Request("PUT", "/userdata/_doc/1?refresh=true");
request.setJsonEntity(requestBody);
TestUtils.performRequest(client(), request);

CSVResult csvResult = executeCsvRequest("SELECT * FROM userdata", false, false, false, false);
List<String> headers = csvResult.getHeaders();
Assert.assertEquals(2, headers.size());
Assert.assertTrue(headers.contains("'=cmd|' /C notepad'!_xlbgnm.A1"));
Assert.assertTrue(headers.contains("'-cmd|' /C notepad'!_xlbgnm.A1"));

List<String> lines = csvResult.getLines();
Assert.assertEquals(1, lines.size());
Assert.assertTrue(lines.get(0).contains("'+cmd|' /C notepad'!_xlbgnm.A1"));
Assert.assertTrue(lines.get(0).contains("'@cmd|' /C notepad'!_xlbgnm.A1"));
}

private void verifyFieldOrder(final String[] expectedFields) throws IOException {

final String fields = String.join(", ", expectedFields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static Response performRequest(RestClient client, Request request) {
try {
Response response = client.performRequest(request);
int status = response.getStatusLine().getStatusCode();
if (status != 200) {
if (status >= 400) {
throw new IllegalStateException("Failed to perform request. Error code: " + status);
}
return response;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.sql.executor.csv;

import org.junit.Test;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;

/**
* Unit tests for {@link CSVResult}
*/
public class CSVResultTest {

private static final String SEPARATOR = ",";

@Test
public void getHeadersShouldReturnHeadersSanitized() {
CSVResult csv = csv(headers("name", "=age"), lines(line("John", "30")));
assertEquals(
headers("name", "'=age"),
csv.getHeaders()
);
}

@Test
public void getLinesShouldReturnLinesSanitized() {
CSVResult csv = csv(
headers("name", "city"),
lines(
line("John", "Seattle"),
line("John", "=Seattle"),
line("John", "+Seattle"),
line("-John", "Seattle"),
line("@John", "Seattle"),
line("John", "Seattle=")
)
);

assertEquals(
line(
"John,Seattle",
"John,'=Seattle",
"John,'+Seattle",
"'-John,Seattle",
"'@John,Seattle",
"John,Seattle="
),
csv.getLines()
);
}

private CSVResult csv(List<String> headers, List<List<String>> lines) {
return new CSVResult(SEPARATOR, headers, lines);
}

private List<String> headers(String... headers) {
return Arrays.stream(headers).collect(Collectors.toList());
}

private List<String> line(String... line) {
return Arrays.stream(line).collect(Collectors.toList());
}

@SafeVarargs
private final List<List<String>> lines(List<String>... lines) {
return Arrays.stream(lines).collect(Collectors.toList());
}

}