diff --git a/main.go b/main.go index 411b48b..0c9ed36 100644 --- a/main.go +++ b/main.go @@ -8,16 +8,17 @@ import ( "context" "device-volume-driver/internal/cgroup" "fmt" - "github.com/docker/docker/api/types" - "github.com/docker/docker/api/types/filters" - "github.com/docker/docker/client" - _ "github.com/opencontainers/runtime-spec/specs-go" - "golang.org/x/sys/unix" "log" "os" "path" "path/filepath" "strings" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" + _ "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" ) const pluginId = "dvd" @@ -28,7 +29,18 @@ func Ptr[T any](v T) *T { } func main() { - listenForMounts() + log.Printf("Starting\n") + + cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) + + if err != nil { + log.Fatal(err) + } + + defer cli.Close() + + checkExistingContainers(cli) + listenForMounts(cli) } func getDeviceInfo(devicePath string) (string, int64, int64, error) { @@ -59,19 +71,9 @@ func getDeviceInfo(devicePath string) (string, int64, int64, error) { return deviceType, major, minor, nil } -func listenForMounts() { - ctx := context.Background() - - cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) - - if err != nil { - log.Fatal(err) - } - - defer cli.Close() - +func listenForMounts(cli *client.Client) { msgs, errs := cli.Events( - ctx, + context.Background(), types.EventsOptions{Filters: filters.NewArgs(filters.Arg("event", "start"))}, ) @@ -80,78 +82,94 @@ func listenForMounts() { case err := <-errs: log.Fatal(err) case msg := <-msgs: - info, err := cli.ContainerInspect(ctx, msg.Actor.ID) + processContainer(cli, msg.Actor.ID) + } + } +} - if err != nil { - panic(err) - } else { - pid := info.State.Pid - version, err := cgroup.GetDeviceCGroupVersion("/", pid) +func processContainer(cli *client.Client, id string) { + info, err := cli.ContainerInspect(context.Background(), id) - log.Printf("The cgroup version for process %d is: %v\n", pid, version) + if err != nil { + panic(err) + } else { + pid := info.State.Pid + version, err := cgroup.GetDeviceCGroupVersion("/", pid) - if err != nil { - log.Println(err) - break - } + log.Printf("The cgroup version for process %d is: %v\n", pid, version) - log.Printf("Checking mounts for process %d\n", pid) + if err != nil { + log.Println(err) + return + } - for _, mount := range info.Mounts { - log.Printf( - "%s/%v requested a volume mount for %s at %s\n", - msg.Actor.ID, info.State.Pid, mount.Source, mount.Destination, - ) + log.Printf("Checking mounts for process %d\n", pid) - if !strings.HasPrefix(mount.Source, "/dev") { - log.Printf("%s is not a device... skipping\n", mount.Source) - continue - } + for _, mount := range info.Mounts { + log.Printf( + "%s/%v requested a volume mount for %s at %s\n", + id, info.State.Pid, mount.Source, mount.Destination, + ) - api, err := cgroup.New(version) - cgroupPath, sysfsPath, err := api.GetDeviceCGroupMountPath("/", pid) + if !strings.HasPrefix(mount.Source, "/dev") { + log.Printf("%s is not a device... skipping\n", mount.Source) + continue + } - if err != nil { - log.Println(err) - break - } + api, err := cgroup.New(version) + cgroupPath, sysfsPath, err := api.GetDeviceCGroupMountPath("/", pid) - cgroupPath = path.Join(rootPath, sysfsPath, cgroupPath) + if err != nil { + log.Println(err) + break + } - log.Printf("The cgroup path for process %d is at %v\n", pid, cgroupPath) + cgroupPath = path.Join(rootPath, sysfsPath, cgroupPath) - if fileInfo, err := os.Stat(mount.Source); err != nil { - log.Println(err) - continue - } else { - if fileInfo.IsDir() { - err := filepath.Walk(mount.Source, - func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } else if info.IsDir() { - return nil - } else if err = applyDeviceRules(api, path, cgroupPath, pid); err != nil { - log.Println(err) - } - return nil - }) + log.Printf("The cgroup path for process %d is at %v\n", pid, cgroupPath) + + if fileInfo, err := os.Stat(mount.Source); err != nil { + log.Println(err) + continue + } else { + if fileInfo.IsDir() { + err := filepath.Walk(mount.Source, + func(path string, info os.FileInfo, err error) error { if err != nil { + return err + } else if info.IsDir() { + return nil + } else if err = applyDeviceRules(api, path, cgroupPath, pid); err != nil { log.Println(err) } - } else { - if err = applyDeviceRules(api, mount.Source, cgroupPath, pid); err != nil { - log.Println(err) - } - } + return nil + }) + if err != nil { + log.Println(err) + } + } else { + if err = applyDeviceRules(api, mount.Source, cgroupPath, pid); err != nil { + log.Println(err) } - } } } } } +func checkExistingContainers(cli *client.Client) { + containers, err := cli.ContainerList(context.Background(), types.ContainerListOptions{}) + + if err != nil { + panic(err) + } + + for _, container := range containers { + log.Printf("Checking existing container %s %s\n", container.ID[:10], container.Image) + processContainer(cli, container.ID) + } +} + func applyDeviceRules(api cgroup.Interface, mountPath string, cgroupPath string, pid int) error { deviceType, major, minor, err := getDeviceInfo(mountPath)