Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the overhead of IndexInput#prefetch when data is cached in RAM. #13381

Merged
merged 3 commits into from
May 21, 2024

Conversation

jpountz
Copy link
Contributor

@jpountz jpountz commented May 17, 2024

As Robert pointed out and benchmarks confirmed, there is some (small) overhead to calling madvise via the foreign function API, benchmarks suggest it is in the order of 1-2us. This is not much for a single call, but may become non-negligible across many calls. Until now, we only looked into using prefetch() for terms, skip data and postings start pointers which are a single prefetch() operation per segment per term.

But we may want to start using it in cases that could result into more calls to madvise, e.g. if we start using it for stored fields and a user requests 10k documents. In #13337, Robert wondered if we could take advantage of mincore() to reduce the overhead of IndexInput#prefetch(), which is what this PR is doing.

For now, this is trying to not add new APIs. Instead, IndexInput#prefetch tracks consecutive hits on the page cache and calls madvise less and less frequently under the hood as the number of cache hits increases.

As Robert pointed out and benchmarks confirmed, there is some (small) overhead
to calling `madvise` via the foreign function API, benchmarks suggest it is in
the order of 1-2us. This is not much for a single call, but may become
non-negligible across many calls. Until now, we only looked into using
prefetch() for terms, skip data and postings start pointers which are a single
prefetch() operation per segment per term.

But we may want to start using it in cases that could result into more calls to
`madvise`, e.g. if we start using it for stored fields and a user requests 10k
documents. In apache#13337, Robert wondered if we could take advantage of `mincore()`
to reduce the overhead of `IndexInput#prefetch()`, which is what this PR is
doing.

For now, this is trying to not add new APIs. Instead, `IndexInput#prefetch`
tracks consecutive hits on the page cache and calls `madvise` less and less
frequently under the hood as the number of cache hits increases.
@jpountz jpountz requested review from rmuir and uschindler May 17, 2024 20:10
@jpountz
Copy link
Contributor Author

jpountz commented May 17, 2024

I slightly modified the benchmark from #13337
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;

public class PrefetchBench {

  private static final int NUM_TERMS = 3;
  private static final long FILE_SIZE = 100L * 1024 * 1024 * 1024; // 100GB
  private static final int NUM_BYTES = 16;
  public static int DUMMY;

  public static void main(String[] args) throws IOException {
    Path filePath = Paths.get(args[0]);
    Path dirPath = filePath.getParent();
    String fileName = filePath.getFileName().toString();
    Random r = ThreadLocalRandom.current();

    try (Directory dir = new MMapDirectory(dirPath)) {
      if (Arrays.asList(dir.listAll()).contains(fileName) == false) {
        try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
          byte[] buf = new byte[8196];
          for (long i = 0; i < FILE_SIZE; i += buf.length) {
            r.nextBytes(buf);
            out.writeBytes(buf, buf.length);
          }
        }
      }

      for (boolean dataFitsInCache : new boolean[] { false, true}) {
        try (IndexInput i0 = dir.openInput("file", IOContext.DEFAULT)) {
          byte[][] b = new byte[NUM_TERMS][];
          for (int i = 0; i < NUM_TERMS; ++i) {
            b[i] = new byte[NUM_BYTES];
          }
          IndexInput[] inputs = new IndexInput[NUM_TERMS];
          if (dataFitsInCache) {
            // 16MB slice that should easily fit in the page cache
            inputs[0] = i0.slice("slice", 0, 16 * 1024 * 1024);
          } else {
            inputs[0] = i0;
          }
          for (int i = 1; i < NUM_TERMS; ++i) {
            inputs[i] = inputs[0].clone();
          }
          final long length = inputs[0].length();
          List<Long>[] latencies = new List[2];
          latencies[0] = new ArrayList<>();
          latencies[1] = new ArrayList<>();
          for (int iter = 0; iter < 100_000; ++iter) {
            final boolean prefetch = (iter & 1) == 0;

            final long start = System.nanoTime();

            for (IndexInput ii : inputs) {
              final long offset = r.nextLong(length - NUM_BYTES);
              ii.seek(offset);
              if (prefetch) {
                ii.prefetch(offset, 1);
              }
            }

            for (int i = 0; i < NUM_TERMS; ++i) {
              inputs[i].readBytes(b[i], 0, b[i].length);
            }

            final long end = System.nanoTime();

            // Prevent the JVM from optimizing away the reads
            DUMMY = Arrays.stream(b).mapToInt(Arrays::hashCode).sum();

            latencies[iter & 1].add(end - start);
          }

          latencies[0].sort(null);
          latencies[1].sort(null);

          System.out.println("Data " + (dataFitsInCache ? "fits" : "does not fit") + " in the page cache");
          long prefetchP50 = latencies[0].get(latencies[0].size() / 2);
          long prefetchP90 = latencies[0].get(latencies[0].size() * 9 / 10);
          long prefetchP99 = latencies[0].get(latencies[0].size() * 99 / 100);
          long noPrefetchP50 = latencies[1].get(latencies[1].size() / 2);
          long noPrefetchP90 = latencies[1].get(latencies[1].size() * 9 / 10);
          long noPrefetchP99 = latencies[1].get(latencies[1].size() * 99 / 100);

          System.out.println("  With prefetching:    P50=" + prefetchP50 + "ns P90=" + prefetchP90 + "ns P99=" + prefetchP99 + "ns");
          System.out.println("  Without prefetching: P50=" + noPrefetchP50 + "ns P90=" + noPrefetchP90 + "ns P99=" + noPrefetchP99 + "ns");
        }
      }
    }
  }

}

