Skip to content

Commit

Permalink
Fix problems in GeoIPv2 code (#71598)
Browse files Browse the repository at this point in the history
This change fixes number of problems in GeoIPv2 code:

- closes streams from Files.list in GeoIpCli, which should fix tests on Windows
- makes sure that total download time in GeoIP stats is non-negative (we serialize it as vInt which can cause problems with negative numbers and it can happen when clock was changed during operation)
- fixes handling of failed/simultaneous downloads, #69951 was meant as a way to prevent 2 persistent tasks to index chunks but it would prevent any update if single download failed mid indexing, this change uses timestamp (lastUpdate) as sort of UUID. This should still prevent 2 tasks to step on each other toes (overwriting chunks) but in the end still only single task should be able to update task state (this is handled by persistent tasks framework)
Closes #71145
  • Loading branch information
probakowski authored Apr 13, 2021
1 parent c436458 commit 46efa6a
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;

import static java.nio.file.StandardOpenOption.CREATE;
Expand All @@ -42,9 +42,8 @@ public class GeoIpCli extends Command {

private static final byte[] EMPTY_BUF = new byte[512];

// visible for testing
final OptionSpec<String> sourceDirectory;
final OptionSpec<String> targetDirectory;
private final OptionSpec<String> sourceDirectory;
private final OptionSpec<String> targetDirectory;

public GeoIpCli() {
super("A CLI tool to prepare local GeoIp database service", () -> {});
Expand All @@ -58,7 +57,7 @@ protected void execute(Terminal terminal, OptionSet options) throws Exception {
Path source = getPath(options.valueOf(sourceDirectory));
String targetString = options.valueOf(targetDirectory);
Path target = targetString != null ? getPath(targetString) : source;
copyTgzToTarget(terminal, source, target);
copyTgzToTarget(source, target);
packDatabasesToTgz(terminal, source, target);
createOverviewJson(terminal, target);
}
Expand All @@ -68,49 +67,49 @@ private Path getPath(String file) {
return PathUtils.get(file);
}

private void copyTgzToTarget(Terminal terminal, Path source, Path target) throws IOException {
private void copyTgzToTarget(Path source, Path target) throws IOException {
if (source.equals(target)) {
return;
}
List<Path> toCopy = Files.list(source).filter(p -> p.getFileName().toString().endsWith(".tgz")).collect(Collectors.toList());
for (Path path : toCopy) {
Files.copy(path, target.resolve(path.getFileName()), StandardCopyOption.REPLACE_EXISTING);
try (Stream<Path> files = Files.list(source)) {
for (Path path : files.filter(p -> p.getFileName().toString().endsWith(".tgz")).collect(Collectors.toList())) {
Files.copy(path, target.resolve(path.getFileName()), StandardCopyOption.REPLACE_EXISTING);
}
}
}

private void packDatabasesToTgz(Terminal terminal, Path source, Path target) throws IOException {
List<Path> toPack = Files.list(source).filter(p -> p.getFileName().toString().endsWith(".mmdb")).collect(Collectors.toList());
for (Path path : toPack) {
String fileName = path.getFileName().toString();
Path compressedPath = target.resolve(fileName.replaceAll("mmdb$", "") + "tgz");
terminal.println("Found " + fileName + ", will compress it to " + compressedPath.getFileName());
try (
OutputStream fos = Files.newOutputStream(compressedPath, TRUNCATE_EXISTING, CREATE);
OutputStream gos = new GZIPOutputStream(new BufferedOutputStream(fos))
) {
long size = Files.size(path);
gos.write(createTarHeader(fileName, size));
Files.copy(path, gos);
if (size % 512 != 0) {
gos.write(EMPTY_BUF, 0, (int) (512 - (size % 512)));
try (Stream<Path> files = Files.list(source)) {
for (Path path : files.filter(p -> p.getFileName().toString().endsWith(".mmdb")).collect(Collectors.toList())) {
String fileName = path.getFileName().toString();
Path compressedPath = target.resolve(fileName.replaceAll("mmdb$", "") + "tgz");
terminal.println("Found " + fileName + ", will compress it to " + compressedPath.getFileName());
try (
OutputStream fos = Files.newOutputStream(compressedPath, TRUNCATE_EXISTING, CREATE);
OutputStream gos = new GZIPOutputStream(new BufferedOutputStream(fos))
) {
long size = Files.size(path);
gos.write(createTarHeader(fileName, size));
Files.copy(path, gos);
if (size % 512 != 0) {
gos.write(EMPTY_BUF, 0, (int) (512 - (size % 512)));
}
gos.write(EMPTY_BUF);
gos.write(EMPTY_BUF);
}
gos.write(EMPTY_BUF);
gos.write(EMPTY_BUF);
}
}
}

private void createOverviewJson(Terminal terminal, Path directory) throws IOException {
List<Path> databasesPaths = Files.list(directory)
.filter(p -> p.getFileName().toString().endsWith(".tgz"))
.collect(Collectors.toList());
Path overview = directory.resolve("overview.json");
try (
Stream<Path> files = Files.list(directory);
OutputStream os = new BufferedOutputStream(Files.newOutputStream(overview, TRUNCATE_EXISTING, CREATE));
XContentGenerator generator = XContentType.JSON.xContent().createGenerator(os)
) {
generator.writeStartArray();
for (Path db : databasesPaths) {
for (Path db : files.filter(p -> p.getFileName().toString().endsWith(".tgz")).collect(Collectors.toList())) {
terminal.println("Adding " + db.getFileName() + " to overview.json");
MessageDigest md5 = MessageDigests.md5();
try (InputStream dis = new DigestInputStream(new BufferedInputStream(Files.newInputStream(db)), md5)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasKey;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/71145")
@LuceneTestCase.SuppressFileSystems(value = "ExtrasFS") // Don't randomly add 'extra' files to directory.
public class GeoIpCliTests extends LuceneTestCase {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ void retrieveDatabase(String databaseName,
// (the chance that the documents change is rare, given the low frequency of the updates for these databases)
for (int chunk = firstChunk; chunk <= lastChunk; chunk++) {
SearchRequest searchRequest = new SearchRequest(GeoIpDownloader.DATABASES_INDEX);
String id = String.format(Locale.ROOT, "%s_%d", databaseName, chunk);
String id = String.format(Locale.ROOT, "%s_%d_%d", databaseName, chunk, metadata.getLastUpdate());
searchRequest.source().query(new TermQueryBuilder("_id", id));

// At most once a day a few searches may be executed to fetch the new files,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
private volatile GeoIpDownloaderStats stats = GeoIpDownloaderStats.EMPTY;

GeoIpDownloader(Client client, HttpClient httpClient, ClusterService clusterService, ThreadPool threadPool, Settings settings,
long id, String type, String action, String description, TaskId parentTask,
Map<String, String> headers) {
long id, String type, String action, String description, TaskId parentTask, Map<String, String> headers) {
super(id, type, action, description, parentTask, headers);
this.httpClient = httpClient;
this.client = new OriginSettingClient(client, IngestService.INGEST_ORIGIN);
Expand Down Expand Up @@ -139,9 +138,9 @@ void processDatabase(Map<String, Object> databaseInfo) {
long start = System.currentTimeMillis();
try (InputStream is = httpClient.get(url)) {
int firstChunk = state.contains(name) ? state.get(name).getLastChunk() + 1 : 0;
int lastChunk = indexChunks(name, is, firstChunk, md5);
int lastChunk = indexChunks(name, is, firstChunk, md5, start);
if (lastChunk > firstChunk) {
state = state.put(name, new Metadata(System.currentTimeMillis(), firstChunk, lastChunk - 1, md5));
state = state.put(name, new Metadata(start, firstChunk, lastChunk - 1, md5));
updateTaskState();
stats = stats.successfulDownload(System.currentTimeMillis() - start).count(state.getDatabases().size());
logger.info("updated geoip database [" + name + "]");
Expand Down Expand Up @@ -180,11 +179,11 @@ void updateTaskState() {
}

//visible for testing
int indexChunks(String name, InputStream is, int chunk, String expectedMd5) throws IOException {
int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long timestamp) throws IOException {
MessageDigest md = MessageDigests.md5();
for (byte[] buf = getChunk(is); buf.length != 0; buf = getChunk(is)) {
md.update(buf);
client.prepareIndex(DATABASES_INDEX).setId(name + "_" + chunk)
client.prepareIndex(DATABASES_INDEX).setId(name + "_" + chunk + "_" + timestamp)
.setCreate(true)
.setSource(XContentType.SMILE, "name", name, "chunk", chunk, "data", buf)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ public GeoIpDownloaderStats skippedDownload() {
}

public GeoIpDownloaderStats successfulDownload(long downloadTime) {
return new GeoIpDownloaderStats(successfulDownloads + 1, failedDownloads, totalDownloadTime + downloadTime, databasesCount,
skippedDownloads);
return new GeoIpDownloaderStats(successfulDownloads + 1, failedDownloads, totalDownloadTime + Math.max(downloadTime, 0),
databasesCount, skippedDownloads);
}

public GeoIpDownloaderStats failedDownload() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -82,7 +83,6 @@
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -122,7 +122,7 @@ public void cleanup() {
resourceWatcherService.close();
threadPool.shutdownNow();
}

public void testCheckDatabases() throws Exception {
String md5 = mockSearches("GeoIP2-City.mmdb", 5, 14);
String taskId = GeoIpDownloader.GEOIP_DOWNLOADER;
Expand Down Expand Up @@ -258,6 +258,7 @@ private String mockSearches(String databaseName, int firstChunk, int lastChunk)
List<byte[]> data = gzip(databaseName, dummyContent, lastChunk - firstChunk + 1);
assertThat(gunzip(data), equalTo(dummyContent));

Map<String, ActionFuture<SearchResponse>> requestMap = new HashMap<>();
for (int i = firstChunk; i <= lastChunk; i++) {
byte[] chunk = data.get(i - firstChunk);
SearchHit hit = new SearchHit(i);
Expand All @@ -270,17 +271,20 @@ private String mockSearches(String databaseName, int firstChunk, int lastChunk)
throw new UncheckedIOException(ex);
}

SearchHits hits = new SearchHits(new SearchHit[] {hit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f);
SearchHits hits = new SearchHits(new SearchHit[]{hit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f);
SearchResponse searchResponse =
new SearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 0), null, 1, 1, 0, 1L, null, null);
@SuppressWarnings("unchecked")
ActionFuture<SearchResponse> actionFuture = mock(ActionFuture.class);
when(actionFuture.actionGet()).thenReturn(searchResponse);
SearchRequest expectedSearchRequest = new SearchRequest(GeoIpDownloader.DATABASES_INDEX);
String id = String.format(Locale.ROOT, "%s_%d", databaseName, i);
expectedSearchRequest.source().query(new TermQueryBuilder("_id", id));
when(client.search(eq(expectedSearchRequest))).thenReturn(actionFuture);
requestMap.put(databaseName + "_" + i, actionFuture);
}
when(client.search(any())).thenAnswer(invocationOnMock -> {
SearchRequest req = (SearchRequest) invocationOnMock.getArguments()[0];
TermQueryBuilder term = (TermQueryBuilder) req.source().query();
String id = (String) term.value();
return requestMap.get(id.substring(0, id.lastIndexOf('_')));
});

MessageDigest md = MessageDigests.md5();
data.forEach(md::update);
Expand Down Expand Up @@ -322,7 +326,7 @@ private static List<byte[]> gzip(String name, String content, int chunks) throws
int chunkSize = all.length / chunks;
List<byte[]> data = new ArrayList<>();

for (int from = 0; from < all.length;) {
for (int from = 0; from < all.length; ) {
int to = from + chunkSize;
if (to > all.length) {
to = all.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -139,12 +140,13 @@ public int read() throws IOException {
}

public void testIndexChunksNoData() throws IOException {
assertEquals(0, geoIpDownloader.indexChunks("test", new ByteArrayInputStream(new byte[0]), 0, "d41d8cd98f00b204e9800998ecf8427e"));
InputStream empty = new ByteArrayInputStream(new byte[0]);
assertEquals(0, geoIpDownloader.indexChunks("test", empty, 0, "d41d8cd98f00b204e9800998ecf8427e", 0));
}

public void testIndexChunksMd5Mismatch() {
IOException exception = expectThrows(IOException.class, () -> geoIpDownloader.indexChunks("test",
new ByteArrayInputStream(new byte[0]), 0, "123123"));
new ByteArrayInputStream(new byte[0]), 0, "123123", 0));
assertEquals("md5 checksum mismatch, expected [123123], actual [d41d8cd98f00b204e9800998ecf8427e]", exception.getMessage());
}

Expand All @@ -164,7 +166,7 @@ public void testIndexChunks() throws IOException {
client.addHandler(IndexAction.INSTANCE, (IndexRequest request, ActionListener<IndexResponse> listener) -> {
int chunk = chunkIndex.getAndIncrement();
assertEquals(OpType.CREATE, request.opType());
assertEquals("test_" + (chunk + 15), request.id());
assertThat(request.id(), Matchers.startsWith("test_" + (chunk + 15) + "_"));
assertEquals(XContentType.SMILE, request.getContentType());
Map<String, Object> source = request.sourceAsMap();
assertEquals("test", source.get("name"));
Expand All @@ -173,7 +175,8 @@ public void testIndexChunks() throws IOException {
listener.onResponse(mock(IndexResponse.class));
});

assertEquals(17, geoIpDownloader.indexChunks("test", new ByteArrayInputStream(bigArray), 15, "a67563dfa8f3cba8b8cff61eb989a749"));
InputStream big = new ByteArrayInputStream(bigArray);
assertEquals(17, geoIpDownloader.indexChunks("test", big, 15, "a67563dfa8f3cba8b8cff61eb989a749", 0));

assertEquals(2, chunkIndex.get());
}
Expand All @@ -191,7 +194,7 @@ void updateTaskState() {
}

@Override
int indexChunks(String name, InputStream is, int chunk, String expectedMd5) {
int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long start) {
assertSame(bais, is);
assertEquals(0, chunk);
return 11;
Expand Down Expand Up @@ -226,7 +229,7 @@ void updateTaskState() {
}

@Override
int indexChunks(String name, InputStream is, int chunk, String expectedMd5) {
int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long start) {
assertSame(bais, is);
assertEquals(9, chunk);
return 11;
Expand Down Expand Up @@ -263,7 +266,7 @@ void updateTaskState() {
}

@Override
int indexChunks(String name, InputStream is, int chunk, String expectedMd5) {
int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long start) {
fail();
return 0;
}
Expand Down

0 comments on commit 46efa6a

Please sign in to comment.