Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device requests #2471

Merged
merged 7 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/api/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@ def create_host_config(self, *args, **kwargs):
For example, ``/dev/sda:/dev/xvda:rwm`` allows the container
to have read-write access to the host's ``/dev/sda`` via a
node named ``/dev/xvda`` inside the container.
device_requests (:py:class:`list`): Expose host resources such as
GPUs to the container, as a list of
:py:class:`docker.types.DeviceRequest` instances.
dns (:py:class:`list`): Set custom DNS servers.
dns_opt (:py:class:`list`): Additional options to be added to the
container's ``resolv.conf`` file
Expand Down
4 changes: 4 additions & 0 deletions docker/models/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,9 @@ def run(self, image, command=None, stdout=True, stderr=False,
For example, ``/dev/sda:/dev/xvda:rwm`` allows the container
to have read-write access to the host's ``/dev/sda`` via a
node named ``/dev/xvda`` inside the container.
device_requests (:py:class:`list`): Expose host resources such as
GPUs to the container, as a list of
:py:class:`docker.types.DeviceRequest` instances.
dns (:py:class:`list`): Set custom DNS servers.
dns_opt (:py:class:`list`): Additional options to be added to the
container's ``resolv.conf`` file.
Expand Down Expand Up @@ -998,6 +1001,7 @@ def prune(self, filters=None):
'device_write_bps',
'device_write_iops',
'devices',
'device_requests',
'dns_opt',
'dns_search',
'dns',
Expand Down
4 changes: 3 additions & 1 deletion docker/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# flake8: noqa
from .containers import ContainerConfig, HostConfig, LogConfig, Ulimit
from .containers import (
ContainerConfig, HostConfig, LogConfig, Ulimit, DeviceRequest
Lucidiot marked this conversation as resolved.
Show resolved Hide resolved
)
from .daemon import CancellableStream
from .healthcheck import Healthcheck
from .networks import EndpointConfig, IPAMConfig, IPAMPool, NetworkingConfig
Expand Down
113 changes: 112 additions & 1 deletion docker/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,104 @@ def hard(self, value):
self['Hard'] = value


class DeviceRequest(DictType):
"""
Create a device request to be used with
:py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`.

Args:

driver (str): Which driver to use for this device. Optional.
count (int): Number or devices to request. Optional.
Set to -1 to request all available devices.
device_ids (list): List of strings for device IDs. Optional.
Set either ``count`` or ``device_ids``.
capabilities (list): List of lists of strings to request
capabilities. Optional. The global list acts like an OR,
and the sub-lists are AND. The driver will try to satisfy
one of the sub-lists.
Available capabilities for the ``nvidia`` driver can be found
`here <https://github.com/NVIDIA/nvidia-container-runtime>`_.
options (dict): Driver-specific options. Optional.
"""

def __init__(self, **kwargs):
driver = kwargs.get('driver', kwargs.get('Driver'))
count = kwargs.get('count', kwargs.get('Count'))
device_ids = kwargs.get('device_ids', kwargs.get('DeviceIDs'))
capabilities = kwargs.get('capabilities', kwargs.get('Capabilities'))
options = kwargs.get('options', kwargs.get('Options'))

if driver is None:
driver = ''
elif not isinstance(driver, six.string_types):
raise ValueError('DeviceRequest.driver must be a string')
if count is None:
count = 0
elif not isinstance(count, int):
raise ValueError('DeviceRequest.count must be an integer')
if device_ids is None:
device_ids = []
elif not isinstance(device_ids, list):
raise ValueError('DeviceRequest.device_ids must be a list')
if capabilities is None:
capabilities = []
elif not isinstance(capabilities, list):
raise ValueError('DeviceRequest.capabilities must be a list')
if options is None:
options = {}
elif not isinstance(options, dict):
raise ValueError('DeviceRequest.options must be a dict')

super(DeviceRequest, self).__init__({
'Driver': driver,
'Count': count,
'DeviceIDs': device_ids,
'Capabilities': capabilities,
'Options': options
})

@property
def driver(self):
return self['Driver']

@driver.setter
def driver(self, value):
self['Driver'] = value

@property
def count(self):
return self['Count']

@count.setter
def count(self, value):
self['Count'] = value

@property
def device_ids(self):
return self['DeviceIDs']

@device_ids.setter
def device_ids(self, value):
self['DeviceIDs'] = value

@property
def capabilities(self):
return self['Capabilities']

@capabilities.setter
def capabilities(self, value):
self['Capabilities'] = value

@property
def options(self):
return self['Options']

@options.setter
def options(self, value):
self['Options'] = value


class HostConfig(dict):
def __init__(self, version, binds=None, port_bindings=None,
lxc_conf=None, publish_all_ports=False, links=None,
Expand All @@ -176,7 +274,7 @@ def __init__(self, version, binds=None, port_bindings=None,
volume_driver=None, cpu_count=None, cpu_percent=None,
nano_cpus=None, cpuset_mems=None, runtime=None, mounts=None,
cpu_rt_period=None, cpu_rt_runtime=None,
device_cgroup_rules=None):
device_cgroup_rules=None, device_requests=None):

if mem_limit is not None:
self['Memory'] = parse_bytes(mem_limit)
Expand Down Expand Up @@ -536,6 +634,19 @@ def __init__(self, version, binds=None, port_bindings=None,
)
self['DeviceCgroupRules'] = device_cgroup_rules

if device_requests is not None:
if version_lt(version, '1.40'):
raise host_config_version_error('device_requests', '1.40')
if not isinstance(device_requests, list):
raise host_config_type_error(
'device_requests', device_requests, 'list'
)
self['DeviceRequests'] = []
for req in device_requests:
if not isinstance(req, DeviceRequest):
req = DeviceRequest(**req)
self['DeviceRequests'].append(req)


def host_config_type_error(param, param_value, expected):
error_msg = 'Invalid type for {0} param: expected {1} but found {2}'
Expand Down
64 changes: 63 additions & 1 deletion tests/unit/api_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import signal

import docker
from docker.api import APIClient
import pytest
import six

from . import fake_api
from ..helpers import requires_api_version
from .api_test import (
BaseAPIClientTest, url_prefix, fake_request, DEFAULT_TIMEOUT_SECONDS,
fake_inspect_container
fake_inspect_container, url_base
)

try:
Expand Down Expand Up @@ -767,6 +768,67 @@ def test_create_container_with_devices(self):
assert args[1]['headers'] == {'Content-Type': 'application/json'}
assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS

def test_create_container_with_device_requests(self):
client = APIClient(version='1.40')
fake_api.fake_responses.setdefault(
'{0}/v1.40/containers/create'.format(fake_api.prefix),
fake_api.post_fake_create_container,
)
client.create_container(
'busybox', 'true', host_config=client.create_host_config(
device_requests=[
{
'device_ids': [
'0',
'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a'
]
},
{
'driver': 'nvidia',
'Count': -1,
'capabilities': [
['gpu', 'utility']
],
'options': {
'key': 'value'
}
}
]
)
)

args = fake_request.call_args
assert args[0][1] == url_base + 'v1.40/' + 'containers/create'
expected_payload = self.base_create_payload()
expected_payload['HostConfig'] = client.create_host_config()
expected_payload['HostConfig']['DeviceRequests'] = [
{
'Driver': '',
'Count': 0,
'DeviceIDs': [
'0',
'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a'
],
'Capabilities': [],
'Options': {}
},
{
'Driver': 'nvidia',
'Count': -1,
'DeviceIDs': [],
'Capabilities': [
['gpu', 'utility']
],
'Options': {
'key': 'value'
}
}
]
assert json.loads(args[1]['data']) == expected_payload
assert args[1]['headers']['Content-Type'] == 'application/json'
assert set(args[1]['headers']) <= {'Content-Type', 'User-Agent'}
assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS

def test_create_container_with_labels_dict(self):
labels_dict = {
six.text_type('foo'): six.text_type('1'),
Expand Down