Skip to content

Commit

Permalink
Adding node_count to ML Usage (#33850) (#33863)
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent authored Sep 19, 2018
1 parent 839a677 commit 4767a01
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -24,28 +25,39 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
public static final String DETECTORS = "detectors";
public static final String FORECASTS = "forecasts";
public static final String MODEL_SIZE = "model_size";
public static final String NODE_COUNT = "node_count";

private final Map<String, Object> jobsUsage;
private final Map<String, Object> datafeedsUsage;
private final int nodeCount;

public MachineLearningFeatureSetUsage(boolean available, boolean enabled, Map<String, Object> jobsUsage,
Map<String, Object> datafeedsUsage) {
Map<String, Object> datafeedsUsage, int nodeCount) {
super(XPackField.MACHINE_LEARNING, available, enabled);
this.jobsUsage = Objects.requireNonNull(jobsUsage);
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
this.nodeCount = nodeCount;
}

public MachineLearningFeatureSetUsage(StreamInput in) throws IOException {
super(in);
this.jobsUsage = in.readMap();
this.datafeedsUsage = in.readMap();
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
this.nodeCount = in.readInt();
} else {
this.nodeCount = -1;
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeMap(jobsUsage);
out.writeMap(datafeedsUsage);
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
out.writeInt(nodeCount);
}
}

@Override
Expand All @@ -57,6 +69,9 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx
if (datafeedsUsage != null) {
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
}
if (nodeCount >= 0) {
builder.field(NODE_COUNT, nodeCount);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.inject.Inject;
Expand Down Expand Up @@ -132,7 +133,22 @@ public Map<String, Object> nativeCodeInfo() {
@Override
public void usage(ActionListener<XPackFeatureSet.Usage> listener) {
ClusterState state = clusterService.state();
new Retriever(client, MlMetadata.getMlMetadata(state), available(), enabled()).execute(listener);
new Retriever(client, MlMetadata.getMlMetadata(state), available(), enabled(), mlNodeCount(state)).execute(listener);
}

private int mlNodeCount(final ClusterState clusterState) {
if (enabled == false) {
return 0;
}

int mlNodeCount = 0;
for (DiscoveryNode node : clusterState.getNodes()) {
String enabled = node.getAttributes().get(MachineLearning.ML_ENABLED_NODE_ATTR);
if (Boolean.parseBoolean(enabled)) {
++mlNodeCount;
}
}
return mlNodeCount;
}

public static class Retriever {
Expand All @@ -143,19 +159,22 @@ public static class Retriever {
private final boolean enabled;
private Map<String, Object> jobsUsage;
private Map<String, Object> datafeedsUsage;
private int nodeCount;

public Retriever(Client client, MlMetadata mlMetadata, boolean available, boolean enabled) {
public Retriever(Client client, MlMetadata mlMetadata, boolean available, boolean enabled, int nodeCount) {
this.client = Objects.requireNonNull(client);
this.mlMetadata = mlMetadata;
this.available = available;
this.enabled = enabled;
this.jobsUsage = new LinkedHashMap<>();
this.datafeedsUsage = new LinkedHashMap<>();
this.nodeCount = nodeCount;
}

public void execute(ActionListener<Usage> listener) {
if (enabled == false) {
listener.onResponse(new MachineLearningFeatureSetUsage(available, enabled, Collections.emptyMap(), Collections.emptyMap()));
listener.onResponse(
new MachineLearningFeatureSetUsage(available, enabled, Collections.emptyMap(), Collections.emptyMap(), 0));
return;
}

Expand All @@ -164,11 +183,9 @@ public void execute(ActionListener<Usage> listener) {
ActionListener.wrap(response -> {
addDatafeedsUsage(response);
listener.onResponse(new MachineLearningFeatureSetUsage(
available, enabled, jobsUsage, datafeedsUsage));
available, enabled, jobsUsage, datafeedsUsage, nodeCount));
},
error -> {
listener.onFailure(error);
}
listener::onFailure
);

// Step 1. Extract usage from jobs stats and then request stats for all datafeeds
Expand All @@ -181,9 +198,7 @@ public void execute(ActionListener<Usage> listener) {
client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest,
datafeedStatsListener);
},
error -> {
listener.onFailure(error);
}
listener::onFailure
);

// Step 0. Kick off the chain of callbacks by requesting jobs stats
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
package org.elasticsearch.xpack.ml;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -46,7 +51,11 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
Expand Down Expand Up @@ -223,6 +232,49 @@ public void testUsage() throws Exception {
}
}

public void testNodeCount() throws Exception {
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
int nodeCount = randomIntBetween(1, 3);
givenNodeCount(nodeCount);
Settings.Builder settings = Settings.builder().put(commonSettings);
settings.put("xpack.ml.enabled", true);
MachineLearningFeatureSet featureSet = new MachineLearningFeatureSet(TestEnvironment.newEnvironment(settings.build()),
clusterService, client, licenseState);

PlainActionFuture<Usage> future = new PlainActionFuture<>();
featureSet.usage(future);
XPackFeatureSet.Usage usage = future.get();

assertThat(usage.available(), is(true));
assertThat(usage.enabled(), is(true));

BytesStreamOutput out = new BytesStreamOutput();
usage.writeTo(out);
XPackFeatureSet.Usage serializedUsage = new MachineLearningFeatureSetUsage(out.bytes().streamInput());

XContentSource source;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
serializedUsage.toXContent(builder, ToXContent.EMPTY_PARAMS);
source = new XContentSource(builder);
}
assertThat(source.getValue("node_count"), equalTo(nodeCount));

BytesStreamOutput oldOut = new BytesStreamOutput();
oldOut.setVersion(Version.V_6_0_0);
usage.writeTo(oldOut);
StreamInput oldInput = oldOut.bytes().streamInput();
oldInput.setVersion(Version.V_6_0_0);
XPackFeatureSet.Usage oldSerializedUsage = new MachineLearningFeatureSetUsage(oldInput);

XContentSource oldSource;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
oldSerializedUsage.toXContent(builder, ToXContent.EMPTY_PARAMS);
oldSource = new XContentSource(builder);
}

assertNull(oldSource.getValue("node_count"));
}

