Skip to content

Commit

Permalink
[PyOutline] Supoprt gpus and gpu_memory
Browse files Browse the repository at this point in the history
Co-authored-by: Lars van der Bijl <[email protected]>
  • Loading branch information
splhack and larsbijl committed Feb 22, 2021
1 parent 1009619 commit bae92f4
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
12 changes: 12 additions & 0 deletions pyoutline/outline/backend/cue.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,18 @@ def _serialize(launcher, use_pycuerun):
else:
_warning_spec_version(spec_version, "timeout_llu")

if layer.get_arg("gpus"):
if spec_version >= Version("1.12"):
sub_element(spec_layer, "gpus", "%d" % (layer.get_arg("gpus")))
else:
_warning_spec_version(spec_version, "gpus")

if layer.get_arg("gpu_memory"):
if spec_version >= Version("1.12"):
sub_element(spec_layer, "gpu_memory", "%s" % (layer.get_arg("gpu_memory")))
else:
_warning_spec_version(spec_version, "gpu_memory")

if os.environ.get("OL_TAG_OVERRIDE", False):
sub_element(spec_layer, "tags",
scrub_tags(os.environ["OL_TAG_OVERRIDE"]))
Expand Down
3 changes: 2 additions & 1 deletion pyoutline/outline/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def setup_local_cores(e):
"cores": str(USE_LOCAL_CORES),
"memory": get_half_host_memory(),
"threads": str(threads),
"gpu": str(0)})
"gpus": str(0),
"gpu_memory": str(0)})


def get_half_host_memory():
Expand Down
21 changes: 21 additions & 0 deletions pyoutline/tests/specver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,24 @@ def test_1_11(self):
self.assertEqual(root.find("job/layers/layer/timeout").text, "420")
self.assertEqual(root.find("job/layers/layer/timeout_llu").text, "4200")
self.assertEqual(root.find("job/priority").text, "42")

def _makeGpuSpec(self):
ol = outline.Outline(name="spec_version_test")
layer = outline.modules.shell.Shell("test_layer", command=["/bin/ls"])
layer.set_arg("gpus", 4)
layer.set_arg("gpu_memory", 8 * 1024 * 1024)
ol.add_layer(layer)
l = outline.cuerun.OutlineLauncher(ol)
return Et.fromstring(l.serialize())

def test_gpu_1_11(self):
outline.config.set("outline", "spec_version", "1.11")
root = self._makeGpuSpec()
self.assertIsNone(root.find("job/layers/layer/gpus"))
self.assertIsNone(root.find("job/layers/layer/gpus_memory"))

def test_gpu_1_12(self):
outline.config.set("outline", "spec_version", "1.12")
root = self._makeGpuSpec()
self.assertEqual(root.find("job/layers/layer/gpus").text, "4")
self.assertEqual(root.find("job/layers/layer/gpu_memory").text, "8388608")

0 comments on commit bae92f4

Please sign in to comment.