From 32da5fad8c53f4717ebe5160763273c8191a2735 Mon Sep 17 00:00:00 2001 From: Kazuki Sakamoto Date: Sun, 21 Feb 2021 21:16:19 -0800 Subject: [PATCH] [PyOutline] Supoprt gpus and gpu_memory Co-authored-by: Lars van der Bijl <285658+larsbijl@users.noreply.github.com> --- pyoutline/outline/backend/cue.py | 12 ++++++++++++ pyoutline/outline/plugins/local.py | 3 ++- pyoutline/tests/specver_test.py | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/pyoutline/outline/backend/cue.py b/pyoutline/outline/backend/cue.py index 1d91eb70a..cfb003fad 100644 --- a/pyoutline/outline/backend/cue.py +++ b/pyoutline/outline/backend/cue.py @@ -330,6 +330,18 @@ def _serialize(launcher, use_pycuerun): else: _warning_spec_version(spec_version, "timeout_llu") + if layer.get_arg("gpus"): + if spec_version >= Version("1.12"): + sub_element(spec_layer, "gpus", "%d" % (layer.get_arg("gpus"))) + else: + _warning_spec_version(spec_version, "gpus") + + if layer.get_arg("gpu_memory"): + if spec_version >= Version("1.12"): + sub_element(spec_layer, "gpu_memory", "%s" % (layer.get_arg("gpu_memory"))) + else: + _warning_spec_version(spec_version, "gpu_memory") + if os.environ.get("OL_TAG_OVERRIDE", False): sub_element(spec_layer, "tags", scrub_tags(os.environ["OL_TAG_OVERRIDE"])) diff --git a/pyoutline/outline/plugins/local.py b/pyoutline/outline/plugins/local.py index c0bc1fd17..29b6dbdf0 100644 --- a/pyoutline/outline/plugins/local.py +++ b/pyoutline/outline/plugins/local.py @@ -112,7 +112,8 @@ def setup_local_cores(e): "cores": str(USE_LOCAL_CORES), "memory": get_half_host_memory(), "threads": str(threads), - "gpu": str(0)}) + "gpus": str(0), + "gpu_memory": str(0)}) def get_half_host_memory(): diff --git a/pyoutline/tests/specver_test.py b/pyoutline/tests/specver_test.py index 24d92df07..9889e0ae9 100644 --- a/pyoutline/tests/specver_test.py +++ b/pyoutline/tests/specver_test.py @@ -61,3 +61,24 @@ def test_1_11(self): self.assertEqual(root.find("job/layers/layer/timeout").text, "420") self.assertEqual(root.find("job/layers/layer/timeout_llu").text, "4200") self.assertEqual(root.find("job/priority").text, "42") + + def _makeGpuSpec(self): + ol = outline.Outline(name="spec_version_test") + layer = outline.modules.shell.Shell("test_layer", command=["/bin/ls"]) + layer.set_arg("gpus", 4) + layer.set_arg("gpu_memory", 8 * 1024 * 1024) + ol.add_layer(layer) + l = outline.cuerun.OutlineLauncher(ol) + return Et.fromstring(l.serialize()) + + def test_gpu_1_11(self): + outline.config.set("outline", "spec_version", "1.11") + root = self._makeGpuSpec() + self.assertIsNone(root.find("job/layers/layer/gpus")) + self.assertIsNone(root.find("job/layers/layer/gpus_memory")) + + def test_gpu_1_12(self): + outline.config.set("outline", "spec_version", "1.12") + root = self._makeGpuSpec() + self.assertEqual(root.find("job/layers/layer/gpus").text, "4") + self.assertEqual(root.find("job/layers/layer/gpu_memory").text, "8388608")