Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adapt to periodic persistent task refresh #36633

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,24 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom {
public static final String TYPE = "ml";
private static final ParseField JOBS_FIELD = new ParseField("jobs");
private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds");
private static final ParseField LAST_MEMORY_REFRESH_VERSION_FIELD = new ParseField("last_memory_refresh_version");

public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), null);
public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap());
// This parser follows the pattern that metadata is parsed leniently (to allow for enhancements)
public static final ObjectParser<Builder, Void> LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new);

static {
LENIENT_PARSER.declareObjectArray(Builder::putJobs, (p, c) -> Job.LENIENT_PARSER.apply(p, c).build(), JOBS_FIELD);
LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds,
(p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD);
LENIENT_PARSER.declareLong(Builder::setLastMemoryRefreshVersion, LAST_MEMORY_REFRESH_VERSION_FIELD);
}

private final SortedMap<String, Job> jobs;
private final SortedMap<String, DatafeedConfig> datafeeds;
private final Long lastMemoryRefreshVersion;
private final GroupOrJobLookup groupOrJobLookup;

private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds, Long lastMemoryRefreshVersion) {
private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds) {
this.jobs = Collections.unmodifiableSortedMap(jobs);
this.datafeeds = Collections.unmodifiableSortedMap(datafeeds);
this.lastMemoryRefreshVersion = lastMemoryRefreshVersion;
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

Expand Down Expand Up @@ -116,10 +112,6 @@ public Set<String> expandDatafeedIds(String expression, boolean allowNoDatafeeds
.expand(expression, allowNoDatafeeds);
}

public Long getLastMemoryRefreshVersion() {
return lastMemoryRefreshVersion;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.V_6_0_0_alpha1;
Expand Down Expand Up @@ -153,21 +145,13 @@ public MlMetadata(StreamInput in) throws IOException {
datafeeds.put(in.readString(), new DatafeedConfig(in));
}
this.datafeeds = datafeeds;
if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshVersion = in.readOptionalLong();
} else {
lastMemoryRefreshVersion = null;
}
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
writeMap(jobs, out);
writeMap(datafeeds, out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
out.writeOptionalLong(lastMemoryRefreshVersion);
}
}

private static <T extends Writeable> void writeMap(Map<String, T> map, StreamOutput out) throws IOException {
Expand All @@ -184,9 +168,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
new DelegatingMapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"), params);
mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams);
mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams);
if (lastMemoryRefreshVersion != null) {
builder.field(LAST_MEMORY_REFRESH_VERSION_FIELD.getPreferredName(), lastMemoryRefreshVersion);
}
return builder;
}

Expand All @@ -203,24 +184,17 @@ public static class MlMetadataDiff implements NamedDiff<MetaData.Custom> {

final Diff<Map<String, Job>> jobs;
final Diff<Map<String, DatafeedConfig>> datafeeds;
final Long lastMemoryRefreshVersion;

MlMetadataDiff(MlMetadata before, MlMetadata after) {
this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer());
this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer());
this.lastMemoryRefreshVersion = after.lastMemoryRefreshVersion;
}

public MlMetadataDiff(StreamInput in) throws IOException {
this.jobs = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), Job::new,
MlMetadataDiff::readJobDiffFrom);
this.datafeeds = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), DatafeedConfig::new,
MlMetadataDiff::readDatafeedDiffFrom);
if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshVersion = in.readOptionalLong();
} else {
lastMemoryRefreshVersion = null;
}
}

