diff --git a/loader/loader_test.go b/loader/loader_test.go index 7b9e1677..141f3998 100644 --- a/loader/loader_test.go +++ b/loader/loader_test.go @@ -2123,7 +2123,7 @@ services: } func TestServiceDeviceRequestCountStringType(t *testing.T) { - _, err := loadYAML(` + project, err := loadYAML(` name: service-device-request-count services: hello-world: @@ -2137,6 +2137,7 @@ services: count: all `) assert.NilError(t, err) + assert.Equal(t, project.Services["hello-world"].Deploy.Resources.Reservations.Devices[0].Count, types.DeviceCount(-1), err) } func TestServiceDeviceRequestCountIntegerAsStringType(t *testing.T) { @@ -2155,6 +2156,22 @@ services: `) assert.NilError(t, err) } +func TestServiceDeviceRequestWithoutCountAndDeviceIdsType(t *testing.T) { + project, err := loadYAML(` +name: service-device-request-count-type +services: + hello-world: + image: redis:alpine + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] +`) + assert.NilError(t, err) + assert.Equal(t, project.Services["hello-world"].Deploy.Resources.Reservations.Devices[0].Count, types.DeviceCount(-1), err) +} func TestServiceDeviceRequestCountInvalidStringType(t *testing.T) { _, err := loadYAML(` @@ -2173,6 +2190,40 @@ services: assert.ErrorContains(t, err, `invalid value "some_string", the only value allowed is 'all' or a number`) } +func TestServiceDeviceRequestCountAndDeviceIdsExclusive(t *testing.T) { + _, err := loadYAML(` +name: service-device-request-count-type +services: + hello-world: + image: redis:alpine + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] + count: 2 + device_ids: ["my-device-id"] +`) + assert.ErrorContains(t, err, `invalid "count" and "device_ids" are attributes are exclusive`) +} + +func TestServiceDeviceRequestCapabilitiesMandatory(t *testing.T) { + _, err := loadYAML(` +name: service-device-request-count-type +services: + hello-world: + image: redis:alpine + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 2 +`) + assert.ErrorContains(t, err, `"capabilities" attribute is mandatory for device request definition`) +} + func TestServicePullPolicy(t *testing.T) { actual, err := loadYAML(` name: service-pull-policy diff --git a/types/device.go b/types/device.go index 240e8778..0d8d7e7a 100644 --- a/types/device.go +++ b/types/device.go @@ -50,3 +50,48 @@ func (c *DeviceCount) DecodeMapstructure(value interface{}) error { } return nil } + +func (d *DeviceRequest) DecodeMapstructure(value interface{}) error { + v, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("invalid device request type %T", value) + } + if _, okCaps := v["capabilities"]; !okCaps { + return fmt.Errorf(`"capabilities" attribute is mandatory for device request definition`) + } + if _, okCount := v["count"]; okCount { + if _, okDeviceIds := v["device_ids"]; okDeviceIds { + return fmt.Errorf(`invalid "count" and "device_ids" are attributes are exclusive`) + } + } + d.Count = DeviceCount(-1) + + capabilities := v["capabilities"] + caps := StringList{} + if err := caps.DecodeMapstructure(capabilities); err != nil { + return err + } + d.Capabilities = caps + if driver, ok := v["driver"]; ok { + if val, ok := driver.(string); ok { + d.Driver = val + } else { + return fmt.Errorf("invalid type for driver value: %T", driver) + } + } + if count, ok := v["count"]; ok { + if err := d.Count.DecodeMapstructure(count); err != nil { + return err + } + } + if deviceIDs, ok := v["device_ids"]; ok { + ids := StringList{} + if err := ids.DecodeMapstructure(deviceIDs); err != nil { + return err + } + d.IDs = ids + d.Count = DeviceCount(len(ids)) + } + return nil + +}