diff --git a/cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go b/cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go index 43cf01b01..548f34e0f 100644 --- a/cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go +++ b/cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go @@ -113,13 +113,16 @@ func containerName(pod *api.PodSandbox, container *api.Container) string { return container.Name } -// Start starts the NRI plugin +// Start initializes the NRI plugin stub and starts the NRI plugin server func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error { pluginOpts := []stub.Option{ stub.WithPluginIdx(nriPluginIdx), stub.WithLogger(toNriLogger{p.logger}), stub.WithOnClose(func() { - p.logger.Infof("NRI ttrpc connection to %s is down. NRI plugin stopped.", nriSocketPath) + p.logger.Infof("NRI ttrpc connection to %s is down. NRI plugin stopped. Attempting to reconnect...", nriSocketPath) + if err := p.start(ctx); err != nil { + p.logger.Errorf("failed to restart NRI plugin: %v", err) + } }), } if len(nriSocketPath) > 0 { @@ -134,13 +137,25 @@ func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) if p.stub, err = stub.New(p, pluginOpts...); err != nil { return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err) } - err = p.stub.Start(ctx) + err = p.start(ctx) if err != nil { return fmt.Errorf("plugin exited with error: %w", err) } return nil } +// start starts the NRI plugin server +func (p *Plugin) start(ctx context.Context) error { + if p != nil { + if p.stub == nil { + p.logger.Infof("NRI plugin not initialized. Skipping plugin start") + } else { + return p.stub.Start(ctx) + } + } + return nil +} + // Stop stops the NRI plugin func (p *Plugin) Stop() { if p == nil || p.stub == nil {