/**
Expand All @@ -232,17 +206,13 @@ public MlMetadataDiff(StreamInput in) throws IOException {
public MetaData.Custom apply(MetaData.Custom part) {
TreeMap<String, Job> newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs));
TreeMap<String, DatafeedConfig> newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds));
// lastMemoryRefreshVersion always comes from the diff - no need to merge with the old value
return new MlMetadata(newJobs, newDatafeeds, lastMemoryRefreshVersion);
return new MlMetadata(newJobs, newDatafeeds);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
jobs.writeTo(out);
datafeeds.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
out.writeOptionalLong(lastMemoryRefreshVersion);
}
}

@Override
Expand All @@ -267,8 +237,7 @@ public boolean equals(Object o) {
return false;
MlMetadata that = (MlMetadata) o;
return Objects.equals(jobs, that.jobs) &&
Objects.equals(datafeeds, that.datafeeds) &&
Objects.equals(lastMemoryRefreshVersion, that.lastMemoryRefreshVersion);
Objects.equals(datafeeds, that.datafeeds);
}

@Override
Expand All @@ -278,14 +247,13 @@ public final String toString() {

@Override
public int hashCode() {
return Objects.hash(jobs, datafeeds, lastMemoryRefreshVersion);
return Objects.hash(jobs, datafeeds);
}

public static class Builder {

private TreeMap<String, Job> jobs;
private TreeMap<String, DatafeedConfig> datafeeds;
private Long lastMemoryRefreshVersion;

public Builder() {
jobs = new TreeMap<>();
Expand All @@ -299,7 +267,6 @@ public Builder(@Nullable MlMetadata previous) {
} else {
jobs = new TreeMap<>(previous.jobs);
datafeeds = new TreeMap<>(previous.datafeeds);
lastMemoryRefreshVersion = previous.lastMemoryRefreshVersion;
}
}

Expand Down Expand Up @@ -419,13 +386,8 @@ public Builder putDatafeeds(Collection<DatafeedConfig> datafeeds) {
return this;
}

public Builder setLastMemoryRefreshVersion(Long lastMemoryRefreshVersion) {
this.lastMemoryRefreshVersion = lastMemoryRefreshVersion;
return this;
}

public MlMetadata build() {
return new MlMetadata(jobs, datafeeds, lastMemoryRefreshVersion);
return new MlMetadata(jobs, datafeeds);
}

public void markJobAsDeleting(String jobId, PersistentTasksCustomMetaData tasks, boolean allowDeleteOpenJob) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
this.datafeedManager.set(datafeedManager);
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
autodetectProcessManager);
MlMemoryTracker memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider);
MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider);
this.memoryTracker.set(memoryTracker);

// This object's constructor attaches to the license state, so there's no need to retain another reference to it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ static RemovalResult removeJobsAndDatafeeds(List<String> jobsToRemove, List<Stri
}

MlMetadata.Builder builder = new MlMetadata.Builder();
builder.setLastMemoryRefreshVersion(mlMetadata.getLastMemoryRefreshVersion())
.putJobs(currentJobs.values())
builder.putJobs(currentJobs.values())
.putDatafeeds(currentDatafeeds.values());

return new RemovalResult(builder.build(), removedJobIds, removedDatafeedIds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,7 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j

if (memoryTracker.isRecentlyRefreshed() == false) {

boolean scheduledRefresh = memoryTracker.asyncRefresh(ActionListener.wrap(
acknowledged -> {
if (acknowledged) {
logger.trace("Job memory requirement refresh request completed successfully");
} else {
logger.warn("Job memory requirement refresh request completed but did not set time in cluster state");
}
},
e -> logger.error("Failed to refresh job memory requirements", e)
));
boolean scheduledRefresh = memoryTracker.asyncRefresh();
if (scheduledRefresh) {
String reason = "Not opening job [" + jobId + "] because job memory requirements are stale - refresh requested";
logger.debug(reason);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,15 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.LocalNodeMasterListener;
import org.elasticsearch.cluster.ack.AckedRequest;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.persistent.PersistentTasksClusterService;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.job.config.Job;
Expand All @@ -44,22 +40,10 @@
* 1. For all open ML jobs (via {@link #asyncRefresh})
* 2. For all open ML jobs, plus one named ML job that is not open (via {@link #refreshJobMemoryAndAllOthers})
* 3. For one named ML job (via {@link #refreshJobMemory})
* In all cases a listener informs the caller when the requested updates are complete.
* In cases 2 and 3 a listener informs the caller when the requested updates are complete.
*/
public class MlMemoryTracker implements LocalNodeMasterListener {

private static final AckedRequest ACKED_REQUEST = new AckedRequest() {
@Override
public TimeValue ackTimeout() {
return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT;
}

@Override
public TimeValue masterNodeTimeout() {
return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT;
}
};

private static final Duration RECENT_UPDATE_THRESHOLD = Duration.ofMinutes(1);

private final Logger logger = LogManager.getLogger(MlMemoryTracker.class);
Expand All @@ -72,14 +56,22 @@ public TimeValue masterNodeTimeout() {
private final JobResultsProvider jobResultsProvider;
private volatile boolean isMaster;
private volatile Instant lastUpdateTime;
private volatile Duration reassignmentRecheckInterval;

public MlMemoryTracker(ClusterService clusterService, ThreadPool threadPool, JobManager jobManager,
public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadPool threadPool, JobManager jobManager,
JobResultsProvider jobResultsProvider) {
this.threadPool = threadPool;
this.clusterService = clusterService;
this.jobManager = jobManager;
this.jobResultsProvider = jobResultsProvider;
setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings));
clusterService.addLocalNodeMasterListener(this);
clusterService.getClusterSettings().addSettingsUpdateConsumer(
PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING, this::setReassignmentRecheckInterval);
}

private void setReassignmentRecheckInterval(TimeValue recheckInterval) {
reassignmentRecheckInterval = Duration.ofNanos(recheckInterval.getNanos());
}

