diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java index eaef459814170..d2c076fbec6cf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java @@ -8,6 +8,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.client.Client; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.protocol.xpack.XPackInfoRequest; @@ -187,6 +188,8 @@ public void onFailure(final Exception e) { private void remoteClusterLicense(final String clusterAlias, final ActionListener listener) { final ThreadContext threadContext = client.threadPool().getThreadContext(); + final ContextPreservingActionListener contextPreservingActionListener = + new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { // we stash any context here since this is an internal execution and should not leak any existing context information threadContext.markAsSystemContext(); @@ -194,9 +197,9 @@ private void remoteClusterLicense(final String clusterAlias, final ActionListene final XPackInfoRequest request = new XPackInfoRequest(); request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE)); try { - client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, listener); + client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, contextPreservingActionListener); } catch (final Exception e) { - listener.onFailure(e); + contextPreservingActionListener.onFailure(e); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java index 8abfeac149289..5677fb51ee907 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.protocol.xpack.XPackInfoResponse; import org.elasticsearch.protocol.xpack.license.LicenseStatus; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.action.XPackInfoAction; @@ -21,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -92,7 +94,8 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() { final AtomicInteger index = new AtomicInteger(); final List responses = new ArrayList<>(); - final Client client = createMockClient(); + final ThreadPool threadPool = createMockThreadPool(); + final Client client = createMockClient(threadPool); doAnswer(invocationMock -> { @SuppressWarnings("unchecked") ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; @@ -100,7 +103,6 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() { return null; }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); - final List remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3"); responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); @@ -110,8 +112,9 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() { new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); final AtomicReference licenseCheck = new AtomicReference<>(); - licenseChecker.checkRemoteClusterLicenses(remoteClusterAliases, - new ActionListener() { + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { @Override public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { @@ -123,7 +126,7 @@ public void onFailure(final Exception e) { fail(e.getMessage()); } - }); + })); verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any()); assertNotNull(licenseCheck.get()); @@ -138,7 +141,8 @@ public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() { responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null)); responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); - final Client client = createMockClient(); + final ThreadPool threadPool = createMockThreadPool(); + final Client client = createMockClient(threadPool); doAnswer(invocationMock -> { @SuppressWarnings("unchecked") ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; @@ -152,7 +156,7 @@ public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() { licenseChecker.checkRemoteClusterLicenses( remoteClusterAliases, - new ActionListener() { + doubleInvocationProtectingListener(new ActionListener() { @Override public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { @@ -164,7 +168,7 @@ public void onFailure(final Exception e) { fail(e.getMessage()); } - }); + })); verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any()); assertNotNull(licenseCheck.get()); @@ -179,7 +183,8 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() { final List remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3"); final String failingClusterAlias = randomFrom(remoteClusterAliases); - final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(failingClusterAlias); + final ThreadPool threadPool = createMockThreadPool(); + final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, failingClusterAlias); doAnswer(invocationMock -> { @SuppressWarnings("unchecked") ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; @@ -187,7 +192,6 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() { return null; }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); - responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); @@ -196,8 +200,9 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() { new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); final AtomicReference exception = new AtomicReference<>(); - licenseChecker.checkRemoteClusterLicenses(remoteClusterAliases, - new ActionListener() { + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { @Override public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { @@ -209,7 +214,7 @@ public void onFailure(final Exception e) { exception.set(e); } - }); + })); assertNotNull(exception.get()); assertThat(exception.get(), instanceOf(ElasticsearchException.class)); @@ -218,6 +223,69 @@ public void onFailure(final Exception e) { assertThat(exception.get().getCause(), instanceOf(IllegalArgumentException.class)); } + public void testListenerIsExecutedWithCallingContext() throws InterruptedException { + final AtomicInteger index = new AtomicInteger(); + final List responses = new ArrayList<>(); + + final ThreadPool threadPool = new TestThreadPool(getTestName()); + + try { + final List remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3"); + final Client client; + final boolean failure = randomBoolean(); + if (failure) { + client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, randomFrom(remoteClusterAliases)); + } else { + client = createMockClient(threadPool); + } + doAnswer(invocationMock -> { + @SuppressWarnings("unchecked") ActionListener listener = + (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + final RemoteClusterLicenseChecker licenseChecker = + new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); + + final AtomicBoolean listenerInvoked = new AtomicBoolean(); + threadPool.getThreadContext().putHeader("key", "value"); + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + if (failure) { + fail(); + } + assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value")); + assertFalse(threadPool.getThreadContext().isSystemContext()); + listenerInvoked.set(true); + } + + @Override + public void onFailure(final Exception e) { + if (failure == false) { + fail(); + } + assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value")); + assertFalse(threadPool.getThreadContext().isSystemContext()); + listenerInvoked.set(true); + } + + })); + + assertTrue(listenerInvoked.get()); + } finally { + terminate(threadPool); + } + } + public void testBuildErrorMessageForActiveCompatibleLicense() { final XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse(); final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info = @@ -246,22 +314,52 @@ public void testBuildErrorMessageForInactiveLicense() { equalTo("the license on cluster [expired-cluster] is not active")); } - private Client createMockClient() { - return createMockClient(client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client)); + private ActionListener doubleInvocationProtectingListener( + final ActionListener listener) { + final AtomicBoolean listenerInvoked = new AtomicBoolean(); + return new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + if (listenerInvoked.compareAndSet(false, true) == false) { + fail("listener invoked twice"); + } + listener.onResponse(response); + } + + @Override + public void onFailure(final Exception e) { + if (listenerInvoked.compareAndSet(false, true) == false) { + fail("listener invoked twice"); + } + listener.onFailure(e); + } + + }; + } + + private ThreadPool createMockThreadPool() { + final ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + return threadPool; } - private Client createMockClientThatThrowsOnGetRemoteClusterClient(final String clusterAlias) { - return createMockClient(client -> { - when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException()); - when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client); - }); + private Client createMockClient(final ThreadPool threadPool) { + return createMockClient(threadPool, client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client)); } - private Client createMockClient(final Consumer finish) { + private Client createMockClientThatThrowsOnGetRemoteClusterClient(final ThreadPool threadPool, final String clusterAlias) { + return createMockClient( + threadPool, + client -> { + when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException()); + when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client); + }); + } + + private Client createMockClient(final ThreadPool threadPool, final Consumer finish) { final Client client = mock(Client.class); - final ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); finish.accept(client); return client; }