Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 87 additions & 69 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ import (
"context"
"device-volume-driver/internal/cgroup"
"fmt"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
_ "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
"log"
"os"
"path"
"path/filepath"
"strings"

"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
_ "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
)

const pluginId = "dvd"
Expand All @@ -28,7 +29,18 @@ func Ptr[T any](v T) *T {
}

func main() {
listenForMounts()
log.Printf("Starting\n")

cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())

if err != nil {
log.Fatal(err)
}

defer cli.Close()

checkExistingContainers(cli)
listenForMounts(cli)
}

func getDeviceInfo(devicePath string) (string, int64, int64, error) {
Expand Down Expand Up @@ -59,19 +71,9 @@ func getDeviceInfo(devicePath string) (string, int64, int64, error) {
return deviceType, major, minor, nil
}

func listenForMounts() {
ctx := context.Background()

cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())

if err != nil {
log.Fatal(err)
}

defer cli.Close()

func listenForMounts(cli *client.Client) {
msgs, errs := cli.Events(
ctx,
context.Background(),
types.EventsOptions{Filters: filters.NewArgs(filters.Arg("event", "start"))},
)

Expand All @@ -80,78 +82,94 @@ func listenForMounts() {
case err := <-errs:
log.Fatal(err)
case msg := <-msgs:
info, err := cli.ContainerInspect(ctx, msg.Actor.ID)
processContainer(cli, msg.Actor.ID)
}
}
}

if err != nil {
panic(err)
} else {
pid := info.State.Pid
version, err := cgroup.GetDeviceCGroupVersion("/", pid)
func processContainer(cli *client.Client, id string) {
info, err := cli.ContainerInspect(context.Background(), id)

log.Printf("The cgroup version for process %d is: %v\n", pid, version)
if err != nil {
panic(err)
} else {
pid := info.State.Pid
version, err := cgroup.GetDeviceCGroupVersion("/", pid)

if err != nil {
log.Println(err)
break
}
log.Printf("The cgroup version for process %d is: %v\n", pid, version)

log.Printf("Checking mounts for process %d\n", pid)
if err != nil {
log.Println(err)
return
}

for _, mount := range info.Mounts {
log.Printf(
"%s/%v requested a volume mount for %s at %s\n",
msg.Actor.ID, info.State.Pid, mount.Source, mount.Destination,
)
log.Printf("Checking mounts for process %d\n", pid)

if !strings.HasPrefix(mount.Source, "/dev") {
log.Printf("%s is not a device... skipping\n", mount.Source)
continue
}
for _, mount := range info.Mounts {
log.Printf(
"%s/%v requested a volume mount for %s at %s\n",
id, info.State.Pid, mount.Source, mount.Destination,
)

api, err := cgroup.New(version)
cgroupPath, sysfsPath, err := api.GetDeviceCGroupMountPath("/", pid)
if !strings.HasPrefix(mount.Source, "/dev") {
log.Printf("%s is not a device... skipping\n", mount.Source)
continue
}

if err != nil {
log.Println(err)
break
}
api, err := cgroup.New(version)
cgroupPath, sysfsPath, err := api.GetDeviceCGroupMountPath("/", pid)

cgroupPath = path.Join(rootPath, sysfsPath, cgroupPath)
if err != nil {
log.Println(err)
break
}

log.Printf("The cgroup path for process %d is at %v\n", pid, cgroupPath)
cgroupPath = path.Join(rootPath, sysfsPath, cgroupPath)

if fileInfo, err := os.Stat(mount.Source); err != nil {
log.Println(err)
continue
} else {
if fileInfo.IsDir() {
err := filepath.Walk(mount.Source,
func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
} else if info.IsDir() {
return nil
} else if err = applyDeviceRules(api, path, cgroupPath, pid); err != nil {
log.Println(err)
}
return nil
})
log.Printf("The cgroup path for process %d is at %v\n", pid, cgroupPath)

if fileInfo, err := os.Stat(mount.Source); err != nil {
log.Println(err)
continue
} else {
if fileInfo.IsDir() {
err := filepath.Walk(mount.Source,
func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
} else if info.IsDir() {
return nil
} else if err = applyDeviceRules(api, path, cgroupPath, pid); err != nil {
log.Println(err)
}
} else {
if err = applyDeviceRules(api, mount.Source, cgroupPath, pid); err != nil {
log.Println(err)
}
}
return nil
})
if err != nil {
log.Println(err)
}
} else {
if err = applyDeviceRules(api, mount.Source, cgroupPath, pid); err != nil {
log.Println(err)
}

}
}
}
}
}

func checkExistingContainers(cli *client.Client) {
containers, err := cli.ContainerList(context.Background(), types.ContainerListOptions{})

if err != nil {
panic(err)
}

for _, container := range containers {
log.Printf("Checking existing container %s %s\n", container.ID[:10], container.Image)
processContainer(cli, container.ID)
}
}

func applyDeviceRules(api cgroup.Interface, mountPath string, cgroupPath string, pid int) error {
deviceType, major, minor, err := getDeviceInfo(mountPath)

Expand Down