Skip to content

Commit 1ba368a

Browse files
author
Tibor Vass
committed
container: --gpus support
Signed-off-by: Tibor Vass <[email protected]>
1 parent 91339e1 commit 1ba368a

File tree

3 files changed

+164
-0
lines changed

3 files changed

+164
-0
lines changed

cli/command/container/opts.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type containerOptions struct {
4646
labels opts.ListOpts
4747
deviceCgroupRules opts.ListOpts
4848
devices opts.ListOpts
49+
gpus opts.GpuOpts
4950
ulimits *opts.UlimitOpt
5051
sysctls *opts.MapOpts
5152
publish opts.ListOpts
@@ -166,6 +167,8 @@ func addFlags(flags *pflag.FlagSet) *containerOptions {
166167
flags.VarP(&copts.attach, "attach", "a", "Attach to STDIN, STDOUT or STDERR")
167168
flags.Var(&copts.deviceCgroupRules, "device-cgroup-rule", "Add a rule to the cgroup allowed devices list")
168169
flags.Var(&copts.devices, "device", "Add a host device to the container")
170+
flags.Var(&copts.gpus, "gpus", "GPU devices to add to the container ('all' to pass all GPUs)")
171+
flags.SetAnnotation("gpus", "version", []string{"1.40"})
169172
flags.VarP(&copts.env, "env", "e", "Set environment variables")
170173
flags.Var(&copts.envFile, "env-file", "Read in a file of environment variables")
171174
flags.StringVar(&copts.entrypoint, "entrypoint", "", "Overwrite the default ENTRYPOINT of the image")
@@ -557,6 +560,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
557560
Ulimits: copts.ulimits.GetList(),
558561
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
559562
Devices: deviceMappings,
563+
DeviceRequests: copts.gpus.Value(),
560564
}
561565

562566
config := &container.Config{

opts/gpus.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package opts
2+
3+
import (
4+
"encoding/csv"
5+
"fmt"
6+
"strconv"
7+
"strings"
8+
9+
"github.com/docker/docker/api/types/container"
10+
"github.com/pkg/errors"
11+
)
12+
13+
// GpuOpts is a Value type for parsing mounts
14+
type GpuOpts struct {
15+
values []container.DeviceRequest
16+
}
17+
18+
func parseCount(s string) (int, error) {
19+
if s == "all" {
20+
return -1, nil
21+
}
22+
i, err := strconv.Atoi(s)
23+
return i, errors.Wrap(err, "count must be an integer")
24+
}
25+
26+
// Set a new mount value
27+
// nolint: gocyclo
28+
func (o *GpuOpts) Set(value string) error {
29+
csvReader := csv.NewReader(strings.NewReader(value))
30+
fields, err := csvReader.Read()
31+
if err != nil {
32+
return err
33+
}
34+
35+
req := container.DeviceRequest{}
36+
37+
seen := map[string]struct{}{}
38+
// Set writable as the default
39+
for _, field := range fields {
40+
parts := strings.SplitN(field, "=", 2)
41+
key := parts[0]
42+
if _, ok := seen[key]; ok {
43+
return fmt.Errorf("gpu request key '%s' can be specified only once", key)
44+
}
45+
seen[key] = struct{}{}
46+
47+
if len(parts) == 1 {
48+
seen["count"] = struct{}{}
49+
req.Count, err = parseCount(key)
50+
if err != nil {
51+
return err
52+
}
53+
continue
54+
}
55+
56+
value := parts[1]
57+
switch key {
58+
case "driver":
59+
req.Driver = value
60+
case "count":
61+
req.Count, err = parseCount(value)
62+
if err != nil {
63+
return err
64+
}
65+
case "device":
66+
req.DeviceIDs = strings.Split(value, ",")
67+
case "capabilities":
68+
req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")}
69+
case "options":
70+
r := csv.NewReader(strings.NewReader(value))
71+
optFields, err := r.Read()
72+
if err != nil {
73+
return errors.Wrap(err, "failed to read gpu options")
74+
}
75+
req.Options = ConvertKVStringsToMap(optFields)
76+
default:
77+
return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
78+
}
79+
}
80+
81+
if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
82+
req.Count = 1
83+
}
84+
if req.Options == nil {
85+
req.Options = make(map[string]string)
86+
}
87+
if req.Capabilities == nil {
88+
req.Capabilities = [][]string{{"gpu"}}
89+
}
90+
91+
o.values = append(o.values, req)
92+
return nil
93+
}
94+
95+
// Type returns the type of this option
96+
func (o *GpuOpts) Type() string {
97+
return "gpu-request"
98+
}
99+
100+
// String returns a string repr of this option
101+
func (o *GpuOpts) String() string {
102+
gpus := []string{}
103+
for _, gpu := range o.values {
104+
gpus = append(gpus, fmt.Sprintf("%v", gpu))
105+
}
106+
return strings.Join(gpus, ", ")
107+
}
108+
109+
// Value returns the mounts
110+
func (o *GpuOpts) Value() []container.DeviceRequest {
111+
return o.values
112+
}

opts/gpus_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package opts
2+
3+
import (
4+
"testing"
5+
6+
"github.com/docker/docker/api/types/container"
7+
"gotest.tools/assert"
8+
is "gotest.tools/assert/cmp"
9+
)
10+
11+
func TestGpusOptAll(t *testing.T) {
12+
for _, testcase := range []string{
13+
"all",
14+
"-1",
15+
"count=all",
16+
"count=-1",
17+
} {
18+
var gpus GpuOpts
19+
gpus.Set(testcase)
20+
gpuReqs := gpus.Value()
21+
assert.Assert(t, is.Len(gpuReqs, 1))
22+
assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{
23+
Count: -1,
24+
Capabilities: [][]string{{"gpu"}},
25+
Options: map[string]string{},
26+
}))
27+
}
28+
}
29+
30+
func TestGpusOpts(t *testing.T) {
31+
for _, testcase := range []string{
32+
"driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
33+
"1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
34+
"count=1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
35+
"driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\",count=1",
36+
} {
37+
var gpus GpuOpts
38+
gpus.Set(testcase)
39+
gpuReqs := gpus.Value()
40+
assert.Assert(t, is.Len(gpuReqs, 1))
41+
assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{
42+
Driver: "nvidia",
43+
Count: 1,
44+
Capabilities: [][]string{{"compute", "utility", "gpu"}},
45+
Options: map[string]string{"foo": "bar", "baz": "qux"},
46+
}))
47+
}
48+
}

0 commit comments

Comments
 (0)