diff --git a/README.md b/README.md index 2e1801d..9aa27d9 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,41 @@ spec: claimName: api-exp-data ``` +#### Profiles +This is a Spring Boot application, that can be run with profiles. The "default" profile is used if no configuration is set. The "dev" profile can be enabled by setting the JVM System Parameter + + -Dspring.profiles.active=dev +or Environment Variable + + export spring_profiles_active=dev +or via the corresponding setting in your development environment or within the pod definition. + +Example: + + $ SCHEDULER_NAME=workflow-scheduler java -Dspring.profiles.active=dev -jar cws-k8s-scheduler-1.2-SNAPSHOT.jar + +The "dev" profile is useful for debugging and reporting problems because it increases the log-level. + +--- +#### Memory Prediction and Task Scaling +- Supported if used together with [nf-cws](https://github.com/CommonWorkflowScheduler/nf-cws) version 1.0.4 or newer. +- Kubernetes Feature InPlacePodVerticalScaling must be enabled. This is available starting from Kubernetes v1.27. See [KEP 1287](https://github.com/kubernetes/enhancements/issues/1287) for the current status. +- It is required to enable traces in Nextflow via `trace.enabled = true` in the config file, or the commandline option `-with-trace`. + +The memory predictor that shall be used for task scaling is set via the nf-cws configuration. If not set, task scaling is disabled. + +| cws.memoryPredictor | Behaviour | +|---------------------|------------------------------------------------------------------------------------------------------------------------------------| +| "" | Disabled if empty or not set. | +| none | NonePredictor, will never make any predictions and consequently no task scaling will occur. Used for testing and benchmarking only.| +| constant | ConstantPredictor, will try to predict a constant memory usage pattern. | +| linear | LinearPredictor, will try to predict a memory usage that is linear to the task input size. | +| combi | CombiPredictor, combines predictions from ConstantPredictor and LinearPredictor. | +| wary | WaryPredictor, behaves like LinearPredictor but is more cautious about its predictions. | +| default | Query the environment variable "MEMORY_PREDICTOR_DEFAULT" and use the value that is set there. | + +If a memory predictor is selected (i.e. setting is not disabled), the implementation will locally record statistics and print out the result after the workflow has finished. This can be disabled via the environment variable "DISABLE_STATISTICS". If this is set to any string, the implementation will not collect and print out the results. + --- If you use this software or artifacts in a publication, please cite it as: diff --git a/pom.xml b/pom.xml index 612112e..97c1a69 100644 --- a/pom.xml +++ b/pom.xml @@ -134,6 +134,12 @@ jackson-annotations + + org.apache.commons + commons-math3 + 3.6.1 + + diff --git a/src/main/java/cws/k8s/scheduler/client/KubernetesClient.java b/src/main/java/cws/k8s/scheduler/client/KubernetesClient.java index 082a74a..1c6734e 100644 --- a/src/main/java/cws/k8s/scheduler/client/KubernetesClient.java +++ b/src/main/java/cws/k8s/scheduler/client/KubernetesClient.java @@ -2,6 +2,7 @@ import cws.k8s.scheduler.model.NodeWithAlloc; import cws.k8s.scheduler.model.PodWithAge; +import cws.k8s.scheduler.model.Task; import io.fabric8.kubernetes.api.model.*; import io.fabric8.kubernetes.client.DefaultKubernetesClient; import io.fabric8.kubernetes.client.KubernetesClientException; @@ -218,4 +219,55 @@ public void onClose(WatcherException cause) { } + /** + * After some testing, this was found to be the only reliable way to patch a pod + * using the Kubernetes client. + * + * It will create a patch for the memory limits and request values and submit it + * to the cluster. + * + * @param t the task to be patched + * @param value the value to be set + * @return false if patching failed because of InPlacePodVerticalScaling + */ + public boolean patchTaskMemory(Task t, String value) { + String namespace = t.getPod().getMetadata().getNamespace(); + String podname = t.getPod().getName(); + log.debug("namespace: {}, podname: {}", namespace, podname); + // @formatter:off + String patch = "kind: Pod\n" + + "apiVersion: v1\n" + + "metadata:\n" + + " name: PODNAME\n" + + " namespace: NAMESPACE\n" + + "spec:\n" + + " containers:\n" + + " - name: PODNAME\n" + + " resources:\n" + + " limits:\n" + + " memory: LIMIT\n" + + " requests:\n" + + " memory: REQUEST\n" + + "\n"; + // @formatter:on + patch = patch.replace("NAMESPACE", namespace); + patch = patch.replace("PODNAME", podname); + patch = patch.replace("LIMIT", value); + patch = patch.replace("REQUEST", value); + log.debug(patch); + + try { + this.pods().inNamespace(namespace).withName(podname).patch(patch); + } catch (KubernetesClientException e) { + // this typically happens when the feature gate InPlacePodVerticalScaling was not enabled + if (e.toString().contains("Forbidden: pod updates may not change fields other than")) { + log.error("Could not patch task. Please make sure that the feature gate 'InPlacePodVerticalScaling' is enabled in Kubernetes. See https://github.com/kubernetes/enhancements/issues/1287 for details. Task scaling will now be disabled for the rest of this workflow execution."); + return false; + } else { + log.error("Could not patch task: {}", e); + } + } + return true; + } + } diff --git a/src/main/java/cws/k8s/scheduler/memory/CombiPredictor.java b/src/main/java/cws/k8s/scheduler/memory/CombiPredictor.java new file mode 100644 index 0000000..94f0cc4 --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/CombiPredictor.java @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +//@formatter:off +/** +* CombiPredictor will combine predictions made by ConstantPredictor and +* LineraPredictor. +* +* LinearPredictor fails if there are no inputSize differences to tasks, +* ConstantPredictor can handle this case. So CombiPredictor will run both and +* decide dynamically which predictions to apply. +* +* @author Florian Friederici +* +*/ +//@formatter:on +@Slf4j +public class CombiPredictor implements MemoryPredictor { + + ConstantPredictor constantPredictor; + LinearPredictor linearPredictor; + + public CombiPredictor() { + this.constantPredictor = new ConstantPredictor(); + this.linearPredictor = new LinearPredictor(); + } + + @Override + public void addObservation(Observation o) { + log.debug("CombiPredictor.addObservation({})", o); + constantPredictor.addObservation(o); + linearPredictor.addObservation(o); + } + + @Override + public BigDecimal queryPrediction(Task task) { + String taskName = task.getConfig().getTask(); + log.debug("CombiPredictor.queryPrediction({},{})", taskName, task.getInputSize()); + + BigDecimal constantPrediction = constantPredictor.queryPrediction(task); + BigDecimal linearPrediction = linearPredictor.queryPrediction(task); + + if (constantPrediction==null && linearPrediction==null) { + // no prediction available at all + return null; + } + + if (constantPrediction!=null && linearPrediction==null) { + // only the constantPrediction is available + return constantPrediction; + } + + if (constantPrediction==null && linearPrediction!=null) { + // only the linearPrediction is available (unusual case) + return linearPrediction; + } + + log.debug("constantPrediction={}, linearPrediction={}, difference={}", constantPrediction, linearPrediction, constantPrediction.subtract(linearPrediction)); + + // prefer linearPrediction if both would be available + return linearPrediction; + } + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/ConstantPredictor.java b/src/main/java/cws/k8s/scheduler/memory/ConstantPredictor.java new file mode 100644 index 0000000..891f03c --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/ConstantPredictor.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.HashMap; +import java.util.Map; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +// @formatter:off +/** + * ConstantPredictor will use the following strategy: + * + * - In case task was successful: + * - let the next prediction be 10% higher, then the peakRss was + * + * - In case task has failed: + * - reset to initial value + * + * I.e. the suggestions from ConstantPredictor are not dependent on the input + * size of the tasks. + * + * @author Florian Friederici + * + */ +// @formatter:on +@Slf4j +class ConstantPredictor implements MemoryPredictor { + + Map model; + Map initialValue; + + public ConstantPredictor() { + model = new HashMap<>(); + initialValue = new HashMap<>(); + } + + @Override + public void addObservation(Observation o) { + log.debug("ConstantPredictor.addObservation({})", o); + if (!TaskScaler.checkObservationSanity(o)) { + log.warn("dismiss observation {}", o); + return; + } + + // store initial ramRequest value per task + if (!initialValue.containsKey(o.task)) { + initialValue.put(o.task, o.getRamRequest()); + } + + if (Boolean.TRUE.equals(o.success)) { + // set model to peakRss + 10% + if (model.containsKey(o.task)) { + model.replace(o.task, o.peakRss.multiply(new BigDecimal("1.1")).setScale(0, RoundingMode.CEILING)); + } else { + model.put(o.task, o.peakRss.multiply(new BigDecimal("1.1")).setScale(0, RoundingMode.CEILING)); + } + } else { + // reset to initialValue + if (model.containsKey(o.task)) { + model.replace(o.task, this.initialValue.get(o.task)); + } else { + model.put(o.task, o.ramRequest.multiply(new BigDecimal(2)).setScale(0, RoundingMode.CEILING)); + } + } + + } + + @Override + public BigDecimal queryPrediction(Task task) { + String taskName = task.getConfig().getTask(); + log.debug("ConstantPredictor.queryPrediction({})", taskName); + + if (model.containsKey(taskName)) { + return model.get(taskName); + } else { + return null; + } + } +} diff --git a/src/main/java/cws/k8s/scheduler/memory/LinearPredictor.java b/src/main/java/cws/k8s/scheduler/memory/LinearPredictor.java new file mode 100644 index 0000000..e94f594 --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/LinearPredictor.java @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +//@formatter:off +/** + * LinearPredictor will use the following strategy: + * + * If there are less than 2 observations, give no prediction, else: + * Calculate linear regression model and provide predictions. + * + * Predictions start with 10% over-provisioning. If tasks fail, this will + * increase automatically. + * + * @author Florian Friederici + * + */ +//@formatter:on +@Slf4j +public class LinearPredictor implements MemoryPredictor { + + Map model; + Map overprovisioning; + + public LinearPredictor() { + model = new HashMap<>(); + overprovisioning = new HashMap<>(); + } + + @Override + public void addObservation(Observation o) { + log.debug("LinearPredictor.addObservation({})", o); + if (!TaskScaler.checkObservationSanity(o)) { + log.warn("dismiss observation {}", o); + return; + } + + if (!overprovisioning.containsKey(o.task)) { + overprovisioning.put(o.task, 1.1); + } + + if (Boolean.TRUE.equals(o.success)) { + if (!model.containsKey(o.task)) { + model.put(o.task, new SimpleRegression()); + } + + double x = o.getInputSize(); + double y = o.getPeakRss().doubleValue(); + model.get(o.task).addData(x,y); + } else { + log.debug("overprovisioning value will increase due to task failure"); + Double old = overprovisioning.get(o.task); + overprovisioning.put(o.task, old+0.05); + } + } + + @Override + public BigDecimal queryPrediction(Task task) { + String taskName = task.getConfig().getTask(); + log.debug("LinearPredictor.queryPrediction({},{})", taskName, task.getInputSize()); + + if (!model.containsKey(taskName)) { + log.debug("LinearPredictor has no model for {}", taskName); + return null; + } + + SimpleRegression simpleRegression = model.get(taskName); + double prediction = simpleRegression.predict(task.getInputSize()); + + if (Double.isNaN(prediction)) { + log.debug("No prediction possible for {}", taskName); + return null; + } + + if (prediction < 0) { + log.warn("prediction would be negative: {}", prediction); + return null; + } + + return BigDecimal.valueOf(prediction).multiply(BigDecimal.valueOf(overprovisioning.get(taskName))).setScale(0, RoundingMode.CEILING); + } + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/MemoryPredictor.java b/src/main/java/cws/k8s/scheduler/memory/MemoryPredictor.java new file mode 100644 index 0000000..2536617 --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/MemoryPredictor.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; + +import cws.k8s.scheduler.model.Task; + +// @formatter:off +/** + * The MemoryPredictor has two important interfaces: + * + * 1) addObservation() + * - "add a new observation" after a workflow task is finished, the + * observation result will be collected in the MemoryPredictor + * + * 2) queryPrediction() + * - "ask for a suggestion" at any time, the MemoryPredictor can be asked + * what its guess is on the resource requirement of a task + * + * Different strategies can be tried and exchanged easily, they just have to + * implement those two interfaces. See ConstantPredictor and LinearPredictor + * for concrete strategies. + * + * @author Florian Friederici + * + */ +// @formatter:on +interface MemoryPredictor { + + /** + * input observation into the MemoryPredictor, to be used to learn memory usage + * of tasks to create suggestions + * + * @param o the observation that was made + */ + void addObservation(Observation o); + + /** + * ask the MemoryPredictor for a suggestion on how much memory should be + * assigned to the task. + * + * @param task the task to get a suggestion form + * @return null, if no suggestion possible, otherwise the value to be used + */ + BigDecimal queryPrediction(Task task); + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/NfTrace.java b/src/main/java/cws/k8s/scheduler/memory/NfTrace.java new file mode 100644 index 0000000..a7da014 --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/NfTrace.java @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * Nextflow writes a trace file, when run with "-with-trace" on command line, or + * "trace.enabled = true" in the configuration file. + * + * This class contains methods to extract values from this traces after the + * Tasks have finished. + * + * @author Florian Friederici + * + */ +@Slf4j +public class NfTrace { + + private NfTrace() { + throw new IllegalStateException("Utility class"); + } + + /** + * This method will get the peak vmem from the Nextflow + * trace, and return it in BigDecimal format. According to the Nextflow + * documentation this is: + * + * "Peak of virtual memory. This data is read from field VmPeak in /proc/$pid/status file." + * + * https://www.nextflow.io/docs/latest/tracing.html#trace-report + * + * The Linux kernel provides the value in KiB, we multiply with 1024 to have + * it in byte, like the other values we use. + * + * @return The peak VMEM value that this task has used (in byte), -1 if extraction failed + */ + static BigDecimal getNfPeakVmem(Task task) { + String value = extractTraceFile(task, "peak_vmem"); + if (value == null) { + // extraction failed, return -1 + return BigDecimal.valueOf(-1); + } else { + return new BigDecimal(value).multiply(BigDecimal.valueOf(1024l)); + } + } + + /** + * This method will get the peak resident set size (RSS) from the Nextflow + * trace, and return it in BigDecimal format. According to the Nextflow + * documentation this is: + * + * "Peak of real memory. This data is read from field VmHWM in /proc/$pid/status file." + * + * https://www.nextflow.io/docs/latest/tracing.html#trace-report + * + * The Linux kernel provides the value in KiB, we multiply with 1024 to have + * it in byte, like the other values we use. + * + * If the task failed, this can be 0. + * + * @return The peak RSS value that this task has used (in byte), -1 if extraction failed + */ + static BigDecimal getNfPeakRss(Task task) { + String value = extractTraceFile(task, "peak_rss"); + if (value == null) { + // extraction failed, return -1 + return BigDecimal.valueOf(-1); + } else { + return new BigDecimal(value).multiply(BigDecimal.valueOf(1024l)); + } + } + + /** + * This method will get the realtime value from the Nextflow trace, and + * return it as long. According to the Nextflow documentation this is: + * + * "Task execution time i.e. delta between completion and start timestamp." + * + * https://www.nextflow.io/docs/latest/tracing.html#trace-report + * + * @return task execution time (in ms), -1 if extraction failed + */ + static long getNfRealTime(Task task) { + String value = extractTraceFile(task, "realtime"); + if (value == null) { + return -1; + } else { + return Long.valueOf(value); + } + } + + private static String extractTraceFile(Task task, String key) { + final String nfTracePath = task.getWorkingDir() + '/' + ".command.trace"; + try { + Path path = Paths.get(nfTracePath); + List allLines = Files.readAllLines(path); + for (String a : allLines) { + if (a.startsWith(key)) { + return a.substring(key.length()+1); + } + } + } catch (Exception e) { + log.warn("Cannot read nf .command.trace file in " + nfTracePath, e); + } + return null; + } + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/NonePredictor.java b/src/main/java/cws/k8s/scheduler/memory/NonePredictor.java new file mode 100644 index 0000000..20202fe --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/NonePredictor.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * NonePredictor will not provide predictions at all. Which results in no + * changes to tasks in consequence. This is useful as baseline. + * + * @author Florian Friederici + * + */ +@Slf4j +public class NonePredictor implements MemoryPredictor { + + @Override + public void addObservation(Observation o) { + log.debug("NonePredictor.addObservation({})", o); + if (!TaskScaler.checkObservationSanity(o)) { + log.warn("dismiss observation {}", o); + return; + } + } + + @Override + public BigDecimal queryPrediction(Task task) { + log.debug("NonePredictor.queryPrediction({})", task); + return null; + } + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/Observation.java b/src/main/java/cws/k8s/scheduler/memory/Observation.java new file mode 100644 index 0000000..062a57b --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/Observation.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; + +import lombok.Builder; +import lombok.Data; + +/** + * This class holds the observations that can be made after the execution of a + * task in the workflow. Depending on those observations, either a single one, + * or multiple, tasks resource needs can be adopted by algorithms. + * + * Note: Would have been an java record if target was java 14+ + * + * @author Florian Friederici + * + */ +@Data +@Builder +public class Observation { + + final String task; + final String taskName; + final Boolean success; + final long inputSize; + final BigDecimal ramRequest; + final BigDecimal peakVmem; + final BigDecimal peakRss; + final long realtime; + final String node; + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/Statistics.java b/src/main/java/cws/k8s/scheduler/memory/Statistics.java new file mode 100644 index 0000000..59d58ef --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/Statistics.java @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Date; +import java.util.DoubleSummaryStatistics; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.LongSummaryStatistics; +import java.util.Map; +import java.util.Set; + +import cws.k8s.scheduler.scheduler.Scheduler; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +/** + * This class collects all observations and provides a statistics summary + * + * Statistics can be disabled via environment variable DISABLE_STATISTICS + * + * @author Florian Friederici + * + */ +@Slf4j +public class Statistics { + + String baseDir; + final Scheduler scheduler; + final MemoryPredictor memoryPredictor; + long start; + long end; + + boolean active = true; + List observations = new ArrayList<>(); + + public Statistics(Scheduler scheduler, MemoryPredictor memoryPredictor) { + this.scheduler = scheduler; + this.memoryPredictor = memoryPredictor; + String disableStatistics = System.getenv("DISABLE_STATISTICS"); + if (disableStatistics != null) { + active = false; + } + this.start = System.currentTimeMillis(); + } + + /** + * Collect Observations for statistics + * + * @param o the observation that was made + */ + void addObservation(Observation o) { + if (active) { + observations.add(o); + } + } + + /** + * Save all Observations into csv file in baseDir + * + * @param timestamp time stamp for the file name + * @return the csv as string for logging + */ + String exportCsv(long timestamp) { + if (!active) { + log.info("Statistics disabled by environment variable"); + return ""; + } + StringBuilder sb = new StringBuilder(); + sb.append("task,taskName,success,inputSize,ramRequest,peakVmem,peakRss,realtime,node\n"); + for (Observation o : observations) { + sb.append(o.getTask()); + sb.append(","); + sb.append(o.getTaskName()); + sb.append(","); + sb.append(o.getSuccess()); + sb.append(","); + sb.append(o.getInputSize()); + sb.append(","); + sb.append(o.getRamRequest().toPlainString()); + sb.append(","); + sb.append(o.getPeakVmem().toPlainString()); + sb.append(","); + sb.append(o.getPeakRss().toPlainString()); + sb.append(","); + sb.append(o.getRealtime()); + sb.append(","); + sb.append(o.getNode()); + sb.append("\n"); + } + String csv = sb.toString(); + + if (baseDir != null) { + Path path = Paths.get(baseDir + "TaskScaler_" + timestamp + ".csv"); + log.debug("save csv to: {}", path); + try { + Files.write(path, csv.getBytes()); + } catch (IOException e) { + log.warn("could not save statistics csv to {}", path); + } + } else { + log.debug("baseDir was not set, could not save csv file"); + } + + return csv; + } + + /** + * Save summary to file in baseDir + * + * @param timestamp time stamp for the filename + * @return the summary as string for logging + */ + String summary(long timestamp) { + if (!active) { + log.info("Statistics disabled by environment variable"); + return ""; + } + StringBuilder sb = new StringBuilder(); + sb.append("~~~ Statistics ~~~\n"); + sb.append(" execution: "); + sb.append(this.scheduler.getExecution()); + sb.append("\n"); + sb.append(" memory predictor: "); + sb.append(this.memoryPredictor.getClass()); + sb.append("\n"); + sb.append(" start: "); + sb.append(String.valueOf(new Date(this.start))); + sb.append(" | end: "); + sb.append(String.valueOf(new Date(this.end))); + sb.append("\n makespan: "); + sb.append(this.end-this.start); + sb.append(" ms\n"); + sb.append(" total observations collected: "); + sb.append(observations.size()); + sb.append("\n"); + + Set tasks = new HashSet<>(); + Map> taskMap = new HashMap<>(); + Map taskSummaryMap = new HashMap<>(); + for (Observation o : observations) { + tasks.add(o.task); + if (!taskMap.containsKey(o.task)) { + taskMap.put(o.task, new HashSet<>()); + } + taskMap.get(o.task).add(o.taskName); + if (!taskSummaryMap.containsKey(o.task)) { + taskSummaryMap.put(o.task, new TaskSummary(o.task)); + } + TaskSummary ts = taskSummaryMap.get(o.task); + if (Boolean.TRUE.equals(o.success)) { + ts.successCount++; + ts.inputSizeStatistics.accept(o.inputSize); + // Note: There might be a loss of precision after 15 digits here + ts.ramRequestStatitistics.accept(o.ramRequest.doubleValue()); + ts.peakVmemStatistics.accept(o.peakVmem.doubleValue()); + ts.peakRssStatistics.accept(o.peakRss.doubleValue()); + ts.realtimeStatistics.accept(o.realtime); + } else { + ts.failCount++; + } + } + + sb.append(" different tasks: "); + sb.append(tasks.size()); + sb.append("\n"); + + for (String task : tasks) { + TaskSummary ts = taskSummaryMap.get(task); + sb.append(" -- task: '"); + sb.append(task); + sb.append("' --\n"); + sb.append(" named instances of '"); + sb.append(task); + sb.append("' seen: "); + sb.append(taskMap.get(task).size()); + sb.append("\n"); + sb.append(" success count: "); + sb.append(ts.successCount); + sb.append("\n"); + sb.append(" failure count: "); + sb.append(ts.failCount); + sb.append("\n"); + // @formatter:off + sb.append(String.format(Locale.US, "inputSize : cnt %d, avr %.1f, min %d, max %d%n", + ts.inputSizeStatistics.getCount(), + ts.inputSizeStatistics.getAverage(), + ts.inputSizeStatistics.getMin(), + ts.inputSizeStatistics.getMax()) ); + sb.append(String.format(Locale.US, "ramRequest : cnt %d, avr %.3e, min %.3e, max %.3e%n", + ts.ramRequestStatitistics.getCount(), + ts.ramRequestStatitistics.getAverage(), + ts.ramRequestStatitistics.getMin(), + ts.ramRequestStatitistics.getMax()) ); + sb.append(String.format(Locale.US, "peakVmem : cnt %d, avr %.3e, min %.3e, max %.3e%n", + ts.peakVmemStatistics.getCount(), + ts.peakVmemStatistics.getAverage(), + ts.peakVmemStatistics.getMin(), + ts.peakVmemStatistics.getMax()) ); + sb.append(String.format(Locale.US, "peakRss : cnt %d, avr %.3e, min %.3e, max %.3e%n", + ts.peakRssStatistics.getCount(), + ts.peakRssStatistics.getAverage(), + ts.peakRssStatistics.getMin(), + ts.peakRssStatistics.getMax()) ); + sb.append(String.format(Locale.US, "realtime : cnt %d, avr %.1f, min %d, max %d%n", + ts.realtimeStatistics.getCount(), + ts.realtimeStatistics.getAverage(), + ts.realtimeStatistics.getMin(), + ts.realtimeStatistics.getMax()) ); + // @formatter:on + } + + String summary = sb.toString(); + if (baseDir != null) { + Path path = Paths.get(baseDir + "TaskScaler_" + timestamp + ".txt"); + log.debug("save summary to: {}", path); + try { + Files.write(path, summary.getBytes()); + } catch (IOException e) { + log.warn("could not save statistics summary to {}", path); + } + } else { + log.debug("baseDir was not set, could not save summary file"); + } + return summary; + } + + @Data + class TaskSummary { + final String task; + int successCount = 0; + int failCount = 0; + LongSummaryStatistics inputSizeStatistics = new LongSummaryStatistics(); + DoubleSummaryStatistics ramRequestStatitistics = new DoubleSummaryStatistics(); + DoubleSummaryStatistics peakVmemStatistics = new DoubleSummaryStatistics(); + DoubleSummaryStatistics peakRssStatistics = new DoubleSummaryStatistics(); + LongSummaryStatistics realtimeStatistics = new LongSummaryStatistics(); + } +} diff --git a/src/main/java/cws/k8s/scheduler/memory/TaskScaler.java b/src/main/java/cws/k8s/scheduler/memory/TaskScaler.java new file mode 100644 index 0000000..a8ca250 --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/TaskScaler.java @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; +import java.text.NumberFormat; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import cws.k8s.scheduler.client.KubernetesClient; +import cws.k8s.scheduler.model.NodeWithAlloc; +import cws.k8s.scheduler.model.Requirements; +import cws.k8s.scheduler.model.SchedulerConfig; +import cws.k8s.scheduler.model.Task; +import cws.k8s.scheduler.scheduler.Scheduler; +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.Quantity; +import io.fabric8.kubernetes.api.model.ResourceRequirements; +import lombok.extern.slf4j.Slf4j; + +/** + * The TaskScaler offers the interfaces that are used by the Scheduler + * + * It will collect the resource usage results of tasks and change future tasks. + * + * @author Florian Friederici + */ +@Slf4j +public class TaskScaler { + + private static final long LOWEST_MEMORY_REQUEST = 256l*1024*1024; + final KubernetesClient client; + final Scheduler scheduler; + final MemoryPredictor memoryPredictor; + final Statistics statistics; + BigDecimal maxRequest = null; + List blacklist; + private boolean active = true; + + /** + * Create a new TaskScaler instance. The memory predictor to be used is + * determined as follows: + * + * 1) use the value memoryPredictor provided in SchedulerConfig config + * + * 2) if (1) is set to "default", use the environment variable + * MEMORY_PREDICTOR_DEFAULT + * + * 3) if (2) is not set, or unrecognized, use the NonePredictor + * + * @param scheduler the Scheduler that has started this TaskScler + * @param config the SchedulerConfig for the execution + * @param client the associated KubernetesClient + */ + public TaskScaler(Scheduler scheduler, SchedulerConfig config, KubernetesClient client) { + this.client = client; + this.scheduler = scheduler; + String predictor = config.memoryPredictor; + if ("default".equalsIgnoreCase(predictor)) { + predictor = System.getenv("MEMORY_PREDICTOR_DEFAULT"); + } + if (predictor == null) { + predictor = "none"; + } + switch (predictor.toLowerCase()) { + case "constant": + log.debug("using ConstantPredictor"); + this.memoryPredictor = new ConstantPredictor(); + break; + + case "linear": + log.debug("using LinearPredictor"); + this.memoryPredictor = new LinearPredictor(); + break; + + case "combi": + log.debug("using CombiPredictor"); + this.memoryPredictor = new CombiPredictor(); + break; + + case "wary": + log.debug("using WaryPredictor"); + this.memoryPredictor = new WaryPredictor(); + break; + + case "none": + default: + log.debug("using NonePredictor"); + this.memoryPredictor = new NonePredictor(); + } + this.statistics = new Statistics(scheduler,memoryPredictor); + + // blacklist for failed tasks + this.blacklist = new ArrayList<>(); + + // remember the biggest node, as upper bound for memory requests + List allNodes = client.getAllNodes(); + for (NodeWithAlloc n : allNodes) { + Requirements maxRes = n.getMaxResources(); + Requirements availRes = n.getAvailableResources(); + log.debug("node = {}, ram = {}, available = {}", n.getName(), NumberFormat.getNumberInstance(Locale.US).format( maxRes.getRam() ), NumberFormat.getNumberInstance(Locale.US).format(n.getAvailableResources().getRam())); + + if (maxRequest==null || availRes.getRam().compareTo(maxRequest) > 0) { + maxRequest = availRes.getRam(); + } + } + log.info("biggest node has maxRequest = {}", NumberFormat.getNumberInstance(Locale.US).format(maxRequest)); + } + + /** + * After a task was finished, this method shall be called to collect the tasks + * resource usage + * + * @param task + */ + public void afterTaskFinished(Task task) { + if (!active) { + return; + } + BigDecimal peakRss; + BigDecimal peakVmem; + long realtime; + // there is no nextflow trace, when the task failed + if (task.wasSuccessfullyExecuted()) { + peakRss = NfTrace.getNfPeakRss(task); + peakVmem = NfTrace.getNfPeakVmem(task); + realtime = NfTrace.getNfRealTime(task); + } else { + peakRss = BigDecimal.ZERO; + peakVmem = BigDecimal.ZERO; + realtime = 0; + // when a task has failed, we put it on the blacklist, so we will not tamper it again + this.blacklist.add(task.getConfig().getName()); + } + // @formatter:off + Observation o = Observation.builder() + .task( task.getConfig().getTask() ) + .taskName( task.getConfig().getName() ) + .success( task.wasSuccessfullyExecuted() ) + .inputSize( task.getInputSize() ) + .ramRequest( task.getPod().getRequest().getRam() ) + .peakVmem( peakVmem ) + .peakRss( peakRss ) + .realtime( realtime ) + .node( task.getNode().getName() ) + .build(); + // @formatter:on + log.info("taskWasFinished, observation={}", o); + memoryPredictor.addObservation(o); + statistics.addObservation(o); + + // Note: this is a workaround, because the SchedulerConfig does not contain the baseDir + if (statistics.baseDir == null) { + statistics.baseDir = task.getWorkingDir().substring(0, task.getWorkingDir().lastIndexOf("work")); + } + } + + public synchronized void beforeTasksScheduled(final List unscheduledTasks) { + if (!active) { + return; + } + log.debug("--- unscheduledTasks BEGIN ---"); + for (Task t : unscheduledTasks) { + log.debug("1 unscheduledTask: {} {} {}", t.getConfig().getTask(), t.getConfig().getName(), + t.getPod().getRequest()); + + // if task is already blacklisted, don't touch it again + if (this.blacklist.contains(t.getConfig().getName())) { + continue; + } + + // if task had no memory request set, it cannot be changed + BigDecimal taskRequest = t.getPod().getRequest().getRam(); + if (taskRequest.compareTo(BigDecimal.ZERO) == 0) { + log.info("cannot change task {}, because it had no prior requirements", t.toString()); + continue; + } + + BigDecimal newRequestValue = null; + + // sanity check for Nextflow provided value + if (taskRequest.compareTo(this.maxRequest) > 0) { + // this would never get scheduled and CWS will get stuck, so we take the liberty to lower the value + newRequestValue = this.maxRequest.subtract(BigDecimal.valueOf(1l*1024*1024)); + log.warn("nextflow request exceeds maximal cluster allocatable capacity, request was reduced by TaskScaler"); + } + + // query suggestion + BigDecimal prediction = memoryPredictor.queryPrediction(t); + + // sanity check for our prediction + if (prediction != null && prediction.compareTo(maxRequest) < 0) { + // we have a prediction and it fits into the cluster + newRequestValue = prediction; + log.debug("predictor proposes {} for task {}", prediction, t.getConfig().getName()); + + // if our prediction is a very low value, the pod might not start. Make sure it has at least 256MiB + BigDecimal lowestRequest = BigDecimal.valueOf(LOWEST_MEMORY_REQUEST); + if (newRequestValue.compareTo(lowestRequest) < 0) { + log.debug("Prediction of {} is lower than {}. Automatically increased.", newRequestValue, lowestRequest); + newRequestValue = lowestRequest; + } + } + + if (newRequestValue != null) { + log.info("resizing {} to {} bytes", t.getConfig().getName(), newRequestValue.toPlainString()); + // 1. patch Kubernetes value + this.active = client.patchTaskMemory(t, newRequestValue.toPlainString()); + + // 2. patch CWS value + List l = t.getPod().getSpec().getContainers(); + for (Container c : l) { + ResourceRequirements req = c.getResources(); + Map limits = req.getLimits(); + limits.replace("memory", new Quantity(newRequestValue.toPlainString())); + Map requests = req.getRequests(); + requests.replace("memory", new Quantity(newRequestValue.toPlainString())); + log.debug("container: {}", req); + } + + log.debug("2 unscheduledTask: {} {} {}", t.getConfig().getTask(), t.getConfig().getName(), + t.getPod().getRequest()); + } + + } + log.debug("--- unscheduledTasks END ---"); + } + + public void afterWorkflow() { + if (!active) { + return; + } + log.debug("afterWorkflow"); + long timestamp = System.currentTimeMillis(); + statistics.end = timestamp; + log.info(statistics.summary(timestamp)); + log.debug(statistics.exportCsv(timestamp)); + } + + /** + * This helper checks observations for sanity. + * + * @return true is the Observation looks sane, false otherwise + */ + public static boolean checkObservationSanity(Observation o) { + if (o.task == null || o.taskName == null || o.success == null || o.ramRequest == null || o.peakRss == null) { + log.error("unexpected null value in observation"); + return false; + } + if (o.inputSize < 0) { + log.error("{}: inputSize may not be negative", o.taskName); + return false; + } + if (o.ramRequest.compareTo(BigDecimal.ZERO) < 0) { + log.error("{}: ramRequest may not be negative", o.taskName); + return false; + } + + // we don't trust the observation of the realtime was that low + if (o.realtime == 0) { + log.warn("{}: realtime was zero, suspicious observation", o.taskName); + return false; + } + + // those are indicators that the .command.trace read has failed + if (o.peakRss.compareTo(BigDecimal.ZERO) < 0) { + log.warn("{}: peakRss may not be negative (has the .command.trace read failed?)", o.taskName); + return false; + } + if (o.peakRss.compareTo(BigDecimal.ZERO) < 0) { + log.warn("{}: peakRss may not be negative (has the .command.trace read failed?)", o.taskName); + return false; + } + if (o.realtime < 0) { + log.warn("{}: realtime may not be negative (has the .command.trace read failed?)", o.taskName); + return false; + } + return true; + } + +} diff --git a/src/main/java/cws/k8s/scheduler/memory/WaryPredictor.java b/src/main/java/cws/k8s/scheduler/memory/WaryPredictor.java new file mode 100644 index 0000000..d746e9d --- /dev/null +++ b/src/main/java/cws/k8s/scheduler/memory/WaryPredictor.java @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.stat.regression.SimpleRegression; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +//@formatter:off +/** +* WaryPredictor will use the following strategy: +* +* If there are less than 4 observations, give no prediction, else: +* Calculate linear regression model and test if all observations would fit into +* the model. If all past observations fit into the model, give a prediction. +* If the model does not fit the past observations, provide initial value. +* +* Predictions start with 10% over-provisioning. If tasks fail, this will +* increase automatically. +* +* WaryPredictor will never exceed the initial value. +* +* @author Florian Friederici +* +*/ +//@formatter:on +@Slf4j +public class WaryPredictor implements MemoryPredictor { + + Map model; + Map overprovisioning; + Map>> observations; + Map errorCounter; + Map initialValue; + Map> ignoreList; + Map lowestSuccess; + + public WaryPredictor() { + model = new HashMap<>(); + overprovisioning = new HashMap<>(); + observations = new HashMap<>(); + errorCounter = new HashMap<>(); + initialValue = new HashMap<>(); + ignoreList = new HashMap<>(); + lowestSuccess = new HashMap<>(); + } + + @Override + public void addObservation(Observation o) { + log.debug("WaryPredictor.addObservation({})", o); + + // nextflow will only retry once, so if the task failed, we will add + // it to our ignore list, so that it wont fail twice + if (!Boolean.TRUE.equals(o.success)) { + if (!ignoreList.containsKey(o.task)) { + ignoreList.put(o.task, new ArrayList<>()); + } + ignoreList.get(o.task).add(o.taskName); + } + + if (!TaskScaler.checkObservationSanity(o)) { + log.warn("dismiss observation {}", o); + return; + } + + // store initial ramRequest value per task + if (!initialValue.containsKey(o.task)) { + initialValue.put(o.task, o.getRamRequest()); + } + + if (!overprovisioning.containsKey(o.task)) { + overprovisioning.put(o.task, 1.1); + } + + if (!errorCounter.containsKey(o.task)) { + errorCounter.put(o.task, 0); + } + + if (Boolean.TRUE.equals(o.success)) { + + if (!observations.containsKey(o.task)) { + observations.put(o.task, new ArrayList<>()); + } + if (!model.containsKey(o.task)) { + model.put(o.task, new SimpleRegression()); + } + + double x = o.getInputSize(); + double y = o.getPeakVmem().doubleValue(); + + lowestSuccess.put(o.task, BigDecimal.valueOf(y)); + observations.get(o.task).add(Pair.of(x, y)); + model.get(o.task).addData(x,y); + } else { + Integer errors = errorCounter.get(o.task); + errorCounter.put(o.task, 1+errors); + log.debug("overprovisioning value will increase due to task failure, errors: {}", 1+errors); + Double old = overprovisioning.get(o.task); + overprovisioning.put(o.task, old+0.05); + } + } + + @Override + public BigDecimal queryPrediction(Task task) { + String taskName = task.getConfig().getTask(); + log.debug("WaryPredictor.queryPrediction({},{})", taskName, task.getInputSize()); + + // check ignore list first + if (ignoreList.containsKey(taskName) && (ignoreList.get(taskName).contains(task.getConfig().getName()))) { + log.debug("{} is on the ignore list", task.getConfig().getName()); + return null; + } + + if (!model.containsKey(taskName)) { + log.debug("WaryPredictor has no model for {}", taskName); + return null; + } + + if (2 < errorCounter.get(taskName)) { + log.warn("to many errors for {}, providing initial value", taskName); + return initialValue.get(taskName); + } + + SimpleRegression simpleRegression = model.get(taskName); + + if (simpleRegression.getN() < 4) { + log.debug("Not enough observations for {}", taskName); + return null; + } + + // would the model match the past successful observations? + List> observationList = observations.get(taskName); + for (Pair o : observationList) { + double p = simpleRegression.predict(o.getLeft()); + double op = overprovisioning.get(taskName); + if ( (p*op) < o.getRight() ) { + // The model predicted value would have been smaller then the + // observed value. Our model is not (yet) appropriate. + // Increase overprovisioning + log.debug("overprovisioning value will increase due to model mismatch"); + Double old = overprovisioning.get(taskName); + overprovisioning.put(taskName, old+0.05); + // Don't make a prediction this time + return null; + } + } + + double prediction = simpleRegression.predict(task.getInputSize()); + + if (Double.isNaN(prediction)) { + log.debug("No prediction possible for {}", taskName); + return null; + } + + if (prediction < 0) { + log.warn("prediction would be negative: {}", prediction); + return null; + } + + if (prediction > initialValue.get(taskName).doubleValue()) { + log.warn("prediction would exceed initial value"); + return initialValue.get(taskName); + } + + // this catches if the model underestimates the behavior + if (prediction < lowestSuccess.get(taskName).doubleValue()) { + log.info("prediction would be lower than the lowest known successful value"); + return null; + } + + return BigDecimal.valueOf(prediction).multiply(BigDecimal.valueOf(overprovisioning.get(taskName))).setScale(0, RoundingMode.CEILING); + } + + +} diff --git a/src/main/java/cws/k8s/scheduler/model/SchedulerConfig.java b/src/main/java/cws/k8s/scheduler/model/SchedulerConfig.java index c085074..c174767 100644 --- a/src/main/java/cws/k8s/scheduler/model/SchedulerConfig.java +++ b/src/main/java/cws/k8s/scheduler/model/SchedulerConfig.java @@ -20,6 +20,7 @@ public class SchedulerConfig { public final String costFunction; public final String strategy; public final Map additional; + public final String memoryPredictor; @ToString @NoArgsConstructor(access = AccessLevel.PRIVATE,force = true) diff --git a/src/main/java/cws/k8s/scheduler/scheduler/Scheduler.java b/src/main/java/cws/k8s/scheduler/scheduler/Scheduler.java index f792d92..c7d5e68 100644 --- a/src/main/java/cws/k8s/scheduler/scheduler/Scheduler.java +++ b/src/main/java/cws/k8s/scheduler/scheduler/Scheduler.java @@ -1,6 +1,7 @@ package cws.k8s.scheduler.scheduler; import cws.k8s.scheduler.dag.DAG; +import cws.k8s.scheduler.memory.TaskScaler; import cws.k8s.scheduler.model.*; import cws.k8s.scheduler.util.Batch; import cws.k8s.scheduler.client.Informable; @@ -19,6 +20,8 @@ import java.io.IOException; import java.util.*; +import org.springframework.util.StringUtils; + @Slf4j public abstract class Scheduler implements Informable { @@ -52,6 +55,9 @@ public abstract class Scheduler implements Informable { final boolean traceEnabled; + // TaskScaler will observe tasks and modify their memory assignments + final TaskScaler taskScaler; + Scheduler(String execution, KubernetesClient client, String namespace, SchedulerConfig config){ this.execution = execution; this.name = System.getenv( "SCHEDULER_NAME" ) + "-" + execution; @@ -73,6 +79,13 @@ public abstract class Scheduler implements Informable { log.info("Start watching"); watcher = client.pods().inNamespace( this.namespace ).watch(podWatcher); log.info("Watching"); + + if ( StringUtils.hasText(config.memoryPredictor) ) { + // create a new TaskScaler for each Scheduler instance + taskScaler = new TaskScaler(this, config, client); + } else { + taskScaler = null; + } } /* Abstract methods */ @@ -85,6 +98,12 @@ public int schedule( final List unscheduledTasks ) { if( traceEnabled ) { unscheduledTasks.forEach( x -> x.getTraceRecord().tryToSchedule( startSchedule ) ); } + + if (taskScaler!=null) { + // change memory resource requests and limits here + taskScaler.beforeTasksScheduled(unscheduledTasks); + } + final ScheduleObject scheduleObject = getTaskNodeAlignment(unscheduledTasks, getAvailableByNode()); final List taskNodeAlignment = scheduleObject.getTaskAlignments(); @@ -185,6 +204,11 @@ void taskWasFinished( Task task ){ unfinishedTasks.remove( task ); } task.getState().setState(task.wasSuccessfullyExecuted() ? State.FINISHED : State.FINISHED_WITH_ERROR); + + if (taskScaler!=null) { + // this will collect the result of the task execution for future scaling + taskScaler.afterTaskFinished(task); + } } public void schedulePod(PodWithAge pod ) { @@ -481,6 +505,11 @@ public void close(){ schedulingThread.interrupt(); finishThread.interrupt(); this.close = true; + + if (taskScaler!=null) { + // save statistics after the workflow is completed + taskScaler.afterWorkflow(); + } } static class PodWatcher implements Watcher { diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 45e54b6..681fda6 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -2,3 +2,18 @@ spring: mvc: pathmatch: matching-strategy: ant-path-matcher +logging: + file: + path: "." + +--- + +spring: + config: + activate: + on-profile: "dev" +logging: + pattern: + console: "%clr(%d{HH:mm:ss.SSS}){faint} %clr(%-5level) %clr(%-30.30logger{29}){cyan} %msg%n" + level: + "[cws.k8s.scheduler]": TRACE diff --git a/src/test/java/cws/k8s/scheduler/memory/CombiPredictorTest.java b/src/test/java/cws/k8s/scheduler/memory/CombiPredictorTest.java new file mode 100644 index 0000000..0267a48 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/CombiPredictorTest.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import org.junit.Test; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for the CombiPredictor + * + * @author Florian Friederici + * + */ +@Slf4j +public class CombiPredictorTest { + + /** + * If there are no observations, we cannot get a prediction + * + */ + @Test + public void testNoObservationsYet() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + CombiPredictor combiPredictor = new CombiPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 0l); + assertNull(combiPredictor.queryPrediction(task)); + } + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/ConstantPredictorTest.java b/src/test/java/cws/k8s/scheduler/memory/ConstantPredictorTest.java new file mode 100644 index 0000000..1670a25 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/ConstantPredictorTest.java @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import java.math.BigDecimal; + +import org.junit.Test; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for the ConstantPredictor + * + * @author Florian Friederici + * + */ +@Slf4j +public class ConstantPredictorTest { + + /** + * If there are no observations, we cannot get a prediction + * + */ + @Test + public void testNoObservationsYet() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 0l); + assertNull(constantPredictor.queryPrediction(task)); + } + + /** + * If there is one observation, we get a prediction + * + */ + @Test + public void testOneObservation() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 0l); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + constantPredictor.addObservation(observation); + assertNotNull(constantPredictor.queryPrediction(task)); + } + + /** + * If there are two observations, we will also get a prediction + */ + @Test + public void testTwoObservations() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 0l); + // @formatter:off + Observation observation1 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + Observation observation2 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + constantPredictor.addObservation(observation1); + constantPredictor.addObservation(observation2); + assertNotNull(constantPredictor.queryPrediction(task)); + } + + /** + * The prediction decreases right after one observation + * + */ + @Test + public void testDecreasePredictionAfterOneObservation() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 0l); + + BigDecimal reserved = BigDecimal.valueOf(4l * 1024 * 1024 * 1024); + BigDecimal used = BigDecimal.valueOf(2l * 1024 * 1024 * 1024); + + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(reserved) + .peakRss(used) + .realtime(1000) + .build(); + // @formatter:on + constantPredictor.addObservation(observation); + BigDecimal suggestion = constantPredictor.queryPrediction(task); + log.debug("suggestion is: {}", suggestion); + // 1. There is a suggestion at all + assertNotNull(suggestion); + // 2. The suggestion is lower than the reserved value was + assertTrue(suggestion.compareTo(reserved) < 0); + // 3. The suggestion is higher than the used value was + assertTrue(suggestion.compareTo(used) > 0); + } + + /** + * The prediction decreases further, if another successful observation is made + * + */ + @Test + public void testDecreasePredictionAfterMultipleObservations() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + + BigDecimal suggestion1 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, + BigDecimal.valueOf(4l * 1024 * 1024 * 1024), BigDecimal.valueOf(2l * 1024 * 1024 * 1024)); + BigDecimal suggestion2 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, + BigDecimal.valueOf(4l * 1024 * 1024 * 1024), BigDecimal.valueOf(2l * 1024 * 1024 * 1024)); + assertTrue(suggestion1.compareTo(suggestion2) >= 0); + + BigDecimal suggestion3 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, + BigDecimal.valueOf(4l * 1024 * 1024 * 1024), BigDecimal.valueOf(2l * 1024 * 1024 * 1024)); + assertTrue(suggestion2.compareTo(suggestion3) >= 0); + + BigDecimal suggestion4 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, + BigDecimal.valueOf(4l * 1024 * 1024 * 1024), BigDecimal.valueOf(2l * 1024 * 1024 * 1024)); + assertTrue(suggestion3.compareTo(suggestion4) >= 0); + } + + /** + * When the Task failed, increase the prediction already after one observation + * + */ + @Test + public void testIncreasePredictionAfterFailure() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + + BigDecimal reserved = BigDecimal.valueOf(4l * 1024 * 1024 * 1024); + BigDecimal used = reserved.add(BigDecimal.ONE); + + BigDecimal suggestion = MemoryPredictorTest.createTaskObservationFailurePrediction(constantPredictor, reserved, used); + log.info("reserved : {})", reserved); + log.info("used : {})", used); + log.info("suggestion is: {})", suggestion.toPlainString()); + } + + /** + * When the Task failed after some successful observations, increase + * + */ + @Test + public void testIncreasePredictionAfterSuccessAndFailure() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + ConstantPredictor constantPredictor = new ConstantPredictor(); + + BigDecimal reserved = BigDecimal.valueOf(4l * 1024 * 1024 * 1024); + BigDecimal usedSucc = BigDecimal.valueOf(2l * 1024 * 1024 * 1024); + BigDecimal usedFail = reserved; + + BigDecimal suggestion1 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, reserved, usedSucc); + log.info("reserved : {}", reserved); + log.info("usedSucc : {}", usedSucc); + log.info("suggestion1 is: {}", suggestion1); + assertTrue(suggestion1.compareTo(reserved) < 0); + + BigDecimal suggestion2 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, suggestion1, usedSucc); + log.info("reserved : {}", suggestion1); + log.info("usedSucc : {}", usedSucc); + log.info("suggestion2 is: {}", suggestion2); + assertTrue(suggestion2.compareTo(suggestion1) <= 0); + + BigDecimal suggestion3 = MemoryPredictorTest.createTaskObservationFailurePrediction(constantPredictor, suggestion2, usedFail); + log.info("reserved : {}", suggestion2); + log.info("usedFail : {}", usedFail); + log.info("suggestion3 is: {}", suggestion3); + assertTrue(suggestion3.compareTo(suggestion2) > 0); + + BigDecimal suggestion4 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, suggestion3, usedSucc); + log.info("reserved : {}", suggestion3); + log.info("usedSucc : {}", usedSucc); + log.info("suggestion4 is: {}", suggestion4); + assertTrue(suggestion4.compareTo(suggestion3) < 0); + + BigDecimal suggestion5 = MemoryPredictorTest.createTaskObservationSuccessPrediction(constantPredictor, suggestion4, usedSucc); + log.info("reserved : {}", suggestion4); + log.info("usedSucc : {}", usedSucc); + log.info("suggestion5 is: {}", suggestion5); + assertTrue(suggestion5.compareTo(suggestion4) <= 0); + } + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/LinearPredictorTest.java b/src/test/java/cws/k8s/scheduler/memory/LinearPredictorTest.java new file mode 100644 index 0000000..66975c8 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/LinearPredictorTest.java @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import java.math.BigDecimal; + +import org.junit.Test; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for the LinearPredictor + * + * @author Florian Friederici + * + */ +@Slf4j +public class LinearPredictorTest { + + /** + * If there are no observations, we cannot get a prediction + */ + @Test + public void testNoObservationsYet() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + LinearPredictor linearPredictor = new LinearPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + assertNull(linearPredictor.queryPrediction(task)); + } + + /** + * If there is only one observation, we cannot get a prediction either + */ + @Test + public void testOneObservation() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + LinearPredictor linearPredictor = new LinearPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + linearPredictor.addObservation(observation); + assertNull(linearPredictor.queryPrediction(task)); + } + + /** + * If there are two observations, we can get a first prediction + */ + @Test + public void testTwoObservations() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + LinearPredictor linearPredictor = new LinearPredictor(); + // @formatter:off + Observation observation1 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(1024l) + .ramRequest(BigDecimal.valueOf(4l*1024*1024*1024)) + .peakRss(BigDecimal.valueOf(1l*1024*1024*1024)) + .realtime(1000) + .build(); + Observation observation2 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(2048) + .ramRequest(BigDecimal.valueOf(4l*1024*1024*1024)) + .peakRss(BigDecimal.valueOf(2l*1024*1024*1024)) + .realtime(1000) + .build(); + // @formatter:on + linearPredictor.addObservation(observation1); + linearPredictor.addObservation(observation2); + + Task task1 = MemoryPredictorTest.createTask("taskName", 512l); + BigDecimal suggestion1 = linearPredictor.queryPrediction(task1); + assertNotNull(suggestion1); + log.info("suggestion 1 is: {}", suggestion1); + + Task task2 = MemoryPredictorTest.createTask("taskName", 1024l); + BigDecimal suggestion2 = linearPredictor.queryPrediction(task2); + assertNotNull(suggestion2); + log.info("suggestion 2 is: {}", suggestion2); + assertTrue(suggestion2.compareTo(suggestion1) > 0); + + Task task3 = MemoryPredictorTest.createTask("taskName", 1536l); + BigDecimal suggestion3 = linearPredictor.queryPrediction(task3); + assertNotNull(suggestion3); + log.info("suggestion 3 is: {}", suggestion3); + assertTrue(suggestion3.compareTo(suggestion2) > 0); + + Task task4 = MemoryPredictorTest.createTask("taskName", 2048l); + BigDecimal suggestion4 = linearPredictor.queryPrediction(task4); + assertNotNull(suggestion4); + log.info("suggestion 4 is: {}", suggestion4); + assertTrue(suggestion4.compareTo(suggestion3) > 0); + + Task task5 = MemoryPredictorTest.createTask("taskName", 4096l); + BigDecimal suggestion5 = linearPredictor.queryPrediction(task5); + assertNotNull(suggestion5); + log.info("suggestion 5 is: {}", suggestion5); + assertTrue(suggestion5.compareTo(suggestion4) > 0); + } + + /** + * Test that predictions cannot get negative + */ + @Test + public void testNoNegativePredicitons() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + LinearPredictor linearPredictor = new LinearPredictor(); + // @formatter:off + Observation observation1 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(3) + .ramRequest(BigDecimal.valueOf(3)) + .peakRss(BigDecimal.valueOf(3)) + .realtime(1000) + .build(); + Observation observation2 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(2) + .ramRequest(BigDecimal.valueOf(1)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + linearPredictor.addObservation(observation1); + linearPredictor.addObservation(observation2); + + Task task1 = MemoryPredictorTest.createTask("taskName", 1); + BigDecimal suggestion1 = linearPredictor.queryPrediction(task1); + assertNull(suggestion1); + } + + /** + * test for observation with success = false + */ + @Test + public void testObservationFailed() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + LinearPredictor linearPredictor = new LinearPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(false) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + linearPredictor.addObservation(observation); + assertNull(linearPredictor.queryPrediction(task)); + linearPredictor.addObservation(observation); + assertNull(linearPredictor.queryPrediction(task)); + linearPredictor.addObservation(observation); + assertNull(linearPredictor.queryPrediction(task)); + } + + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/MemoryPredictorTest.java b/src/test/java/cws/k8s/scheduler/memory/MemoryPredictorTest.java new file mode 100644 index 0000000..e45f054 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/MemoryPredictorTest.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; +import org.springframework.test.util.ReflectionTestUtils; + +import cws.k8s.scheduler.dag.DAG; +import cws.k8s.scheduler.dag.Process; +import cws.k8s.scheduler.dag.Vertex; +import cws.k8s.scheduler.model.Task; +import cws.k8s.scheduler.model.TaskConfig; +import lombok.extern.slf4j.Slf4j; + +/** + * Common methods for all MemoryPredictor Tests + * + * @author Florian Friederici + * + */ +@Slf4j +public class MemoryPredictorTest { + + /** + * Helper that creates tasks for the tests + * + * Note: There a two fields that contain the name of the task within the + * taskConfig. The first one, taskConfig.task, contains the process name from + * Nextflow. The second one, taskConfig.name, has a number added. + * + * @return the newly created Task + */ + static Task createTask(String name, long inputSize) { + TaskConfig taskConfig = new TaskConfig(name); + ReflectionTestUtils.setField(taskConfig, "name", name + " (1)"); + DAG dag = new DAG(); + List processes = Arrays.asList(new Process(name, 0)); + dag.registerVertices(processes); + Task task = new Task(taskConfig, dag); + ReflectionTestUtils.setField(task, "inputSize", inputSize); + return task; + } + + static Task createTask(String name, long inputSize, int number) { + TaskConfig taskConfig = new TaskConfig(name); + ReflectionTestUtils.setField(taskConfig, "name", name + " ("+number+")"); + DAG dag = new DAG(); + List processes = Arrays.asList(new Process(name, 0)); + dag.registerVertices(processes); + Task task = new Task(taskConfig, dag); + ReflectionTestUtils.setField(task, "inputSize", inputSize); + return task; + } + + /** + * Execute observationSanityCheck on all predictors + * + */ + @Test + public void testSanityChecksOnAllPredictors() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + NonePredictor nonePredictor = new NonePredictor(); + observationSanityCheck(nonePredictor); + + ConstantPredictor constantPredictor = new ConstantPredictor(); + observationSanityCheck(constantPredictor); + + LinearPredictor linearPredictor = new LinearPredictor(); + observationSanityCheck(linearPredictor); + } + + /** + * A runtime exception is thrown, when the observation values look suspicious. + * No suggestion will be available then. + */ + void observationSanityCheck(MemoryPredictor memoryPredictor) { + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + // @formatter:off + Observation observation1 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(-1) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(0)) + .realtime(1000) + .build(); + // @formatter:on + + memoryPredictor.addObservation(observation1); + assertNull( memoryPredictor.queryPrediction(task) ); + } + + /** + * Helper to create Task and Observation and insert them into memoryPredictor + * for successful tasks + * + * @return prediction value + */ + static BigDecimal createTaskObservationSuccessPrediction(MemoryPredictor memoryPredictor, BigDecimal reserved, BigDecimal used) { + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(reserved) + .peakRss(used) + .realtime(1000) + .build(); + // @formatter:on + memoryPredictor.addObservation(observation); + BigDecimal suggestion = memoryPredictor.queryPrediction(task); + log.debug("suggestion is: {}", suggestion); + // 1. There is a suggestion at all + assertNotNull(suggestion); + // 2. The suggestion is leq than the reserved value was + assertTrue(suggestion.compareTo(reserved) <= 0); + // 3. The suggestion is higher than the used value was + assertTrue(suggestion.compareTo(used) > 0); + return suggestion; + } + + /** + * Helper to create Task and Observation and insert them into memoryPredictor + * for unsuccessful tasks + * + * @return prediction value + */ + static BigDecimal createTaskObservationFailurePrediction(MemoryPredictor memoryPredictor, BigDecimal reserved, BigDecimal used) { + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(false) + .inputSize(0) + .ramRequest(reserved) + .peakRss(used) + .realtime(1000) + .build(); + // @formatter:on + memoryPredictor.addObservation(observation); + BigDecimal suggestion = memoryPredictor.queryPrediction(task); + log.info("suggestion is: {}", suggestion); + // 1. There is a suggestion at all + assertNotNull(suggestion); + // 2. The suggestion is higher than the reserved value was + assertTrue(suggestion.compareTo(reserved) > 0); + // 3. The suggestion is higher than the used value was + log.info("assert {} >= {}", suggestion, used); + assertTrue(suggestion.compareTo(used) >= 0); + return suggestion; + } +} diff --git a/src/test/java/cws/k8s/scheduler/memory/NfTraceTest.java b/src/test/java/cws/k8s/scheduler/memory/NfTraceTest.java new file mode 100644 index 0000000..7829d92 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/NfTraceTest.java @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.Test; +import org.springframework.test.util.ReflectionTestUtils; + +import cws.k8s.scheduler.model.Task; +import cws.k8s.scheduler.model.TaskConfig; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for the NfTrace utility class + * + * @author Florian Friederici + * + */ +@Slf4j +public class NfTraceTest { + + String exampleTrace = "nextflow.trace/v2\n" + + "realtime=30090\n" + + "%cpu=812\n" + + "cpu_model=Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz\n" + + "rchar=79901\n" + + "wchar=427\n" + + "syscr=276\n" + + "syscw=22\n" + + "read_bytes=380928\n" + + "write_bytes=0\n" + + "%mem=71\n" + + "vmem=728216\n" + + "rss=581672\n" + + "peak_vmem=728216\n" + + "peak_rss=593192\n" + + "vol_ctxt=1906\n" + + "inv_ctxt=11615"; + + /** + * mock a Task with the trace string as trace file + * + * @param trace the contents for the trace file, if null no file will be created + * @return the mocked Task + * @throws IOException if file i/o goes wrong + */ + private Task mockTask(String trace) throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + Task task = MemoryPredictorTest.createTask("taskName", 0); + TaskConfig taskConfig = (TaskConfig)ReflectionTestUtils.getField(task, "config"); + Path tmpdir = Files.createTempDirectory("unittest."); + tmpdir.toFile().deleteOnExit(); + ReflectionTestUtils.setField(taskConfig, "workDir", tmpdir.toFile().getAbsolutePath()); + if (trace != null) { + String filename = ".command.trace"; + Path path = Paths.get(tmpdir.toFile().getAbsolutePath() + File.separator + filename); + Files.write(path, trace.getBytes()); + } + return task; + } + + /** + * Positive test case for NfTrace.getNfPeakRss + * + * @throws IOException + */ + @Test + public void testGetNfPeakRss() throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + Task task = mockTask(exampleTrace); + log.info("workdir: {}", task.getWorkingDir()); + + BigDecimal peakRss = NfTrace.getNfPeakRss(task); + log.info("" + peakRss); + assertEquals(0,peakRss.compareTo(new BigDecimal("607428608"))); + } + + /** + * Negative test case for NfTrace.getNfPeakRss + * When the trace file is missing, we expect BigDecimal.ZERO as peakRss + * + * @throws IOException + */ + @Test + public void testGetNfPeakRssMissingFile() throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + Task task = mockTask(null); + log.info("workdir: {}", task.getWorkingDir()); + + BigDecimal peakRss = NfTrace.getNfPeakRss(task); + log.info("" + peakRss); + assertEquals(0,peakRss.compareTo(BigDecimal.valueOf(-1))); + } + + /** + * Positive test case for NfTrace.getNfRealTime + * + * @throws IOException + */ + @Test + public void testGetNfRealTime() throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + Task task = mockTask(exampleTrace); + log.info("workdir: {}", task.getWorkingDir()); + + long realtime = NfTrace.getNfRealTime(task); + log.info("" + realtime); + assertEquals(30090,realtime); + } + + /** + * Negative test case for NfTrace.getNfRealTime + * + * @throws IOException + */ + @Test + public void testGetNfRealTimeMissingFile() throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + Task task = mockTask(null); + log.info("workdir: {}", task.getWorkingDir()); + + long realtime = NfTrace.getNfRealTime(task); + log.info("" + realtime); + assertEquals(-1,realtime); + } + + /** + * Positive test case for NfTrace.getNfPeakVmem + * + * @throws IOException + */ + @Test + public void testGetNfPeakVmem() throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + Task task = mockTask(exampleTrace); + log.info("workdir: {}", task.getWorkingDir()); + + BigDecimal peakVmem = NfTrace.getNfPeakVmem(task); + log.info("" + peakVmem); + assertEquals(0,peakVmem.compareTo(BigDecimal.valueOf(745693184))); + } + + /** + * Negative test case for NfTrace.getNfPeakVmem + * + * @throws IOException + */ + @Test + public void testGetNfPeakVmemMissingFile() throws IOException { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + Task task = mockTask(null); + log.info("workdir: {}", task.getWorkingDir()); + + BigDecimal peakVmem = NfTrace.getNfPeakVmem(task); + log.info("" + peakVmem); + assertEquals(0,peakVmem.compareTo(BigDecimal.valueOf(-1))); + } + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/NonePredictorTest.java b/src/test/java/cws/k8s/scheduler/memory/NonePredictorTest.java new file mode 100644 index 0000000..abfaf35 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/NonePredictorTest.java @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import java.math.BigDecimal; + +import org.junit.Test; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for the NonePredictor + * + * @author Florian Friederici + */ +@Slf4j +public class NonePredictorTest { + + /** + * NonePredictor shall never give a prediction. Test 1: when no observations + * were inserted + */ + @Test + public void testNoObservationsYet() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + NonePredictor nonePredictor = new NonePredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + assertNull(nonePredictor.queryPrediction(task)); + } + + /** + * NonePredictor shall never give a prediction. Test 2: when one observation was + * inserted + */ + @Test + public void testOneObservation() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + NonePredictor nonePredictor = new NonePredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + nonePredictor.addObservation(observation); + assertNull(nonePredictor.queryPrediction(task)); + } + + /** + * NonePredictor shall never give a prediction. Test 3: when two observations + * were inserted + */ + @Test + public void testTwoObservations() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + NonePredictor nonePredictor = new NonePredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + // @formatter:off + Observation observation1 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + Observation observation2 = Observation.builder() + .task("taskName") + .taskName("taskName (1)") + .success(true) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(0)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + nonePredictor.addObservation(observation1); + nonePredictor.addObservation(observation2); + assertNull(nonePredictor.queryPrediction(task)); + } + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/ObservationTest.java b/src/test/java/cws/k8s/scheduler/memory/ObservationTest.java new file mode 100644 index 0000000..2634632 --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/ObservationTest.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import org.junit.Test; + +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for Observation Data Class and related utilities + * + * @author Florian Friederici + * + */ +@Slf4j +public class ObservationTest { + + @Test + public void testConstructor() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + Observation o1 = new Observation(null, null, null, 0, null, null, null, 0, null); + Observation o2 = Observation.builder() + .build(); + assertEquals(o1, o2); + } + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/StatisticsTest.java b/src/test/java/cws/k8s/scheduler/memory/StatisticsTest.java new file mode 100644 index 0000000..cc938bc --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/StatisticsTest.java @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.Assert.*; + +import java.math.BigDecimal; + +import org.junit.Test; +import org.mockito.Mockito; + +import cws.k8s.scheduler.scheduler.Scheduler; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 4 Tests for the Statistics + * + * @author Florian Friederici + * + */ +@Slf4j +public class StatisticsTest { + + private Statistics mockStatistics() { + Scheduler scheduler = Mockito.mock(Scheduler.class); + MemoryPredictor memoryPredictor = new NonePredictor(); + Statistics statistics = new Statistics(scheduler, memoryPredictor); + return statistics; + } + + /** + * If no observations are inserted, csv must be empty + */ + @Test + public void testNoObservations() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + Statistics statistics = mockStatistics(); + + String csv = statistics.exportCsv(0); + log.info(csv); + assertEquals("task,taskName,success,inputSize,ramRequest,peakVmem,peakRss,realtime,node\n", csv); + + String summary = statistics.summary(0); + log.info(summary); + } + + /** + * Test if the observation values are inserted correctly + * + */ + @Test + public void testSingleObservations() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + Statistics statistics = mockStatistics(); + + Observation o = Observation.builder() + .task("task") + .taskName("taskName") + .success(true) + .inputSize(123l) + .ramRequest(BigDecimal.valueOf(234l)) + .peakVmem(BigDecimal.valueOf(345l)) + .peakRss(BigDecimal.valueOf(456l)) + .realtime(678l) + .node("testnode") + .build(); + statistics.addObservation(o); + + String csv = statistics.exportCsv(0); + log.info(csv); + assertEquals("task,taskName,success,inputSize,ramRequest,peakVmem,peakRss,realtime,node\n" + +"task,taskName,true,123,234,345,456,678,testnode\n", csv); + + String summary = statistics.summary(0); + log.info(summary); + } + + /** + * Test if Summary Report is correct + * + */ + @Test + public void testSummary() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + Statistics statistics = mockStatistics(); + + Observation o1 = Observation.builder() + .task("task") + .taskName("taskName") + .success(true) + .inputSize(1l) + .ramRequest(BigDecimal.valueOf(2l)) + .peakVmem(BigDecimal.valueOf(3l)) + .peakRss(BigDecimal.valueOf(4l)) + .realtime(6l) + .node("testnode") + .build(); + statistics.addObservation(o1); + + Observation o2 = Observation.builder() + .task("task") + .taskName("taskName") + .success(true) + .inputSize(2l) + .ramRequest(BigDecimal.valueOf(3l)) + .peakVmem(BigDecimal.valueOf(4l)) + .peakRss(BigDecimal.valueOf(5l)) + .realtime(7l) + .node("testnode") + .build(); + statistics.addObservation(o2); + + Observation o3 = Observation.builder() + .task("task") + .taskName("taskName") + .success(true) + .inputSize(3l) + .ramRequest(BigDecimal.valueOf(4l)) + .peakVmem(BigDecimal.valueOf(5l)) + .peakRss(BigDecimal.valueOf(6l)) + .realtime(8l) + .node("testnode") + .build(); + statistics.addObservation(o3); + + String csv = statistics.exportCsv(0); + log.info(csv); + assertEquals("task,taskName,success,inputSize,ramRequest,peakVmem,peakRss,realtime,node\n" + + "task,taskName,true,1,2,3,4,6,testnode\n" + + "task,taskName,true,2,3,4,5,7,testnode\n" + + "task,taskName,true,3,4,5,6,8,testnode\n", csv); + + String summary = statistics.summary(0); + log.info(summary); + + String reference = " total observations collected: 3\n" + + " different tasks: 1\n" + + " -- task: 'task' --\n" + + " named instances of 'task' seen: 1\n" + + " success count: 3\n" + + " failure count: 0\n" + + "inputSize : cnt 3, avr 2.0, min 1, max 3\n" + + "ramRequest : cnt 3, avr 3.000e+00, min 2.000e+00, max 4.000e+00\n" + + "peakVmem : cnt 3, avr 4.000e+00, min 3.000e+00, max 5.000e+00\n" + + "peakRss : cnt 3, avr 5.000e+00, min 4.000e+00, max 6.000e+00\n" + + "realtime : cnt 3, avr 7.0, min 6, max 8\n"; + assertTrue(summary.endsWith(reference)); + } + +} diff --git a/src/test/java/cws/k8s/scheduler/memory/TaskScalerTest.java b/src/test/java/cws/k8s/scheduler/memory/TaskScalerTest.java new file mode 100644 index 0000000..ca046de --- /dev/null +++ b/src/test/java/cws/k8s/scheduler/memory/TaskScalerTest.java @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2023, Florian Friederici. All rights reserved. + * + * This code is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This code is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * this work. If not, see . + */ + +package cws.k8s.scheduler.memory; + +import static org.mockito.Mockito.when; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.test.util.ReflectionTestUtils; + +import cws.k8s.scheduler.client.KubernetesClient; +import cws.k8s.scheduler.model.NodeWithAlloc; +import cws.k8s.scheduler.model.PodWithAge; +import cws.k8s.scheduler.model.Requirements; +import cws.k8s.scheduler.model.SchedulerConfig; +import cws.k8s.scheduler.model.Task; +import cws.k8s.scheduler.model.TaskConfig; +import cws.k8s.scheduler.scheduler.Scheduler; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 5 Tests for the TaskScaler class + * + * @author Florian Friederici + * + */ +@Slf4j +class TaskScalerTest { + + private Task mockTask() { + Requirements r = Mockito.mock(Requirements.class); + when(r.getRam()).thenReturn(BigDecimal.ONE); + + PodWithAge p = Mockito.mock(PodWithAge.class); + when(p.getRequest()).thenReturn(r); + + NodeWithAlloc n = Mockito.mock(NodeWithAlloc.class); + when(n.getName()).thenReturn("nodename"); + + TaskConfig tc = Mockito.mock(TaskConfig.class); + when(tc.getName()).thenReturn("Unittest"); + when(tc.getTask()).thenReturn("task"); + when(tc.getName()).thenReturn("task (1)"); + + Task t = Mockito.mock(Task.class); + when(t.getConfig()).thenReturn(tc); + when(t.wasSuccessfullyExecuted()).thenReturn(true); + when(t.getInputSize()).thenReturn(1l); + when(t.getPod()).thenReturn(p); + when(t.getNode()).thenReturn(n); + + // TODO provide .command.trace file + when(t.getWorkingDir()).thenReturn("work"); + + return t; + } + + private TaskScaler mockTaskScaler() { + SchedulerConfig schedulerConfig = Mockito.mock(SchedulerConfig.class); + ReflectionTestUtils.setField(schedulerConfig, "memoryPredictor", "none"); + Scheduler scheduler = Mockito.mock(Scheduler.class); + KubernetesClient kubernetesClient = Mockito.mock(KubernetesClient.class); + + BigDecimal maxCpu = BigDecimal.ONE; + BigDecimal maxRam = BigDecimal.TEN; + Requirements requirements = new Requirements( maxCpu, maxRam); + NodeWithAlloc nwa = Mockito.mock(NodeWithAlloc.class); + ReflectionTestUtils.setField(nwa, "maxResources", requirements); + ReflectionTestUtils.setField(nwa, "assignedPods", new HashMap<>()); + when(nwa.getMaxResources()).thenReturn(requirements); + when(nwa.getAvailableResources()).thenReturn(requirements); + + List allNodes = new ArrayList(); + allNodes.add(nwa); + when(kubernetesClient.getAllNodes()).thenReturn(allNodes); + + TaskScaler ts = new TaskScaler(scheduler, schedulerConfig, kubernetesClient); + return ts; + } + + /** + * Test NonePredictor overhead + */ + @Test + void testAfterTaskFinished() { + TaskScaler ts = mockTaskScaler(); + Task t = mockTask(); + + long repetitions = 1000; + long startTime = System.currentTimeMillis(); + for (int i=0; i unscheduled = new ArrayList<>(); + unscheduled.add(t); + + long repetitions = 1000; + long startTime = System.currentTimeMillis(); + for (int i=0; i. + */ + +package cws.k8s.scheduler.memory; + +import static org.junit.jupiter.api.Assertions.*; + +import java.math.BigDecimal; +import java.math.RoundingMode; + +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import cws.k8s.scheduler.model.Task; +import lombok.extern.slf4j.Slf4j; + +/** + * JUnit 5 Tests for the WaryPredictor + * + * @author Florian Friederici + * + */ +@Slf4j +class WaryPredictorTest { + + /** + * If there are < 4 observations, we cannot get a prediction + */ + @ParameterizedTest + @ValueSource(ints = { 0, 1, 2, 3 }) + void testNoObservationsYet(int number) { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + log.info("param: {}", number); + WaryPredictor waryPredictor = new WaryPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + + for (int i=0; i 3 observations, we can get a prediction + */ + @ParameterizedTest + @ValueSource(ints = { 4, 5, 6 }) + void testSomeObservations(int number) { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + log.info("param: {}", number); + WaryPredictor waryPredictor = new WaryPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + + for (int i=0; i 3 observations, but with same inputSize, we cannot get a prediction + */ + @ParameterizedTest + @ValueSource(ints = { 3, 4, 5 }) + void testNoDifferentObservations(int number) { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + log.info("param: {}", number); + WaryPredictor waryPredictor = new WaryPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1024l); + + for (int i=0; i 2 errors, warePredictor will quit predicting + */ + @Test + void testAfter3Errors() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + WaryPredictor waryPredictor = new WaryPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1000l); + + long initialValue = 4l*1024*1024*1024; + + for (int i=0; i<4; i++) { + log.info("insert successful observation {}", i+2); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName ("+(i+2)+")") + .success(true) + .inputSize(1001l+i) + .ramRequest(BigDecimal.valueOf(initialValue)) + .peakVmem(BigDecimal.valueOf(3l*1024*1024*1024)) + .peakRss(BigDecimal.valueOf(2l*1024*1024*1024)) + .realtime(1000) + .build(); + // @formatter:on + waryPredictor.addObservation(observation); + } + + BigDecimal prediction = waryPredictor.queryPrediction(task); + log.info("{}", prediction); + assertNotNull(prediction); + + for (int i=0; i<3; i++) { + log.info("insert failed observation {}", i+6); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName ("+(i+6)+")") + .success(false) + .inputSize(0) + .ramRequest(BigDecimal.valueOf(initialValue/2)) + .peakRss(BigDecimal.valueOf(0)) + .peakVmem(BigDecimal.valueOf(0)) + .realtime(1000) + .build(); + // @formatter:on + BigDecimal p2 = waryPredictor.queryPrediction(task); + log.info("{}", p2); + assertNotNull(p2); + waryPredictor.addObservation(observation); + } + assertEquals(initialValue, Long.parseLong( waryPredictor.queryPrediction(task).toPlainString() )); + } + + /** + * Test a specific situation + */ + //@Test + void testSpecific() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + + long is1 = 4652181; + long v1 = 5082320896l; + + long is2 = 4647849; + long v2 = 5082214400l; + + long is3 = 4589825; + long v3 = 4948094976l; + + long is4 = 2464690; + + SimpleRegression sr = new SimpleRegression(); + sr.addData(is1, v1); + sr.addData(is2, v2); + sr.addData(is3, v3); + long prediction = (long) sr.predict(is4); + log.info("expected value = {}", prediction); + + WaryPredictor waryPredictor = new WaryPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", is4); + + // @formatter:off + Observation o1 = Observation.builder().task("taskName").taskName("taskName (1)") + .success(true) + .inputSize(is1) + .ramRequest(BigDecimal.valueOf(53687091200l)) + .peakVmem(BigDecimal.valueOf(v1)) + .peakRss(BigDecimal.valueOf(853852160l)) + .realtime(73000) + .build(); + // @formatter:on + waryPredictor.addObservation(o1); + + assertNull( waryPredictor.queryPrediction(task) ); + + // @formatter:off + Observation o2 = Observation.builder().task("taskName").taskName("taskName (1)") + .success(true) + .inputSize(is2) + .ramRequest(BigDecimal.valueOf(53687091200l)) + .peakVmem(BigDecimal.valueOf(v2)) + .peakRss(BigDecimal.valueOf(858411008l)) + .realtime(71000) + .build(); + // @formatter:on + waryPredictor.addObservation(o2); + + assertNull( waryPredictor.queryPrediction(task) ); + + // @formatter:off + Observation o3 = Observation.builder().task("taskName").taskName("taskName (1)") + .success(true) + .inputSize(is3) + .ramRequest(BigDecimal.valueOf(53687091200l)) + .peakVmem(BigDecimal.valueOf(v3)) + .peakRss(BigDecimal.valueOf(854892544l)) + .realtime(82000) + .build(); + // @formatter:on + waryPredictor.addObservation(o3); + + assertTrue(BigDecimal.valueOf(prediction).multiply(BigDecimal.valueOf(1.1)).setScale(0, RoundingMode.CEILING).compareTo( waryPredictor.queryPrediction(task) ) < 1 ); + + } + + /** + * Test ignore list + */ + @Test + void testIgnoreList() { + log.info(Thread.currentThread().getStackTrace()[1].getMethodName()); + WaryPredictor waryPredictor = new WaryPredictor(); + Task task = MemoryPredictorTest.createTask("taskName", 1000); + Task task3 = MemoryPredictorTest.createTask("taskName", 1000, 3); + + long initialValue = 2000l; + + // @formatter:off + Observation o2 = Observation.builder().task("taskName").taskName("taskName (2)") + .success(false) + .inputSize(1000) + .ramRequest(BigDecimal.valueOf(initialValue)) + .peakVmem(BigDecimal.valueOf(853852160l)) + .peakRss(BigDecimal.valueOf(853852160l)) + .realtime(73000) + .build(); + // @formatter:on + waryPredictor.addObservation(o2); + + // @formatter:off + Observation o1 = Observation.builder().task("taskName").taskName("taskName (1)") + .success(false) + .inputSize(1000) + .ramRequest(BigDecimal.valueOf(53687091200l)) + .peakVmem(BigDecimal.valueOf(853852160l)) + .peakRss(BigDecimal.valueOf(853852160l)) + .realtime(73000) + .build(); + // @formatter:on + waryPredictor.addObservation(o1); + + log.info("{}", waryPredictor.queryPrediction(task)); + assertNull(waryPredictor.queryPrediction(task)); + log.info("{}", waryPredictor.queryPrediction(task3)); + assertNull(waryPredictor.queryPrediction(task3)); + + for (int i=0; i<4; i++) { + log.info("insert successful observation {}", i); + // @formatter:off + Observation observation = Observation.builder() + .task("taskName") + .taskName("taskName ("+(4+i)+")") + .success(true) + .inputSize(123+i) + .ramRequest(BigDecimal.valueOf(123)) + .peakVmem(BigDecimal.valueOf(1)) + .peakRss(BigDecimal.valueOf(1)) + .realtime(1000) + .build(); + // @formatter:on + waryPredictor.addObservation(observation); + } + + log.info("{}", waryPredictor.queryPrediction(task3)); + assertNotNull(waryPredictor.queryPrediction(task3)); + + log.info("{}", waryPredictor.queryPrediction(task)); + assertNull(waryPredictor.queryPrediction(task)); + + } + + +}