public void testUsageGivenMlMetadataNotInstalled() throws Exception {
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
Settings.Builder settings = Settings.builder().put(commonSettings);
Expand Down Expand Up @@ -286,6 +338,37 @@ private void givenJobs(List<Job> jobs, List<GetJobsStatsAction.Response.JobStats
}).when(client).execute(same(GetJobsStatsAction.INSTANCE), any(), any());
}

private void givenNodeCount(int nodeCount) {
DiscoveryNodes.Builder nodesBuilder = DiscoveryNodes.builder();
for (int i = 0; i < nodeCount; i++) {
Map<String, String> attrs = new HashMap<>();
attrs.put(MachineLearning.ML_ENABLED_NODE_ATTR, Boolean.toString(true));
Set<DiscoveryNode.Role> roles = new HashSet<>();
roles.add(DiscoveryNode.Role.DATA);
roles.add(DiscoveryNode.Role.MASTER);
roles.add(DiscoveryNode.Role.INGEST);
nodesBuilder.add(new DiscoveryNode(randomAlphaOfLength(i+1),
new TransportAddress(TransportAddress.META_ADDRESS, 9100 + i),
attrs,
roles,
Version.CURRENT));
}
for (int i = 0; i < randomIntBetween(1, 3); i++) {
Map<String, String> attrs = new HashMap<>();
Set<DiscoveryNode.Role> roles = new HashSet<>();
roles.add(DiscoveryNode.Role.DATA);
roles.add(DiscoveryNode.Role.MASTER);
roles.add(DiscoveryNode.Role.INGEST);
nodesBuilder.add(new DiscoveryNode(randomAlphaOfLength(i+1),
new TransportAddress(TransportAddress.META_ADDRESS, 9300 + i),
attrs,
roles,
Version.CURRENT));
}
ClusterState clusterState = new ClusterState.Builder(ClusterState.EMPTY_STATE).nodes(nodesBuilder.build()).build();
when(clusterService.state()).thenReturn(clusterState);
}

private void givenDatafeeds(List<GetDatafeedsStatsAction.Response.DatafeedStats> datafeedStats) {
doAnswer(invocationOnMock -> {
ActionListener<GetDatafeedsStatsAction.Response> listener =
Expand Down

0 comments on commit 4767a01

Please sign in to comment.