Skip to content

Commit

Permalink
HLRest: Allow caller to set per request options (#30490)
Browse files Browse the repository at this point in the history
This modifies the high level rest client to allow calling code to
customize per request options for the bulk API. You do the actual
customization by passing a `RequestOptions` object to the API call
which is set on the `Request` that is generated by the high level
client. It also makes the `RequestOptions` a thing in the low level
rest client. For now that just means you use it to customize the
headers and the `httpAsyncResponseConsumerFactory` and we'll add
node selectors and per request timeouts in a follow up.

I only implemented this on the bulk API because it is the first one
in the list alphabetically and I wanted to keep the change small
enough to review. I'll convert the remaining APIs in a followup.
  • Loading branch information
nik9000 authored May 31, 2018
1 parent d826cb3 commit b225f5e
Show file tree
Hide file tree
Showing 19 changed files with 578 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ public final TasksClient tasks() {
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
*/
public final BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws IOException {
return performRequestAndParseEntity(bulkRequest, RequestConverters::bulk, options, BulkResponse::fromXContent, emptySet());
}

/**
* Executes a bulk request using the Bulk API
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
* @deprecated Prefer {@link #bulk(BulkRequest, RequestOptions)}
*/
@Deprecated
public final BulkResponse bulk(BulkRequest bulkRequest, Header... headers) throws IOException {
return performRequestAndParseEntity(bulkRequest, RequestConverters::bulk, BulkResponse::fromXContent, emptySet(), headers);
}
Expand All @@ -288,6 +299,17 @@ public final BulkResponse bulk(BulkRequest bulkRequest, Header... headers) throw
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
*/
public final void bulkAsync(BulkRequest bulkRequest, RequestOptions options, ActionListener<BulkResponse> listener) {
performRequestAsyncAndParseEntity(bulkRequest, RequestConverters::bulk, options, BulkResponse::fromXContent, listener, emptySet());
}

/**
* Asynchronously executes a bulk request using the Bulk API
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
* @deprecated Prefer {@link #bulkAsync(BulkRequest, RequestOptions, ActionListener)}
*/
@Deprecated
public final void bulkAsync(BulkRequest bulkRequest, ActionListener<BulkResponse> listener, Header... headers) {
performRequestAsyncAndParseEntity(bulkRequest, RequestConverters::bulk, BulkResponse::fromXContent, listener, emptySet(), headers);
}
Expand Down Expand Up @@ -584,23 +606,42 @@ public final void fieldCapsAsync(FieldCapabilitiesRequest fieldCapabilitiesReque
FieldCapabilitiesResponse::fromXContent, listener, emptySet(), headers);
}

@Deprecated
protected final <Req extends ActionRequest, Resp> Resp performRequestAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
Set<Integer> ignores, Header... headers) throws IOException {
return performRequest(request, requestConverter, (response) -> parseEntity(response.getEntity(), entityParser), ignores, headers);
}

protected final <Req extends ActionRequest, Resp> Resp performRequestAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
Set<Integer> ignores) throws IOException {
return performRequest(request, requestConverter, options,
response -> parseEntity(response.getEntity(), entityParser), ignores);
}

@Deprecated
protected final <Req extends ActionRequest, Resp> Resp performRequest(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<Response, Resp, IOException> responseConverter,
Set<Integer> ignores, Header... headers) throws IOException {
return performRequest(request, requestConverter, optionsForHeaders(headers), responseConverter, ignores);
}

protected final <Req extends ActionRequest, Resp> Resp performRequest(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<Response, Resp, IOException> responseConverter,
Set<Integer> ignores) throws IOException {
ActionRequestValidationException validationException = request.validate();
if (validationException != null) {
throw validationException;
}
Request req = requestConverter.apply(request);
addHeaders(req, headers);
req.setOptions(options);
Response response;
try {
response = client.performRequest(req);
Expand All @@ -626,6 +667,7 @@ protected final <Req extends ActionRequest, Resp> Resp performRequest(Req reques
}
}

@Deprecated
protected final <Req extends ActionRequest, Resp> void performRequestAsyncAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
Expand All @@ -634,10 +676,28 @@ protected final <Req extends ActionRequest, Resp> void performRequestAsyncAndPar
listener, ignores, headers);
}

protected final <Req extends ActionRequest, Resp> void performRequestAsyncAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
ActionListener<Resp> listener, Set<Integer> ignores) {
performRequestAsync(request, requestConverter, options,
response -> parseEntity(response.getEntity(), entityParser), listener, ignores);
}