It gives the following results. Before the change:

Data does not fit in the page cache
  With prefetching:    P50=88080ns P90=122970ns P99=157420ns
  Without prefetching: P50=224040ns P90=242320ns P99=297470ns
Data fits in the page cache
  With prefetching:    P50=880ns P90=1060ns P99=1370ns
  Without prefetching: P50=190ns P90=280ns P99=580ns

After the change:

Data does not fit in the page cache
  With prefetching:    P50=89710ns P90=124780ns P99=159400ns
  Without prefetching: P50=224271ns P90=242940ns P99=297371ns
Data fits in the page cache
  With prefetching:    P50=210ns P90=300ns P99=630ns
  Without prefetching: P50=200ns P90=290ns P99=580ns

Copy link
Member

@rmuir rmuir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numbers look great. I like the simple solution here to lower the overhead for when things fit in RAM. Let's try MemorySegment.isLoaded() and if performance is similar, we can avoid maintaining our own native mincore plumbing.

}
return true;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be replaced with MemorySegment.isLoaded() which does the exact same thing in the openJDK via C code?

// on the next power of two of the counter.
return;
}

if (NATIVE_ACCESS.isEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move the native access check to the top.

@@ -344,7 +354,11 @@ public void prefetch(long offset, long length) throws IOException {
}

final MemorySegment prefetchSlice = segment.asSlice(offset, length);
nativeAccess.madviseWillNeed(prefetchSlice);
if (nativeAccess.mincore(prefetchSlice) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we look into replacing our native code with a call to MemorySegment#load() in a virtual thread?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we look into replacing our native code with a call to MemorySegment#load() in a virtual thread?

Let's keep it with pure madvise. A virtual thread is not a good idea because the "touch every page" code inside the JVM is not suitable for a virtual thread as it is cpu bound.

@@ -32,6 +32,9 @@ abstract class NativeAccess {
*/
public abstract void madviseWillNeed(MemorySegment segment) throws IOException;

/** Returns {@code true} if pages from the given {@link MemorySegment} are resident in RAM. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

@@ -17,6 +17,7 @@
package org.apache.lucene.store;

import java.io.IOException;
import java.lang.foreign.Arena;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

@uschindler
Copy link
Contributor

uschindler commented May 18, 2024

P.S.: Actually when looking at the code, the MemorySegment#load() method calls madvise(MADV_WILLNEED). So we could also implement prefetch using load(). You can follow that through the same chain of classes/call like in my previous review comment.

The only problem with that is: After doing the madvise, it touches a byte in each page to actually trigger the load synchronously. So we have to stay with our direct native call here.

@rmuir
Copy link
Member

rmuir commented May 19, 2024

somewhat related: i was playing around with the new cachestat syscall (it isn't relevant to us here directly, takes fd, etc), but the background did bring up the opposite concern of this PR:

such an application can learn whether the pages it is prefetching into the cache are still there by the time it gets around to using them. If those pages are being evicted, the prefetching is overloading the page cache and causing more work overall; in such situations, the application can back off and get better performance.

https://lwn.net/Articles/917096/

You can play around with it easily on linux 6.x from the commandline:

$ fincore --output-all myindexdir/*

@rmuir
Copy link
Member

rmuir commented May 19, 2024

Maybe if we didn't close the fd in mmapdir we could eventually think about making use of this on modern linux. it doesn't have a glibc wrapper yet... here is minimal sample code, but maybe just look at fincore for a more functional example: https://github.com/util-linux/util-linux/blob/master/misc-utils/fincore.c

#include <sys/syscall.h>
#include <linux/mman.h>
#include <fcntl.h>
#include <stdio.h>
#include <unistd.h>

int
cachestat(int fd, struct cachestat_range *range, struct cachestat *stats, int flags) {
  return syscall(SYS_cachestat, fd, range, stats, flags);
}

int main(int argc, char **argv) {
  int fd;

  if (argc != 2) {
    printf("usage: %s <file>\n", argv[0]);
    return 2;
  }

  if ((fd = open(argv[1], O_RDONLY)) < 0) {
    perror("couldn't open");
    return 1;
  }

  struct cachestat_range range = { 0, 0 };
  struct cachestat cstats;
  if (cachestat(fd, &range, &cstats, 0) != 0) {
    perror("couldn't cachestat");
    return 1;
  }

  printf("cached: %llu\ndirty: %llu\nwriteback: %llu\nevicted: %llu\nrecently_evicted: %llu\n",
      cstats.nr_cache, cstats.nr_dirty, cstats.nr_writeback, cstats.nr_evicted, cstats.nr_recently_evicted);

  return 0;
}

Copy link
Contributor

@uschindler uschindler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now. Let's just think about the counter and clones.

@jpountz
Copy link
Contributor Author

jpountz commented May 19, 2024

such an application can learn whether the pages it is prefetching into the cache are still there by the time it gets around to using them

This is an interesting idea!

I was discussing this potential problem with @tveasey the other day. With terms and postings, we're currently only looking into loading a few pages in parallel per search thread and we then use them immediately. With GBs of capacity for the page cache, it would be extremely unlikely for these pages to get evicted in the meantime. But if/when we start looking into using prefetch() for bigger regions (e.g. stored fields) and/or possibly longer before needing the data (e.g. starting prefetching data for the next segment while we're scoring the current segment), then this could become a problem indeed. It would be nice if we could learn to disable prefetching when it's not working as intended. This would make this API safer to use.

@jpountz
Copy link
Contributor Author

jpountz commented May 19, 2024

I added "search" concurrency to the benchmark to make it a bit more realistic
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadLocalRandom;

import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.ThreadInterruptedException;

public class PrefetchBench {

  private static final int CONCURRENCY = 10;
  private static final int NUM_TERMS = 3;
  private static final long FILE_SIZE = 100L * 1024 * 1024 * 1024; // 100GB
  private static final int NUM_BYTES = 16;
  public static int DUMMY;

  public static void main(String[] args) throws Exception {
    Path filePath = Paths.get(args[0]);
    Path dirPath = filePath.getParent();
    String fileName = filePath.getFileName().toString();
    Random r = ThreadLocalRandom.current();

    try (Directory dir = new MMapDirectory(dirPath)) {
      if (Arrays.asList(dir.listAll()).contains(fileName) == false) {
        try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
          byte[] buf = new byte[8196];
          for (long i = 0; i < FILE_SIZE; i += buf.length) {
            r.nextBytes(buf);
            out.writeBytes(buf, buf.length);
          }
        }
      }

      for (boolean dataFitsInCache : new boolean[] { false, true}) {
        try (IndexInput i0 = dir.openInput("file", IOContext.DEFAULT)) {
          final IndexInput input;
          if (dataFitsInCache) {
            // 16MB slice that should easily fit in the page cache
            input = i0.slice("slice", 0, 16 * 1024 * 1024);
          } else {
            input = i0;
          }

          final CountDownLatch latch = new CountDownLatch(1);
          RandomReader[] readers = new RandomReader[CONCURRENCY];
          for (int i = 0; i < readers.length; ++i) {
            IndexInput[] inputs = new IndexInput[NUM_TERMS];
            for (int j = 0; j < inputs.length; ++j) {
              inputs[j] = input.clone();
            }
            readers[i] = new RandomReader(inputs, latch);
            readers[i].start();
          }
          
          latch.countDown();
          List<Long> prefetchLatencies = new ArrayList<>();
          List<Long> noPrefetchLatencies = new ArrayList<>();
          for (RandomReader reader : readers) {
            reader.join();
            prefetchLatencies.addAll(reader.latencies[0]);
            noPrefetchLatencies.addAll(reader.latencies[1]);
          }
          prefetchLatencies.sort(null);
          noPrefetchLatencies.sort(null);

          System.out.println("Data " + (dataFitsInCache ? "fits" : "does not fit") + " in the page cache");
          long prefetchP50 = prefetchLatencies.get(prefetchLatencies.size() / 2);
          long prefetchP90 = prefetchLatencies.get(prefetchLatencies.size() * 9 / 10);
          long prefetchP99 = prefetchLatencies.get(prefetchLatencies.size() * 99 / 100);
          long noPrefetchP50 = noPrefetchLatencies.get(noPrefetchLatencies.size() / 2);
          long noPrefetchP90 = noPrefetchLatencies.get(noPrefetchLatencies.size() * 9 / 10);
          long noPrefetchP99 = noPrefetchLatencies.get(noPrefetchLatencies.size() * 99 / 100);

          System.out.println("  With prefetching:    P50=" + prefetchP50 + "ns P90=" + prefetchP90 + "ns P99=" + prefetchP99 + "ns");
          System.out.println("  Without prefetching: P50=" + noPrefetchP50 + "ns P90=" + noPrefetchP90 + "ns P99=" + noPrefetchP99 + "ns");
        }
      }
    }
  }

  private static class RandomReader extends Thread {

    private final IndexInput[] inputs;
    private final CountDownLatch latch;
    private final byte[][] b = new byte[NUM_TERMS][];
    final List<Long>[] latencies = new List[2];

    RandomReader(IndexInput[] inputs, CountDownLatch latch) {
      this.inputs = inputs;
      this.latch = latch;
      latencies[0] = new ArrayList<>();
      latencies[1] = new ArrayList<>();
      for (int i = 0; i < NUM_TERMS; ++i) {
        b[i] = new byte[NUM_BYTES];
      }
    }

    @Override
    public void run() {
      try {
        latch.await();

        final ThreadLocalRandom r = ThreadLocalRandom.current();
        final long length = inputs[0].length();
        for (int iter = 0; iter < 100_000; ++iter) {
          final boolean prefetch = (iter & 1) == 0;

          final long start = System.nanoTime();

          for (IndexInput ii : inputs) {
            final long offset = r.nextLong(length - NUM_BYTES);
            ii.seek(offset);
            if (prefetch) {
              ii.prefetch(offset, 1);
            }
          }

          for (int i = 0; i < NUM_TERMS; ++i) {
            inputs[i].readBytes(b[i], 0, b[i].length);
          }

          final long end = System.nanoTime();

          // Prevent the JVM from optimizing away the reads
          DUMMY = Arrays.stream(b).mapToInt(Arrays::hashCode).sum();

          latencies[iter & 1].add(end - start);
        }
      } catch (IOException e) {
        throw new UncheckedIOException(e);
      } catch (InterruptedException e) {
        throw new ThreadInterruptedException(e);
      }
    }

  }

}

On the latest version of this PR, it reports:

Data does not fit in the page cache
  With prefetching:    P50=104260ns P90=159710ns P99=228880ns
  Without prefetching: P50=242580ns P90=315821ns P99=405901ns
Data fits in the page cache
  With prefetching:    P50=310ns P90=6700ns P99=12320ns
  Without prefetching: P50=290ns P90=6770ns P99=11610ns

vs. the following on main:

Data does not fit in the page cache
  With prefetching:    P50=97620ns P90=153050ns P99=220510ns
  Without prefetching: P50=226690ns P90=302530ns P99=392770ns
Data fits in the page cache
  With prefetching:    P50=6970ns P90=9380ns P99=12300ns
  Without prefetching: P50=290ns P90=5890ns P99=8560ns

@jpountz jpountz marked this pull request as ready for review May 19, 2024 21:53
Copy link
Member

@rmuir rmuir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, lets make incremental progress!

@jpountz jpountz merged commit aac856a into apache:main May 21, 2024
3 checks passed
@jpountz jpountz deleted the mincore_prefetch branch May 21, 2024 07:12
@jpountz jpountz added this to the 10.0.0 milestone May 21, 2024
shatejas pushed a commit to shatejas/lucene that referenced this pull request Nov 17, 2024
…AM. (apache#13381)

As Robert pointed out and benchmarks confirmed, there is some (small) overhead
to calling `madvise` via the foreign function API, benchmarks suggest it is in
the order of 1-2us. This is not much for a single call, but may become
non-negligible across many calls. Until now, we only looked into using
prefetch() for terms, skip data and postings start pointers which are a single
prefetch() operation per segment per term.

But we may want to start using it in cases that could result into more calls to
`madvise`, e.g. if we start using it for stored fields and a user requests 10k
documents. In apache#13337, Robert wondered if we could take advantage of `mincore()`
to reduce the overhead of `IndexInput#prefetch()`, which is what this PR is
doing via `MemorySegment#isLoaded()`.

`IndexInput#prefetch` tracks consecutive hits on the page cache and calls
`madvise` less and less frequently under the hood as the number of consecutive
cache hits increases.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants