diff --git a/pyoutline/outline/backend/cue.py b/pyoutline/outline/backend/cue.py index 1d91eb70a..cfb003fad 100644 --- a/pyoutline/outline/backend/cue.py +++ b/pyoutline/outline/backend/cue.py @@ -330,6 +330,18 @@ def _serialize(launcher, use_pycuerun): else: _warning_spec_version(spec_version, "timeout_llu") + if layer.get_arg("gpus"): + if spec_version >= Version("1.12"): + sub_element(spec_layer, "gpus", "%d" % (layer.get_arg("gpus"))) + else: + _warning_spec_version(spec_version, "gpus") + + if layer.get_arg("gpu_memory"): + if spec_version >= Version("1.12"): + sub_element(spec_layer, "gpu_memory", "%s" % (layer.get_arg("gpu_memory"))) + else: + _warning_spec_version(spec_version, "gpu_memory") + if os.environ.get("OL_TAG_OVERRIDE", False): sub_element(spec_layer, "tags", scrub_tags(os.environ["OL_TAG_OVERRIDE"])) diff --git a/pyoutline/outline/plugins/local.py b/pyoutline/outline/plugins/local.py index c0bc1fd17..29b6dbdf0 100644 --- a/pyoutline/outline/plugins/local.py +++ b/pyoutline/outline/plugins/local.py @@ -112,7 +112,8 @@ def setup_local_cores(e): "cores": str(USE_LOCAL_CORES), "memory": get_half_host_memory(), "threads": str(threads), - "gpu": str(0)}) + "gpus": str(0), + "gpu_memory": str(0)}) def get_half_host_memory(): diff --git a/pyoutline/tests/specver_test.py b/pyoutline/tests/specver_test.py index 24d92df07..9889e0ae9 100644 --- a/pyoutline/tests/specver_test.py +++ b/pyoutline/tests/specver_test.py @@ -61,3 +61,24 @@ def test_1_11(self): self.assertEqual(root.find("job/layers/layer/timeout").text, "420") self.assertEqual(root.find("job/layers/layer/timeout_llu").text, "4200") self.assertEqual(root.find("job/priority").text, "42") + + def _makeGpuSpec(self): + ol = outline.Outline(name="spec_version_test") + layer = outline.modules.shell.Shell("test_layer", command=["/bin/ls"]) + layer.set_arg("gpus", 4) + layer.set_arg("gpu_memory", 8 * 1024 * 1024) + ol.add_layer(layer) + l = outline.cuerun.OutlineLauncher(ol) + return Et.fromstring(l.serialize()) + + def test_gpu_1_11(self): + outline.config.set("outline", "spec_version", "1.11") + root = self._makeGpuSpec() + self.assertIsNone(root.find("job/layers/layer/gpus")) + self.assertIsNone(root.find("job/layers/layer/gpus_memory")) + + def test_gpu_1_12(self): + outline.config.set("outline", "spec_version", "1.12") + root = self._makeGpuSpec() + self.assertEqual(root.find("job/layers/layer/gpus").text, "4") + self.assertEqual(root.find("job/layers/layer/gpu_memory").text, "8388608")