Skip to content

Commit

Permalink
Preserve context after remote license check
Browse files Browse the repository at this point in the history
  • Loading branch information
jasontedor committed Aug 19, 2018
1 parent 901acdc commit 222dea9
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.protocol.xpack.XPackInfoRequest;
Expand Down Expand Up @@ -187,16 +188,18 @@ public void onFailure(final Exception e) {

private void remoteClusterLicense(final String clusterAlias, final ActionListener<XPackInfoResponse> listener) {
final ThreadContext threadContext = client.threadPool().getThreadContext();
final ContextPreservingActionListener<XPackInfoResponse> contextPreservingActionListener =
new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();

final XPackInfoRequest request = new XPackInfoRequest();
request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE));
try {
client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, listener);
client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, contextPreservingActionListener);
} catch (final Exception e) {
listener.onFailure(e);
contextPreservingActionListener.onFailure(e);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import org.elasticsearch.protocol.xpack.XPackInfoResponse;
import org.elasticsearch.protocol.xpack.license.LicenseStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.action.XPackInfoAction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
Expand Down Expand Up @@ -92,15 +94,15 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() {
final AtomicInteger index = new AtomicInteger();
final List<XPackInfoResponse> responses = new ArrayList<>();

final Client client = createMockClient();
final ThreadPool threadPool = createMockThreadPool();
final Client client = createMockClient(threadPool);
doAnswer(invocationMock -> {
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
listener.onResponse(responses.get(index.getAndIncrement()));
return null;
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());


final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
Expand All @@ -110,8 +112,9 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() {
new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
final AtomicReference<RemoteClusterLicenseChecker.LicenseCheck> licenseCheck = new AtomicReference<>();

licenseChecker.checkRemoteClusterLicenses(remoteClusterAliases,
new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
licenseChecker.checkRemoteClusterLicenses(
remoteClusterAliases,
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {

@Override
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
Expand All @@ -123,7 +126,7 @@ public void onFailure(final Exception e) {
fail(e.getMessage());
}

});
}));

verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any());
assertNotNull(licenseCheck.get());
Expand All @@ -138,7 +141,8 @@ public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() {
responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null));
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));

final Client client = createMockClient();
final ThreadPool threadPool = createMockThreadPool();
final Client client = createMockClient(threadPool);
doAnswer(invocationMock -> {
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
Expand All @@ -152,7 +156,7 @@ public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() {

licenseChecker.checkRemoteClusterLicenses(
remoteClusterAliases,
new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {

@Override
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
Expand All @@ -164,7 +168,7 @@ public void onFailure(final Exception e) {
fail(e.getMessage());
}

});
}));

verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any());
assertNotNull(licenseCheck.get());
Expand All @@ -179,15 +183,15 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() {

final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
final String failingClusterAlias = randomFrom(remoteClusterAliases);
final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(failingClusterAlias);
final ThreadPool threadPool = createMockThreadPool();
final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, failingClusterAlias);
doAnswer(invocationMock -> {
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
listener.onResponse(responses.get(index.getAndIncrement()));
return null;
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());


responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
Expand All @@ -196,8 +200,9 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() {
new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
final AtomicReference<Exception> exception = new AtomicReference<>();

licenseChecker.checkRemoteClusterLicenses(remoteClusterAliases,
new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
licenseChecker.checkRemoteClusterLicenses(
remoteClusterAliases,
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {

@Override
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
Expand All @@ -209,7 +214,7 @@ public void onFailure(final Exception e) {
exception.set(e);
}

});
}));

assertNotNull(exception.get());
assertThat(exception.get(), instanceOf(ElasticsearchException.class));
Expand All @@ -218,6 +223,69 @@ public void onFailure(final Exception e) {
assertThat(exception.get().getCause(), instanceOf(IllegalArgumentException.class));
}

public void testListenerIsExecutedWithCallingContext() throws InterruptedException {
final AtomicInteger index = new AtomicInteger();
final List<XPackInfoResponse> responses = new ArrayList<>();

final ThreadPool threadPool = new TestThreadPool(getTestName());

try {
final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
final Client client;
final boolean failure = randomBoolean();
if (failure) {
client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, randomFrom(remoteClusterAliases));
} else {
client = createMockClient(threadPool);
}
doAnswer(invocationMock -> {
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
listener.onResponse(responses.get(index.getAndIncrement()));
return null;
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());

responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));

final RemoteClusterLicenseChecker licenseChecker =
new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);

final AtomicBoolean listenerInvoked = new AtomicBoolean();
threadPool.getThreadContext().putHeader("key", "value");
licenseChecker.checkRemoteClusterLicenses(
remoteClusterAliases,
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {

@Override
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
if (failure) {
fail();
}
assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value"));
assertFalse(threadPool.getThreadContext().isSystemContext());
listenerInvoked.set(true);
}

@Override
public void onFailure(final Exception e) {
if (failure == false) {
fail();
}
assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value"));
assertFalse(threadPool.getThreadContext().isSystemContext());
listenerInvoked.set(true);
}

}));

assertTrue(listenerInvoked.get());
} finally {
terminate(threadPool);
}
}

public void testBuildErrorMessageForActiveCompatibleLicense() {
final XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse();
final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info =
Expand Down Expand Up @@ -246,22 +314,52 @@ public void testBuildErrorMessageForInactiveLicense() {
equalTo("the license on cluster [expired-cluster] is not active"));
}

private Client createMockClient() {
return createMockClient(client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client));
private ActionListener<RemoteClusterLicenseChecker.LicenseCheck> doubleInvocationProtectingListener(
final ActionListener<RemoteClusterLicenseChecker.LicenseCheck> listener) {
final AtomicBoolean listenerInvoked = new AtomicBoolean();
return new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {

@Override
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
if (listenerInvoked.compareAndSet(false, true) == false) {
fail("listener invoked twice");
}
listener.onResponse(response);
}

@Override
public void onFailure(final Exception e) {
if (listenerInvoked.compareAndSet(false, true) == false) {
fail("listener invoked twice");
}
listener.onFailure(e);
}

};
}

private ThreadPool createMockThreadPool() {
final ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
return threadPool;
}

private Client createMockClientThatThrowsOnGetRemoteClusterClient(final String clusterAlias) {
return createMockClient(client -> {
when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException());
when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client);
});
private Client createMockClient(final ThreadPool threadPool) {
return createMockClient(threadPool, client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client));
}

private Client createMockClient(final Consumer<Client> finish) {
private Client createMockClientThatThrowsOnGetRemoteClusterClient(final ThreadPool threadPool, final String clusterAlias) {
return createMockClient(
threadPool,
client -> {
when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException());
when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client);
});
}

private Client createMockClient(final ThreadPool threadPool, final Consumer<Client> finish) {
final Client client = mock(Client.class);
final ThreadPool threadPool = mock(ThreadPool.class);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
finish.accept(client);
return client;
}
Expand Down

0 comments on commit 222dea9

Please sign in to comment.