diff --git a/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java b/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java index a9301dc..3fb62cf 100644 --- a/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java +++ b/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java @@ -19,6 +19,8 @@ import io.github.nstdio.http.ext.Cache.CacheEntry; import io.github.nstdio.http.ext.Cache.CacheStats; +import java.net.URI; +import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; @@ -27,7 +29,6 @@ import java.util.List; import java.util.Optional; import java.util.function.Consumer; -import java.util.stream.Stream; import static io.github.nstdio.http.ext.Headers.HEADER_IF_MODIFIED_SINCE; import static io.github.nstdio.http.ext.Headers.HEADER_IF_NONE_MATCH; @@ -36,9 +37,10 @@ import static io.github.nstdio.http.ext.Responses.isSafeRequest; import static io.github.nstdio.http.ext.Responses.isSuccessful; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.stream.Collectors.toList; class CachingInterceptor implements Interceptor { + private final static List INVALIDATION_HEADERS = List.of("Location", "Content-Location"); + private final Cache cache; private final Clock clock; @@ -170,14 +172,18 @@ private BodyHandler cacheAware(RequestContext ctx) { } private HttpResponse invalidate(HttpResponse response) { - List toEvict = Stream.of("Location", "Content-Location") - .flatMap(s -> Headers.effectiveUri(response.headers(), s, response.uri()).stream()) - .filter(uri -> response.uri().getHost().equals(uri.getHost())) - .map(uri -> HttpRequest.newBuilder(uri).build()) - .collect(toList()); - toEvict.add(response.request()); - - toEvict.forEach(cache::evictAll); + HttpHeaders headers = response.headers(); + URI uri = response.uri(); + String host = uri.getHost(); + + for (String headerName : INVALIDATION_HEADERS) { + Headers.effectiveUri(headers, headerName, uri) + .filter(u -> host.equals(u.getHost())) + .map(u -> HttpRequest.newBuilder(u).build()) + .forEach(cache::evictAll); + } + + cache.evictAll(response.request()); return response; } diff --git a/src/main/java/io/github/nstdio/http/ext/Headers.java b/src/main/java/io/github/nstdio/http/ext/Headers.java index ad76047..26efd90 100644 --- a/src/main/java/io/github/nstdio/http/ext/Headers.java +++ b/src/main/java/io/github/nstdio/http/ext/Headers.java @@ -33,9 +33,9 @@ import java.util.Objects; import java.util.Optional; import java.util.function.BiPredicate; +import java.util.stream.Stream; import static java.util.function.Predicate.not; -import static java.util.stream.Collectors.toList; class Headers { static final String HEADER_VARY = "Vary"; @@ -176,12 +176,11 @@ static Optional parseInstant(HttpHeaders headers, String headerName) { .map(Headers::parseInstant); } - static List effectiveUri(HttpHeaders headers, String headerName, URI responseUri) { + static Stream effectiveUri(HttpHeaders headers, String headerName, URI responseUri) { return headers.allValues(headerName) .stream() .map(s -> effectiveUri(s, responseUri)) - .filter(Objects::nonNull) - .collect(toList()); + .filter(Objects::nonNull); } static URI effectiveUri(String s, URI responseUri) { diff --git a/src/test/kotlin/io/github/nstdio/http/ext/HeadersTest.kt b/src/test/kotlin/io/github/nstdio/http/ext/HeadersTest.kt index 1fbe982..e01b82b 100644 --- a/src/test/kotlin/io/github/nstdio/http/ext/HeadersTest.kt +++ b/src/test/kotlin/io/github/nstdio/http/ext/HeadersTest.kt @@ -25,6 +25,7 @@ import java.net.http.HttpHeaders import java.time.Instant import java.util.Map import java.util.stream.Stream +import kotlin.streams.asSequence internal class HeadersTest { @ParameterizedTest @@ -61,7 +62,7 @@ internal class HeadersTest { @MethodSource("effectiveUriHeadersData") fun effectiveUriHeaders(headers: HttpHeaders?, headerName: String?, responseUri: URI?, expected: List?) { //when - val uris = Headers.effectiveUri(headers, headerName, responseUri) + val uris = Headers.effectiveUri(headers, headerName, responseUri).asSequence().toList() //then assertThat(uris).isEqualTo(expected)