@Deprecated
protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> listener, Set<Integer> ignores, Header... headers) {
performRequestAsync(request, requestConverter, optionsForHeaders(headers), responseConverter, listener, ignores);
}

protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> listener, Set<Integer> ignores) {
ActionRequestValidationException validationException = request.validate();
if (validationException != null) {
listener.onFailure(validationException);
Expand All @@ -650,19 +710,12 @@ protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req r
listener.onFailure(e);
return;
}
addHeaders(req, headers);
req.setOptions(options);

ResponseListener responseListener = wrapResponseListener(responseConverter, listener, ignores);
client.performRequestAsync(req, responseListener);
}

private static void addHeaders(Request request, Header... headers) {
Objects.requireNonNull(headers, "headers cannot be null");
for (Header header : headers) {
request.addHeader(header.getName(), header.getValue());
}
}

final <Resp> ResponseListener wrapResponseListener(CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> actionListener, Set<Integer> ignores) {
return new ResponseListener() {
Expand Down Expand Up @@ -746,6 +799,15 @@ protected final <Resp> Resp parseEntity(final HttpEntity entity,
}
}

private static RequestOptions optionsForHeaders(Header[] headers) {
RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder();
for (Header header : headers) {
Objects.requireNonNull(header, "header cannot be null");
options.addHeader(header.getName(), header.getValue());
}
return options.build();
}

static boolean convertExistsResponse(Response response) {
return response.getStatusLine().getStatusCode() == 200;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.http.client.methods.HttpGet;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.apache.http.message.BasicRequestLine;
import org.apache.http.message.BasicStatusLine;
import org.apache.lucene.util.BytesRef;
Expand All @@ -48,11 +47,13 @@
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;

import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
Expand All @@ -73,12 +74,12 @@ public void initClients() throws IOException {
final RestClient restClient = mock(RestClient.class);
restHighLevelClient = new CustomRestClient(restClient);

doAnswer(inv -> mockPerformRequest(((Request) inv.getArguments()[0]).getHeaders().iterator().next()))
doAnswer(inv -> mockPerformRequest((Request) inv.getArguments()[0]))
.when(restClient)
.performRequest(any(Request.class));

doAnswer(inv -> mockPerformRequestAsync(
((Request) inv.getArguments()[0]).getHeaders().iterator().next(),
((Request) inv.getArguments()[0]),
(ResponseListener) inv.getArguments()[1]))
.when(restClient)
.performRequestAsync(any(Request.class), any(ResponseListener.class));
Expand All @@ -87,26 +88,32 @@ public void initClients() throws IOException {

public void testCustomEndpoint() throws IOException {
final MainRequest request = new MainRequest();
final Header header = new BasicHeader("node_name", randomAlphaOfLengthBetween(1, 10));
String nodeName = randomAlphaOfLengthBetween(1, 10);

MainResponse response = restHighLevelClient.custom(request, header);
assertEquals(header.getValue(), response.getNodeName());
MainResponse response = restHighLevelClient.custom(request, optionsForNodeName(nodeName));
assertEquals(nodeName, response.getNodeName());

response = restHighLevelClient.customAndParse(request, header);
assertEquals(header.getValue(), response.getNodeName());
response = restHighLevelClient.customAndParse(request, optionsForNodeName(nodeName));
assertEquals(nodeName, response.getNodeName());
}

public void testCustomEndpointAsync() throws Exception {
final MainRequest request = new MainRequest();
final Header header = new BasicHeader("node_name", randomAlphaOfLengthBetween(1, 10));
String nodeName = randomAlphaOfLengthBetween(1, 10);

PlainActionFuture<MainResponse> future = PlainActionFuture.newFuture();
restHighLevelClient.customAsync(request, future, header);
assertEquals(header.getValue(), future.get().getNodeName());
restHighLevelClient.customAsync(request, optionsForNodeName(nodeName), future);
assertEquals(nodeName, future.get().getNodeName());

future = PlainActionFuture.newFuture();
restHighLevelClient.customAndParseAsync(request, future, header);
assertEquals(header.getValue(), future.get().getNodeName());
restHighLevelClient.customAndParseAsync(request, optionsForNodeName(nodeName), future);
assertEquals(nodeName, future.get().getNodeName());
}

private static RequestOptions optionsForNodeName(String nodeName) {
RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder();
options.addHeader("node_name", nodeName);
return options.build();
}

/**
Expand All @@ -115,27 +122,27 @@ public void testCustomEndpointAsync() throws Exception {
*/
@SuppressForbidden(reason = "We're forced to uses Class#getDeclaredMethods() here because this test checks protected methods")
public void testMethodsVisibility() throws ClassNotFoundException {
final String[] methodNames = new String[]{"performRequest",
"performRequestAsync",
final String[] methodNames = new String[]{"parseEntity",
"parseResponseException",
"performRequest",
"performRequestAndParseEntity",
"performRequestAsyncAndParseEntity",
"parseEntity",
"parseResponseException"};
"performRequestAsync",
"performRequestAsyncAndParseEntity"};

final List<String> protectedMethods = Arrays.stream(RestHighLevelClient.class.getDeclaredMethods())
final Set<String> protectedMethods = Arrays.stream(RestHighLevelClient.class.getDeclaredMethods())
.filter(method -> Modifier.isProtected(method.getModifiers()))
.map(Method::getName)
.collect(Collectors.toList());
.collect(Collectors.toCollection(TreeSet::new));

assertThat(protectedMethods, containsInAnyOrder(methodNames));
assertThat(protectedMethods, contains(methodNames));
}

/**
* Mocks the asynchronous request execution by calling the {@link #mockPerformRequest(Header)} method.
* Mocks the asynchronous request execution by calling the {@link #mockPerformRequest(Request)} method.
*/
private Void mockPerformRequestAsync(Header httpHeader, ResponseListener responseListener) {
private Void mockPerformRequestAsync(Request request, ResponseListener responseListener) {
try {
responseListener.onSuccess(mockPerformRequest(httpHeader));
responseListener.onSuccess(mockPerformRequest(request));
} catch (IOException e) {
responseListener.onFailure(e);
}
Expand All @@ -145,7 +152,9 @@ private Void mockPerformRequestAsync(Header httpHeader, ResponseListener respons
/**
* Mocks the synchronous request execution like if it was executed by Elasticsearch.
*/
private Response mockPerformRequest(Header httpHeader) throws IOException {
private Response mockPerformRequest(Request request) throws IOException {
assertThat(request.getOptions().getHeaders(), hasSize(1));
Header httpHeader = request.getOptions().getHeaders().get(0);
final Response mockResponse = mock(Response.class);
when(mockResponse.getHost()).thenReturn(new HttpHost("localhost", 9200));

Expand All @@ -171,20 +180,20 @@ private CustomRestClient(RestClient restClient) {
super(restClient, RestClient::close, Collections.emptyList());
}

MainResponse custom(MainRequest mainRequest, Header... headers) throws IOException {
return performRequest(mainRequest, this::toRequest, this::toResponse, emptySet(), headers);
MainResponse custom(MainRequest mainRequest, RequestOptions options) throws IOException {
return performRequest(mainRequest, this::toRequest, options, this::toResponse, emptySet());
}

MainResponse customAndParse(MainRequest mainRequest, Header... headers) throws IOException {
return performRequestAndParseEntity(mainRequest, this::toRequest, MainResponse::fromXContent, emptySet(), headers);
MainResponse customAndParse(MainRequest mainRequest, RequestOptions options) throws IOException {
return performRequestAndParseEntity(mainRequest, this::toRequest, options, MainResponse::fromXContent, emptySet());
}

void customAsync(MainRequest mainRequest, ActionListener<MainResponse> listener, Header... headers) {
performRequestAsync(mainRequest, this::toRequest, this::toResponse, listener, emptySet(), headers);
void customAsync(MainRequest mainRequest, RequestOptions options, ActionListener<MainResponse> listener) {
performRequestAsync(mainRequest, this::toRequest, options, this::toResponse, listener, emptySet());
}

void customAndParseAsync(MainRequest mainRequest, ActionListener<MainResponse> listener, Header... headers) {
performRequestAsyncAndParseEntity(mainRequest, this::toRequest, MainResponse::fromXContent, listener, emptySet(), headers);
void customAndParseAsync(MainRequest mainRequest, RequestOptions options, ActionListener<MainResponse> listener) {
performRequestAsyncAndParseEntity(mainRequest, this::toRequest, options, MainResponse::fromXContent, listener, emptySet());
}

Request toRequest(MainRequest mainRequest) throws IOException {
Expand Down
Loading

0 comments on commit b225f5e

Please sign in to comment.