Skip to content

Commit

Permalink
Revert "[ML] Use perAllocation and perDeployment memory usage in the …
Browse files Browse the repository at this point in the history
…model assignment planner"

This reverts commit 31ca2f7.

The functionality of elastic#98874 is being removed from 8.12 because it
means that models which were working successfully on 2GB nodes in
8.11 will no longer fit on 2GB nodes. This will be frustrating for
trial users.

Before 8.13 we need to do a more thorough assessment of which
models will and won't fit on 2GB nodes as a result of better
memory estimation. It may be possible to tweak the memory usage
estimation so that we require more memory than 8.11 but not so
much more that our recommended trial models no longer fit onto
2GB nodes.
  • Loading branch information
droberts195 committed Dec 11, 2023
1 parent 6d5254c commit f43ea60
Show file tree
Hide file tree
Showing 20 changed files with 483 additions and 2,088 deletions.
5 changes: 0 additions & 5 deletions docs/changelog/98874.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Randomness;
Expand Down Expand Up @@ -97,10 +96,6 @@ public final class TrainedModelAssignment implements SimpleDiffable<TrainedModel
private final Instant startTime;
private final int maxAssignedAllocations;

public static boolean useNewMemoryFields(TransportVersion minClusterVersion) {
return minClusterVersion.onOrAfter(TransportVersions.V_8_500_064);
}

public static TrainedModelAssignment fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
Expand Down Expand Up @@ -299,23 +297,29 @@ private void modelSizeStats(
for (TrainedModelConfig model : models) {
if (model.getModelType() == TrainedModelType.PYTORCH) {
long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
// We ensure that in the mixed cluster state trained model stats uses the same values for memory estimation
// as the rebalancer.
boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(
TransportVersionUtils.getMinTransportVersion(clusterService.state())
);
long estimatedMemoryUsageBytes = totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
numberOfAllocations
)
: 0L;
modelSizeStatsByModelId.put(
model.getModelId(),
new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes)
new TrainedModelSizeStats(
totalDefinitionLength,
totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
numberOfAllocations
)
: 0L
)
);
} else {
modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
Expand Down Expand Up @@ -77,8 +76,6 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;

private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064;

private final ClusterService clusterService;
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
Expand Down Expand Up @@ -647,14 +644,12 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(
Map<DiscoveryNode, NodeLoad> nodeLoads = detectNodeLoads(nodes, currentState);
TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState);

boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState));
TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(
currentMetadata,
nodeLoads,
nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState),
modelToAdd,
allocatedProcessorsScale,
useNewMemoryFields
allocatedProcessorsScale
);

Set<String> shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,18 @@ class TrainedModelAssignmentRebalancer {
private final Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd;
private final int allocatedProcessorsScale;

private final boolean useNewMemoryFields;

TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata currentMetadata,
Map<DiscoveryNode, NodeLoad> nodeLoads,
Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone,
Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd,
int allocatedProcessorsScale,
boolean useNewMemoryFields
int allocatedProcessorsScale
) {
this.currentMetadata = Objects.requireNonNull(currentMetadata);
this.nodeLoads = Objects.requireNonNull(nodeLoads);
this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone);
this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd);
this.allocatedProcessorsScale = allocatedProcessorsScale;
this.useNewMemoryFields = useNewMemoryFields;
}

TrainedModelAssignmentMetadata.Builder rebalance() {
Expand Down Expand Up @@ -142,11 +138,9 @@ private static void copyAssignments(
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
dest.assignModelToNode(m, originalNode, assignment.getValue());
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
dest.accountMemory(m, originalNode, requiredMemory);
dest.accountMemory(m, originalNode);
}
}
}
Expand Down Expand Up @@ -174,14 +168,11 @@ private AssignmentPlan computePlanForNormalPriorityModels(
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
return new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
currentAssignments,
assignment.getMaxAssignedAllocations(),
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
assignment.getMaxAssignedAllocations()
);
})
.forEach(planDeployments::add);
Expand All @@ -190,14 +181,11 @@ private AssignmentPlan computePlanForNormalPriorityModels(
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.getModelBytes(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0,
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0
0
)
);
}
Expand Down Expand Up @@ -229,14 +217,12 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
.map(
assignment -> new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory),
assignment.getMaxAssignedAllocations(),
Priority.LOW,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
Priority.LOW
)
)
.forEach(planDeployments::add);
Expand All @@ -245,14 +231,12 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.getModelBytes(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0,
Priority.LOW,
(useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0
Priority.LOW
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ private Node modifyNodePreservingAllocations(Node n) {
int coresUsed = 0;
for (Deployment m : deployments) {
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
int allocations = m.currentAllocationsByNodeId().get(n.id());
bytesUsed += m.estimateMemoryUsageBytes(allocations);
bytesUsed += m.memoryBytes();
coresUsed += calculateUsedCores(n, m);
}
}
Expand All @@ -59,9 +58,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) {
m.allocations() - calculatePreservedAllocations(m),
m.threadsPerAllocation(),
calculateAllocationsPerNodeToPreserve(m),
m.maxAssignedAllocations(),
m.perDeploymentMemoryBytes(),
m.perAllocationMemoryBytes()
m.maxAssignedAllocations()
);
}

Expand All @@ -70,37 +67,28 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
// they will not match the models/nodes members we have in this class.
// Therefore, we build a lookup table based on the ids so we can merge the plan
// with its preserved allocations.
final Map<Tuple<String, String>, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>();
final Map<Tuple<String, String>, Integer> assignmentsByModelNodeIdPair = new HashMap<>();
for (Deployment m : assignmentPlan.models()) {
Map<Node, Integer> assignments = assignmentPlan.assignments(m).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignments.entrySet()) {
plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
}
}

AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments);
for (Node n : nodes) {
// TODO (#101612) Should the first loop happen in the builder constructor?
for (Deployment deploymentAllocationsToPreserve : deployments) {

// if the model m is already allocated on the node n and I want to preserve this allocation
int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve);
if (preservedAllocations > 0) {
long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations);
if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory);
for (Deployment m : deployments) {
for (Node n : nodes) {
int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0);
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) {
allocations += addPreservedAllocations(n, m);
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
mergedPlanBuilder.accountMemory(m, n);
}
}
}
for (Deployment deploymentNewAllocations : deployments) {
int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault(
Tuple.tuple(deploymentNewAllocations.id(), n.id()),
0
);

long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations);
if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations);
if (allocations > 0) {
mergedPlanBuilder.assignModelToNode(m, n, allocations);
}
}
}
Expand Down
Loading

0 comments on commit f43ea60

Please sign in to comment.