Skip to content

Commit

Permalink
Avoid OOM-killing query if large result-level cache population fails …
Browse files Browse the repository at this point in the history
…for query

Currently, result-level caching which attempts to allocate a large enough buffer to store query results will overflow the Integer.MAX_INT capacity. ByteArrayOutputStream materializes this case as an OutOfMemoryError, which is not caught and terminates the node. This limits the allocated buffer for storing query results to whatever is set in `CacheConfig.getResultLevelCacheLimit()`.
  • Loading branch information
jtuglu-netflix committed Jan 22, 2025
1 parent a964220 commit 56dfb12
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@

import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;

/**
* An {@link OutputStream} that limits how many bytes can be written. Throws {@link IOException} if the limit
* is exceeded.
* is exceeded. *Not* thread-safe.
*/
public class LimitedOutputStream extends OutputStream
{
private final OutputStream out;
private final long limit;
private final Function<Long, String> exceptionMessageFn;
long written;
AtomicLong written;

/**
* Create a bytes-limited output stream.
Expand All @@ -48,6 +49,7 @@ public LimitedOutputStream(OutputStream out, long limit, Function<Long, String>
{
this.out = out;
this.limit = limit;
this.written = new AtomicLong(0);
this.exceptionMessageFn = exceptionMessageFn;

if (limit < 0) {
Expand Down Expand Up @@ -88,10 +90,14 @@ public void close() throws IOException
out.close();
}

public OutputStream get()
{
return out;
}

private void plus(final int n) throws IOException
{
written += n;
if (written > limit) {
if (written.addAndGet(n) > limit) {
throw new IOE(exceptionMessageFn.apply(limit));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.druid.client.cache.Cache;
import org.apache.druid.client.cache.Cache.NamedKey;
import org.apache.druid.client.cache.CacheConfig;
import org.apache.druid.io.LimitedOutputStream;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequence;
Expand Down Expand Up @@ -152,6 +153,8 @@ public void after(boolean isDone, Throwable thrown)
// The resultset identifier and its length is cached along with the resultset
resultLevelCachePopulator.populateResults();
log.debug("Cache population complete for query %s", query.getId());
} else { // thrown == null && !resultLevelCachePopulator.isShouldPopulate()
log.error("Failed (and recovered) to populate result level cache for query %s", query.getId());
}
resultLevelCachePopulator.stopPopulating();
}
Expand Down Expand Up @@ -233,8 +236,8 @@ private ResultLevelCachePopulator createResultLevelCachePopulator(
try {
// Save the resultSetId and its length
resultLevelCachePopulator.cacheObjectStream.write(ByteBuffer.allocate(Integer.BYTES)
.putInt(resultSetId.length())
.array());
.putInt(resultSetId.length())
.array());
resultLevelCachePopulator.cacheObjectStream.write(StringUtils.toUtf8(resultSetId));
}
catch (IOException ioe) {
Expand All @@ -255,7 +258,7 @@ private class ResultLevelCachePopulator
private final Cache.NamedKey key;
private final CacheConfig cacheConfig;
@Nullable
private ByteArrayOutputStream cacheObjectStream;
private LimitedOutputStream cacheObjectStream;

private ResultLevelCachePopulator(
Cache cache,
Expand All @@ -270,7 +273,14 @@ private ResultLevelCachePopulator(
this.serialiers = mapper.getSerializerProviderInstance();
this.key = key;
this.cacheConfig = cacheConfig;
this.cacheObjectStream = shouldPopulate ? new ByteArrayOutputStream() : null;
this.cacheObjectStream = shouldPopulate ? new LimitedOutputStream(
new ByteArrayOutputStream(),
cacheConfig.getResultLevelCacheLimit(), limit -> StringUtils.format(
"resultLevelCacheLimit[%,d] exceeded. "
+ "Max ResultLevelCacheLimit for cache exceeded. Result caching failed.",
limit
)
) : null;
}

boolean isShouldPopulate()
Expand All @@ -289,12 +299,8 @@ private void cacheResultEntry(
)
{
Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream");
int cacheLimit = cacheConfig.getResultLevelCacheLimit();
try (JsonGenerator gen = mapper.getFactory().createGenerator(cacheObjectStream)) {
JacksonUtils.writeObjectUsingSerializerProvider(gen, serialiers, cacheFn.apply(resultEntry));
if (cacheLimit > 0 && cacheObjectStream.size() > cacheLimit) {
stopPopulating();
}
}
catch (IOException ex) {
log.error(ex, "Failed to retrieve entry to be cached. Result Level caching will not be performed!");
Expand All @@ -304,7 +310,8 @@ private void cacheResultEntry(

public void populateResults()
{
byte[] cachedResults = Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream").toByteArray();
byte[] cachedResults = ((ByteArrayOutputStream) Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream")
.get()).toByteArray();
cache.put(key, cachedResults);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
public class ResultLevelCachingQueryRunnerTest extends QueryRunnerBasedOnClusteredClientTestBase
{
private Cache cache;
private static final int DEFAULT_CACHE_ENTRY_MAX_SIZE = Integer.MAX_VALUE;

@Before
public void setup()
Expand All @@ -58,7 +59,7 @@ public void testNotPopulateAndNotUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(false, false),
newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -72,7 +73,7 @@ public void testNotPopulateAndNotUse()
Assert.assertEquals(0, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(false, false),
newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -93,7 +94,7 @@ public void testPopulateAndNotUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(true, false),
newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -107,7 +108,7 @@ public void testPopulateAndNotUse()
Assert.assertEquals(0, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(true, false),
newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -128,7 +129,7 @@ public void testNotPopulateAndUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(false, false),
newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -142,7 +143,7 @@ public void testNotPopulateAndUse()
Assert.assertEquals(0, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(false, true),
newCacheConfig(false, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -163,7 +164,7 @@ public void testPopulateAndUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(true, true),
newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -177,7 +178,7 @@ public void testPopulateAndUse()
Assert.assertEquals(1, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(true, true),
newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -192,6 +193,41 @@ public void testPopulateAndUse()
Assert.assertEquals(1, cache.getStats().getNumMisses());
}

@Test
public void testNoPopulateIfEntrySizeExceedsMaximum()
{
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(true, true, 128),
query
);

final Sequence<Result<TimeseriesResultValue>> sequence1 = queryRunner1.run(
QueryPlus.wrap(query),
responseContext()
);
final List<Result<TimeseriesResultValue>> results1 = sequence1.toList();
Assert.assertEquals(0, cache.getStats().getNumHits());
Assert.assertEquals(0, cache.getStats().getNumEntries());
Assert.assertEquals(1, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

final Sequence<Result<TimeseriesResultValue>> sequence2 = queryRunner2.run(
QueryPlus.wrap(query),
responseContext()
);
final List<Result<TimeseriesResultValue>> results2 = sequence2.toList();
Assert.assertEquals(results1, results2);
Assert.assertEquals(0, cache.getStats().getNumHits());
Assert.assertEquals(1, cache.getStats().getNumEntries());
Assert.assertEquals(2, cache.getStats().getNumMisses());
}

@Test
public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache()
{
Expand All @@ -206,7 +242,7 @@ public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache()

final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner = createQueryRunner(
newCacheConfig(true, false),
newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand Down Expand Up @@ -249,7 +285,11 @@ private <T> ResultLevelCachingQueryRunner<T> createQueryRunner(
);
}

private CacheConfig newCacheConfig(boolean populateResultLevelCache, boolean useResultLevelCache)
private CacheConfig newCacheConfig(
boolean populateResultLevelCache,
boolean useResultLevelCache,
int resultLevelCacheLimit
)
{
return new CacheConfig()
{
Expand All @@ -264,6 +304,12 @@ public boolean isUseResultLevelCache()
{
return useResultLevelCache;
}

@Override
public int getResultLevelCacheLimit()
{
return resultLevelCacheLimit;
}
};
}
}

0 comments on commit 56dfb12

Please sign in to comment.