@Override
Expand All @@ -103,11 +95,12 @@ public String executorName() {

/**
* Is the information in this object sufficiently up to date
* for valid allocation decisions to be made using it?
* for valid task assignment decisions to be made using it?
*/
public boolean isRecentlyRefreshed() {
Instant localLastUpdateTime = lastUpdateTime;
return localLastUpdateTime != null && localLastUpdateTime.plus(RECENT_UPDATE_THRESHOLD).isAfter(Instant.now());
return localLastUpdateTime != null &&
localLastUpdateTime.plus(RECENT_UPDATE_THRESHOLD).plus(reassignmentRecheckInterval).isAfter(Instant.now());
}

/**
Expand Down Expand Up @@ -143,24 +136,19 @@ public void removeJob(String jobId) {
/**
* Uses a separate thread to refresh the memory requirement for every ML job that has
* a corresponding persistent task. This method only works on the master node.
* @param listener Will be called when the async refresh completes or fails. The
* boolean value indicates whether the cluster state was updated
* with the refresh completion time. (If it was then this will in
* cause the persistent tasks framework to check if any persistent
* tasks are awaiting allocation.)
* @return <code>true</code> if the async refresh is scheduled, and <code>false</code>
* if this is not possible for some reason.
*/
public boolean asyncRefresh(ActionListener<Boolean> listener) {
public boolean asyncRefresh() {

if (isMaster) {
try {
ActionListener<Void> mlMetaUpdateListener = ActionListener.wrap(
aVoid -> recordUpdateTimeInClusterState(listener),
listener::onFailure
ActionListener<Void> listener = ActionListener.wrap(
aVoid -> logger.trace("Job memory requirement refresh request completed successfully"),
e -> logger.error("Failed to refresh job memory requirements", e)
);
threadPool.executor(executorName()).execute(
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), mlMetaUpdateListener));
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), listener));
return true;
} catch (EsRejectedExecutionException e) {
logger.debug("Couldn't schedule ML memory update - node might be shutting down", e);
Expand Down Expand Up @@ -227,33 +215,6 @@ void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener<Void>
}
}

private void recordUpdateTimeInClusterState(ActionListener<Boolean> listener) {

clusterService.submitStateUpdateTask("ml-memory-last-update-time",
new AckedClusterStateUpdateTask<Boolean>(ACKED_REQUEST, listener) {
@Override
protected Boolean newResponse(boolean acknowledged) {
return acknowledged;
}

@Override
public ClusterState execute(ClusterState currentState) {
MlMetadata currentMlMetadata = MlMetadata.getMlMetadata(currentState);
MlMetadata.Builder builder = new MlMetadata.Builder(currentMlMetadata);
builder.setLastMemoryRefreshVersion(currentState.getVersion() + 1);
MlMetadata newMlMetadata = builder.build();
if (newMlMetadata.equals(currentMlMetadata)) {
// Return same reference if nothing has changed
return currentState;
} else {
ClusterState.Builder newState = ClusterState.builder(currentState);
newState.metaData(MetaData.builder(currentState.getMetaData()).putCustom(MlMetadata.TYPE, newMlMetadata).build());
return newState.build();
}
}
});
}

private void iterateMlJobTasks(Iterator<PersistentTasksCustomMetaData.PersistentTask<?>> iterator,
ActionListener<Void> refreshComplete) {
if (iterator.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ protected MlMetadata createTestInstance() {
builder.putJob(job, false);
}
}
if (randomBoolean()) {
builder.setLastMemoryRefreshVersion(randomNonNegativeLong());
}
return builder.build();
}

Expand Down Expand Up @@ -441,9 +438,8 @@ protected MlMetadata mutateInstance(MlMetadata instance) {
for (Map.Entry<String, DatafeedConfig> entry : datafeeds.entrySet()) {
metadataBuilder.putDatafeed(entry.getValue(), Collections.emptyMap());
}
metadataBuilder.setLastMemoryRefreshVersion(instance.getLastMemoryRefreshVersion());

switch (between(0, 2)) {
switch (between(0, 1)) {
case 0:
metadataBuilder.putJob(JobTests.createRandomizedJob(), true);
break;
Expand All @@ -463,13 +459,6 @@ protected MlMetadata mutateInstance(MlMetadata instance) {
metadataBuilder.putJob(randomJob, false);
metadataBuilder.putDatafeed(datafeedConfig, Collections.emptyMap());
break;
case 2:
if (instance.getLastMemoryRefreshVersion() == null) {
metadataBuilder.setLastMemoryRefreshVersion(randomNonNegativeLong());
} else {
metadataBuilder.setLastMemoryRefreshVersion(null);
}
break;
default:
throw new AssertionError("Illegal randomisation branch");
}
Expand Down
Loading