diff --git a/cmd/containerd-shim-lcow-v2/containerd-shim-lcow-v2.exe.manifest b/cmd/containerd-shim-lcow-v2/containerd-shim-lcow-v2.exe.manifest new file mode 100644 index 0000000000..9c5ba67277 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/containerd-shim-lcow-v2.exe.manifest @@ -0,0 +1,17 @@ + + + containerd-shim-lcow-v2 + + + + + + + + + + true + + + + diff --git a/cmd/containerd-shim-lcow-v2/main.go b/cmd/containerd-shim-lcow-v2/main.go new file mode 100644 index 0000000000..7bb297810f --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/main.go @@ -0,0 +1,93 @@ +//go:build windows + +// containerd-shim-lcow-v2 is a containerd shim implementation for Linux Containers on Windows (LCOW). +package main + +import ( + "context" + "errors" + "fmt" + "io" + "os" + + _ "github.com/Microsoft/hcsshim/cmd/containerd-shim-lcow-v2/service/plugin" + runhcsopts "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/internal/shim" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" + + "github.com/containerd/errdefs" + "github.com/sirupsen/logrus" + "go.opencensus.io/trace" +) + +// Add a manifest to get proper Windows version detection. +//go:generate go tool github.com/josephspurrier/goversioninfo/cmd/goversioninfo -platform-specific + +func main() { + logrus.AddHook(log.NewHook()) + + // Register our OpenCensus logrus exporter so that trace spans are emitted via logrus. + trace.ApplyConfig(trace.Config{DefaultSampler: oc.DefaultSampler}) + trace.RegisterExporter(&oc.LogrusExporter{}) + + logrus.SetFormatter(log.NopFormatter{}) + logrus.SetOutput(io.Discard) + + // Set the log configuration. + // If we encounter an error, we exit with non-zero code. + if err := setLogConfiguration(); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "%s: %s", vmutils.LCOWShimName, err) + os.Exit(1) + } + + // Start the shim manager event loop. The manager is responsible for + // handling containerd start/stop lifecycle calls for the shim process. + shim.Run(context.Background(), newShimManager(vmutils.LCOWShimName), func(c *shim.Config) { + // We don't want the shim package to set up logging options. + c.NoSetupLogger = true + }) +} + +// setLogConfiguration reads the runtime options from stdin and sets the log configuration. +// We only set up the log configuration for serve action. +func setLogConfiguration() error { + // We set up the log configuration in the serve action only. + // This is because we want to avoid reading the stdin in start action, + // so that we can pass it along to the invocation for serve action. + if len(os.Args) > 1 && os.Args[len(os.Args)-1] == "serve" { + // The serve process is started with stderr pointing to panic.log file. + // We want to keep that file only for pure Go panics. Any explicit writes + // to os.Stderr should go to stdout instead, which is connected to the parent's + // stderr for regular logging. + // We can safely redirect os.Stderr to os.Stdout because in case of panics, + // the Go runtime will write the panic stack trace directly to the file descriptor, + // bypassing os.Stderr, so it will still go to panic.log. + os.Stderr = os.Stdout + + opts, err := shim.ReadRuntimeOptions[*runhcsopts.Options](os.Stdin) + if err != nil { + if !errors.Is(err, errdefs.ErrNotFound) { + return fmt.Errorf("failed to read runtime options from stdin: %w", err) + } + } + + if opts != nil { + if opts.LogLevel != "" { + // If log level is specified, set the corresponding logrus logging level. + lvl, err := logrus.ParseLevel(opts.LogLevel) + if err != nil { + return fmt.Errorf("failed to parse shim log level %q: %w", opts.LogLevel, err) + } + logrus.SetLevel(lvl) + } + + if opts.ScrubLogs { + log.SetScrubbing(true) + } + } + _ = os.Stdin.Close() + } + return nil +} diff --git a/cmd/containerd-shim-lcow-v2/manager.go b/cmd/containerd-shim-lcow-v2/manager.go new file mode 100644 index 0000000000..e0a997dcc5 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/manager.go @@ -0,0 +1,295 @@ +//go:build windows + +package main + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "syscall" + "time" + + runhcsopts "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" + "github.com/Microsoft/hcsshim/internal/hcs" + "github.com/Microsoft/hcsshim/internal/memory" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/internal/shim" + hcsversion "github.com/Microsoft/hcsshim/internal/version" + + "github.com/containerd/containerd/api/types" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/errdefs" + "github.com/containerd/typeurl/v2" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + // addrFmt is the format of the address used for containerd shim. + addrFmt = "\\\\.\\pipe\\ProtectedPrefix\\Administrators\\containerd-shim-%s-%s-pipe" + + // serveReadyEventNameFormat is the format string used to construct the named Windows event + // that signals when the child "serve" process is ready to accept ttrpc connections. + // It is formatted with the namespace and shim ID (e.g. "-"). + serveReadyEventNameFormat = "%s-%s" +) + +// shimManager implements the shim.Manager interface. It is the entry-point +// used by the containerd shim runner to create and destroy shim instances. +type shimManager struct { + name string +} + +// Verify that shimManager implements shim.Manager interface +var _ shim.Manager = (*shimManager)(nil) + +// newShimManager returns a shimManager with the given binary name. +func newShimManager(name string) *shimManager { + return &shimManager{ + name: name, + } +} + +// newCommand builds the exec.Cmd that will be used to spawn the long-running +// "serve" child process. +func newCommand(ctx context.Context, + id, + containerdAddress, + socketAddr string, + stderr io.Writer, +) (*exec.Cmd, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + self, err := os.Executable() + if err != nil { + return nil, err + } + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + + args := []string{ + "-namespace", ns, + "-id", id, + "-address", containerdAddress, + "-socket", socketAddr, + "serve", + } + cmd := exec.Command(self, args...) + cmd.Dir = cwd + // Limit Go runtime parallelism in the child to avoid excessive CPU usage. + cmd.Env = append(os.Environ(), "GOMAXPROCS=4") + // Place the child in its own process group so OS signals (e.g. Ctrl-C) + // sent to the parent are not automatically forwarded to the child. + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: windows.CREATE_NEW_PROCESS_GROUP, + } + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stderr + cmd.Stderr = stderr + + return cmd, nil +} + +// Name returns the name of the shim +func (m *shimManager) Name() string { + return m.name +} + +// Start starts a shim instance for 'containerd-shim-lcow-v2'. +// This shim relies on containerd's Sandbox API to start a sandbox. +// There can be following scenarios that will launch a shim- +// +// 1. Containerd Sandbox Controller calls the Start command to start +// the sandbox for the pod. All the container create requests will +// set the SandboxID via `WithSandbox` ContainerOpts. Thereby, the +// container create request within the pod will be routed directly to the +// shim without calling the start command again. +// +// NOTE: This shim will not support routing the create request to an existing +// shim based on annotations like `io.kubernetes.cri.sandbox-id`. +func (m *shimManager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shim.BootstrapParams, retErr error) { + // We cant write anything to stdout/stderr for this cmd. + logrus.SetOutput(io.Discard) + + var params shim.BootstrapParams + params.Version = 3 + params.Protocol = "ttrpc" + + cwd, err := os.Getwd() + if err != nil { + return params, fmt.Errorf("failed to get current working directory: %w", err) + } + + f, err := os.Create(filepath.Join(cwd, "panic.log")) + if err != nil { + return params, fmt.Errorf("failed to create panic log file: %w", err) + } + defer f.Close() + + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return params, fmt.Errorf("failed to get namespace from context: %w", err) + } + + // Create an event on which we will listen to know when the shim is ready to accept connections. + // The child serve process signals this event once its TTRPC server is fully initialized. + eventName, _ := windows.UTF16PtrFromString(fmt.Sprintf(serveReadyEventNameFormat, ns, id)) + + // Create the named event + handle, err := windows.CreateEvent(nil, 0, 0, eventName) + if err != nil { + return params, fmt.Errorf("failed to create event: %w", err) + } + defer func() { + _ = windows.CloseHandle(handle) + }() + + // address is the named pipe address that the shim will use to serve the ttrpc service. + address := fmt.Sprintf(addrFmt, ns, id) + + // Create the serve command. + cmd, err := newCommand(ctx, id, opts.Address, address, f) + if err != nil { + return params, err + } + + if err = cmd.Start(); err != nil { + return params, err + } + + defer func() { + if retErr != nil { + _ = cmd.Process.Kill() + } + }() + + // Block until the child signals the event. + _, _ = windows.WaitForSingleObject(handle, windows.INFINITE) + + params.Address = address + return params, nil +} + +// Stop tears down a running shim instance identified by id. +// It reads and logs any panic messages written to panic.log, then tries to +// terminate the associated HCS compute system and waits up to 30 seconds for +// it to exit. +func (m *shimManager) Stop(_ context.Context, id string) (resp shim.StopStatus, err error) { + ctx, span := oc.StartSpan(context.Background(), "delete") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var bundlePath string + if opts, ok := ctx.Value(shim.OptsKey{}).(shim.Opts); ok { + bundlePath = opts.BundlePath + } + + if bundlePath == "" { + return resp, fmt.Errorf("bundle path not found in context") + } + + // hcsshim shim writes panic logs in the bundle directory in a file named "panic.log" + // log those messages (if any) on stderr so that it shows up in containerd's log. + // This should be done as the first thing so that we don't miss any panic logs even if + // something goes wrong during delete op. + // The file can be very large so read only first 1MB of data. + readLimit := int64(memory.MiB) // 1MB + logBytes, err := limitedRead(filepath.Join(bundlePath, "panic.log"), readLimit) + if err == nil && len(logBytes) > 0 { + if int64(len(logBytes)) == readLimit { + logrus.Warnf("shim panic log file %s is larger than 1MB, logging only first 1MB", filepath.Join(bundlePath, "panic.log")) + } + logrus.WithField("log", string(logBytes)).Warn("found shim panic logs during delete") + } else if err != nil && !errors.Is(err, os.ErrNotExist) { + logrus.WithError(err).Warn("failed to open shim panic log") + } + + // Attempt to find the hcssystem for this bundle and terminate it. + if sys, _ := hcs.OpenComputeSystem(ctx, id); sys != nil { + defer sys.Close() + if err := sys.Terminate(ctx); err != nil { + fmt.Fprintf(os.Stderr, "failed to terminate '%s': %v", id, err) + } else { + ch := make(chan error, 1) + go func() { ch <- sys.Wait() }() + t := time.NewTimer(time.Second * 30) + select { + case <-t.C: + sys.Close() + return resp, fmt.Errorf("timed out waiting for '%s' to terminate", id) + case err := <-ch: + t.Stop() + if err != nil { + fmt.Fprintf(os.Stderr, "failed to wait for '%s' to terminate: %v", id, err) + } + } + } + } + + resp = shim.StopStatus{ + ExitedAt: time.Now(), + // 255 exit code is used by convention to indicate unknown exit reason. + ExitStatus: 255, + } + return resp, nil +} + +// limitedRead reads at max `readLimitBytes` bytes from the file at path `filePath`. If the file has +// more than `readLimitBytes` bytes of data then first `readLimitBytes` will be returned. +// Read at most readLimitBytes so delete does not flood logs. +func limitedRead(filePath string, readLimitBytes int64) ([]byte, error) { + f, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("limited read failed to open file: %s: %w", filePath, err) + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return []byte{}, fmt.Errorf("limited read failed during file stat: %s: %w", filePath, err) + } + if fi.Size() < readLimitBytes { + readLimitBytes = fi.Size() + } + buf := make([]byte, readLimitBytes) + _, err = f.Read(buf) + if err != nil { + return []byte{}, fmt.Errorf("limited read failed during file read: %s: %w", filePath, err) + } + return buf, nil +} + +// Info returns runtime information about this shim including its name, version, +// git commit, OCI spec version, and any runtime options decoded from optionsR. +func (m *shimManager) Info(_ context.Context, optionsR io.Reader) (*types.RuntimeInfo, error) { + info := &types.RuntimeInfo{ + Name: m.name, + Version: &types.RuntimeVersion{ + Version: fmt.Sprintf("%s\ncommit: %s\nspec: %s", hcsversion.Version, hcsversion.Commit, specs.Version), + }, + Annotations: nil, + } + + opts, err := shim.ReadRuntimeOptions[*runhcsopts.Options](optionsR) + if err != nil { + if !errors.Is(err, errdefs.ErrNotFound) { + return nil, fmt.Errorf("failed to read runtime options (*options.Options): %w", err) + } + } + if opts != nil { + info.Options, err = typeurl.MarshalAnyToProto(opts) + if err != nil { + return nil, fmt.Errorf("failed to marshal %T: %w", opts, err) + } + } + + return info, nil +} diff --git a/cmd/containerd-shim-lcow-v2/manager_test.go b/cmd/containerd-shim-lcow-v2/manager_test.go new file mode 100644 index 0000000000..fd93692299 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/manager_test.go @@ -0,0 +1,46 @@ +//go:build windows + +package main + +import ( + "os" + "path/filepath" + "testing" +) + +// TestLimitedRead verifies that limitedRead correctly enforces the byte limit +// when the file is larger than the limit, and reads the full content when the +// file is smaller than the limit. +func TestLimitedRead(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "panic.log") + content := []byte("hello") + if err := os.WriteFile(filePath, content, 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + buf, err := limitedRead(filePath, 2) + if err != nil { + t.Fatalf("limitedRead: %v", err) + } + if string(buf) != "he" { + t.Fatalf("expected 'he', got %q", string(buf)) + } + + buf, err = limitedRead(filePath, 10) + if err != nil { + t.Fatalf("limitedRead: %v", err) + } + if string(buf) != "hello" { + t.Fatalf("expected 'hello', got %q", string(buf)) + } +} + +// TestLimitedReadMissingFile verifies that limitedRead returns an error when +// the target file does not exist. +func TestLimitedReadMissingFile(t *testing.T) { + _, err := limitedRead(filepath.Join(t.TempDir(), "missing.log"), 10) + if err == nil { + t.Fatalf("expected error for missing file") + } +} diff --git a/cmd/containerd-shim-lcow-v2/resource_windows_386.syso b/cmd/containerd-shim-lcow-v2/resource_windows_386.syso new file mode 100644 index 0000000000..5510dc97e2 Binary files /dev/null and b/cmd/containerd-shim-lcow-v2/resource_windows_386.syso differ diff --git a/cmd/containerd-shim-lcow-v2/resource_windows_amd64.syso b/cmd/containerd-shim-lcow-v2/resource_windows_amd64.syso new file mode 100644 index 0000000000..2c00dedb25 Binary files /dev/null and b/cmd/containerd-shim-lcow-v2/resource_windows_amd64.syso differ diff --git a/cmd/containerd-shim-lcow-v2/resource_windows_arm.syso b/cmd/containerd-shim-lcow-v2/resource_windows_arm.syso new file mode 100644 index 0000000000..2706f485e1 Binary files /dev/null and b/cmd/containerd-shim-lcow-v2/resource_windows_arm.syso differ diff --git a/cmd/containerd-shim-lcow-v2/resource_windows_arm64.syso b/cmd/containerd-shim-lcow-v2/resource_windows_arm64.syso new file mode 100644 index 0000000000..718ad2bfb8 Binary files /dev/null and b/cmd/containerd-shim-lcow-v2/resource_windows_arm64.syso differ diff --git a/cmd/containerd-shim-lcow-v2/service/plugin/plugin.go b/cmd/containerd-shim-lcow-v2/service/plugin/plugin.go new file mode 100644 index 0000000000..560b8de316 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/plugin/plugin.go @@ -0,0 +1,115 @@ +//go:build windows + +package plugin + +import ( + "context" + "os" + + "github.com/Microsoft/hcsshim/cmd/containerd-shim-lcow-v2/service" + "github.com/Microsoft/hcsshim/internal/shim" + "github.com/Microsoft/hcsshim/internal/shimdiag" + hcsversion "github.com/Microsoft/hcsshim/internal/version" + + "github.com/Microsoft/go-winio/pkg/etw" + "github.com/Microsoft/go-winio/pkg/etwlogrus" + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/containerd/containerd/v2/pkg/shutdown" + "github.com/containerd/containerd/v2/plugins" + "github.com/containerd/plugin" + "github.com/containerd/plugin/registry" + "github.com/sirupsen/logrus" +) + +const ( + // etwProviderName is the ETW provider name for lcow shim. + etwProviderName = "Microsoft.Virtualization.RunHCSLCOW" +) + +// svc holds the single Service instance created during plugin initialization. +var svc *service.Service + +func init() { + // Provider ID: 64F6FC7F-8326-5EE8-B890-3734AE584136 + // Provider and hook aren't closed explicitly, as they will exist until process exit. + provider, err := etw.NewProvider(etwProviderName, etwCallback) + if err != nil { + logrus.Error(err) + } else { + if hook, err := etwlogrus.NewHookFromProvider(provider); err == nil { + logrus.AddHook(hook) + } else { + logrus.Error(err) + } + } + + // Write the "ShimLaunched" event with the shim's command-line arguments. + _ = provider.WriteEvent( + "ShimLaunched", + nil, + etw.WithFields( + etw.StringArray("Args", os.Args), + etw.StringField("Version", hcsversion.Version), + etw.StringField("GitCommit", hcsversion.Commit), + ), + ) + + // Register the shim's TTRPC plugin with the containerd plugin registry. + // The plugin depends on the event publisher (for publishing task/sandbox + // events to containerd) and the internal shutdown service (for co-ordinated + // graceful teardown). + registry.Register(&plugin.Registration{ + Type: plugins.TTRPCPlugin, + ID: "shim-services", + Requires: []plugin.Type{ + plugins.EventPlugin, + plugins.InternalPlugin, + }, + InitFn: func(ic *plugin.InitContext) (interface{}, error) { + pp, err := ic.GetByID(plugins.EventPlugin, "publisher") + if err != nil { + return nil, err + } + ss, err := ic.GetByID(plugins.InternalPlugin, "shutdown") + if err != nil { + return nil, err + } + // We will register all the services namely- + // 1. Sandbox service + // 2. Task service + // 3. Shimdiag service + svc = service.NewService( + ic.Context, + pp.(shim.Publisher), + ss.(shutdown.Service), + ) + + return svc, nil + }, + }) +} + +// etwCallback is the ETW callback method for this shim. +// +// On a CaptureState notification (triggered by tools such as wpr or xperf) it +// dumps all goroutine stacks – both host-side Go stacks and, when available, +// the guest Linux stacks – to the logrus logger tagged with the sandbox ID. +// This provides an out-of-band diagnostic snapshot without requiring the shim +// to be paused or restarted. +func etwCallback(sourceID guid.GUID, state etw.ProviderState, level etw.Level, matchAnyKeyword uint64, matchAllKeyword uint64, filterData uintptr) { + if state == etw.ProviderStateCaptureState { + if svc == nil { + logrus.Warn("service not initialized") + return + } + resp, err := svc.DiagStacks(context.Background(), &shimdiag.StacksRequest{}) + if err != nil { + return + } + log := logrus.WithField("sandboxID", svc.SandboxID()) + log.WithField("stack", resp.Stacks).Info("goroutine stack dump") + if resp.GuestStacks != "" { + log.WithField("stack", resp.GuestStacks).Info("guest stack dump") + } + } +} diff --git a/cmd/containerd-shim-lcow-v2/service/service.go b/cmd/containerd-shim-lcow-v2/service/service.go new file mode 100644 index 0000000000..bd9b33cd76 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service.go @@ -0,0 +1,115 @@ +//go:build windows + +package service + +import ( + "context" + "sync" + + "github.com/Microsoft/hcsshim/internal/builder/vm/lcow" + "github.com/Microsoft/hcsshim/internal/controller/vm" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/shim" + "github.com/Microsoft/hcsshim/internal/shimdiag" + + sandboxsvc "github.com/containerd/containerd/api/runtime/sandbox/v1" + tasksvc "github.com/containerd/containerd/api/runtime/task/v3" + "github.com/containerd/containerd/v2/core/runtime" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/containerd/v2/pkg/shutdown" + "github.com/containerd/ttrpc" +) + +// Service is the shared Service struct that implements all TTRPC Service interfaces. +// All Service methods (sandbox, task, and shimdiag) operate on this shared struct. +type Service struct { + // mu is used to synchronize access to shared state within the Service. + mu sync.Mutex + + // publisher is used to publish events from the shim to containerd. + publisher shim.Publisher + // events is a buffered channel used to queue events before they are published to containerd. + events chan interface{} + + // sandboxID is the unique identifier for the sandbox managed by this Service instance. + // For LCOW shim, sandboxID corresponds 1-1 with the UtilityVM managed by the shim. + sandboxID string + + // sandboxOptions contains parsed, shim-level configuration for the sandbox + // such as architecture and confidential-compute settings. + sandboxOptions *lcow.SandboxOptions + + // vmController is responsible for managing the lifecycle of the underlying utility VM and its associated resources. + vmController vm.Controller + + // shutdown manages graceful shutdown operations and allows registration of cleanup callbacks. + shutdown shutdown.Service +} + +var _ shim.TTRPCService = (*Service)(nil) + +// NewService creates a new instance of the Service with the shared state. +func NewService(ctx context.Context, eventsPublisher shim.Publisher, sd shutdown.Service) *Service { + svc := &Service{ + publisher: eventsPublisher, + events: make(chan interface{}, 128), // Buffered channel for events + vmController: vm.NewController(), + shutdown: sd, + } + + go svc.forward(ctx, eventsPublisher) + + // Register a shutdown callback to close the events channel, + // which signals the forward goroutine to exit. + sd.RegisterCallback(func(context.Context) error { + close(svc.events) + return nil + }) + + // Perform best-effort VM cleanup on shutdown. + sd.RegisterCallback(func(ctx context.Context) error { + _ = svc.vmController.TerminateVM(ctx) + return nil + }) + + return svc +} + +// RegisterTTRPC registers the Task, Sandbox, and ShimDiag TTRPC services on +// the provided server so that containerd can call into the shim over TTRPC. +func (s *Service) RegisterTTRPC(server *ttrpc.Server) error { + tasksvc.RegisterTTRPCTaskService(server, s) + sandboxsvc.RegisterTTRPCSandboxService(server, s) + shimdiag.RegisterShimDiagService(server, s) + return nil +} + +// SandboxID returns the unique identifier for the sandbox managed by this Service. +func (s *Service) SandboxID() string { + return s.sandboxID +} + +// send enqueues an event onto the internal events channel so that it can be +// forwarded to containerd asynchronously by the forward goroutine. +// +// TODO: wire up send() for task events once task lifecycle methods are implemented. +// +//nolint:unused +func (s *Service) send(evt interface{}) { + s.events <- evt +} + +// forward runs in a dedicated goroutine and publishes events from the internal +// events channel to containerd using the provided Publisher. It exits when the +// events channel is closed (which happens during graceful shutdown). +func (s *Service) forward(ctx context.Context, publisher shim.Publisher) { + ns, _ := namespaces.Namespace(ctx) + ctx = namespaces.WithNamespace(context.Background(), ns) + for e := range s.events { + err := publisher.Publish(ctx, runtime.GetTopic(e), e) + if err != nil { + log.G(ctx).WithError(err).Error("post event") + } + } + _ = publisher.Close() +} diff --git a/cmd/containerd-shim-lcow-v2/service/service_sandbox.go b/cmd/containerd-shim-lcow-v2/service/service_sandbox.go new file mode 100644 index 0000000000..a82aa6b467 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service_sandbox.go @@ -0,0 +1,170 @@ +//go:build windows + +package service + +import ( + "context" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oc" + + "github.com/containerd/containerd/api/runtime/sandbox/v1" + errdefs2 "github.com/containerd/errdefs/pkg/errgrpc" + "github.com/sirupsen/logrus" + "go.opencensus.io/trace" +) + +// Ensure Service implements the TTRPCSandboxService interface at compile time. +var _ sandbox.TTRPCSandboxService = &Service{} + +// CreateSandbox creates (or prepares) a new sandbox for the given SandboxID. +// This method is part of the instrumentation layer and business logic is included in createSandboxInternal. +func (s *Service) CreateSandbox(ctx context.Context, request *sandbox.CreateSandboxRequest) (resp *sandbox.CreateSandboxResponse, err error) { + ctx, span := oc.StartSpan(ctx, "CreateSandbox") + defer span.End() + defer func() { + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, request.SandboxID), + trace.StringAttribute(logfields.Bundle, request.BundlePath), + trace.StringAttribute(logfields.NetNsPath, request.NetnsPath), + ) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.createSandboxInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// StartSandbox transitions a previously created sandbox to the "running" state. +// This method is part of the instrumentation layer and business logic is included in startSandboxInternal. +func (s *Service) StartSandbox(ctx context.Context, request *sandbox.StartSandboxRequest) (resp *sandbox.StartSandboxResponse, err error) { + ctx, span := oc.StartSpan(ctx, "StartSandbox") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.startSandboxInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// Platform returns the platform details for the sandbox ("windows/amd64" or "linux/amd64"). +// This method is part of the instrumentation layer and business logic is included in platformInternal. +func (s *Service) Platform(ctx context.Context, request *sandbox.PlatformRequest) (resp *sandbox.PlatformResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Platform") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + + r, e := s.platformInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// StopSandbox attempts a graceful stop of the sandbox within the specified timeout. +// This method is part of the instrumentation layer and business logic is included in stopSandboxInternal. +func (s *Service) StopSandbox(ctx context.Context, request *sandbox.StopSandboxRequest) (resp *sandbox.StopSandboxResponse, err error) { + ctx, span := oc.StartSpan(ctx, "StopSandbox") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + span.AddAttributes(trace.Int64Attribute(logfields.Timeout, int64(request.TimeoutSecs))) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.stopSandboxInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// WaitSandbox blocks until the sandbox reaches a terminal state (stopped/errored) and returns the outcome. +// This method is part of the instrumentation layer and business logic is included in waitSandboxInternal. +func (s *Service) WaitSandbox(ctx context.Context, request *sandbox.WaitSandboxRequest) (resp *sandbox.WaitSandboxResponse, err error) { + ctx, span := oc.StartSpan(ctx, "WaitSandbox") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.waitSandboxInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// SandboxStatus returns current status for the sandbox, optionally verbose. +// This method is part of the instrumentation layer and business logic is included in sandboxStatusInternal. +func (s *Service) SandboxStatus(ctx context.Context, request *sandbox.SandboxStatusRequest) (resp *sandbox.SandboxStatusResponse, err error) { + ctx, span := oc.StartSpan(ctx, "SandboxStatus") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + span.AddAttributes(trace.BoolAttribute(logfields.Verbose, request.Verbose)) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.sandboxStatusInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// PingSandbox performs a minimal liveness check on the sandbox and returns quickly. +// This method is part of the instrumentation layer and business logic is included in pingSandboxInternal. +func (s *Service) PingSandbox(ctx context.Context, request *sandbox.PingRequest) (resp *sandbox.PingResponse, err error) { + ctx, span := oc.StartSpan(ctx, "PingSandbox") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.pingSandboxInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// ShutdownSandbox requests a full shim + sandbox shutdown (stronger than StopSandbox), +// typically used by the higher-level controller to tear down resources and exit the shim. +// This method is part of the instrumentation layer and business logic is included in shutdownSandboxInternal. +func (s *Service) ShutdownSandbox(ctx context.Context, request *sandbox.ShutdownSandboxRequest) (resp *sandbox.ShutdownSandboxResponse, err error) { + ctx, span := oc.StartSpan(ctx, "ShutdownSandbox") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.shutdownSandboxInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} + +// SandboxMetrics returns runtime metrics for the sandbox (e.g., CPU/memory/IO), +// suitable for monitoring and autoscaling decisions. +// This method is part of the instrumentation layer and business logic is included in sandboxMetricsInternal. +func (s *Service) SandboxMetrics(ctx context.Context, request *sandbox.SandboxMetricsRequest) (resp *sandbox.SandboxMetricsResponse, err error) { + ctx, span := oc.StartSpan(ctx, "SandboxMetrics") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, request.SandboxID)) + + // Set the sandbox ID in the logger context for all subsequent logs in this request. + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.SandboxID, request.SandboxID)) + + r, e := s.sandboxMetricsInternal(ctx, request) + return r, errdefs2.ToGRPC(e) +} diff --git a/cmd/containerd-shim-lcow-v2/service/service_sandbox_internal.go b/cmd/containerd-shim-lcow-v2/service/service_sandbox_internal.go new file mode 100644 index 0000000000..364c5807d9 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service_sandbox_internal.go @@ -0,0 +1,321 @@ +//go:build windows + +package service + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/Microsoft/hcsshim/internal/builder/vm/lcow" + "github.com/Microsoft/hcsshim/internal/controller/vm" + "github.com/Microsoft/hcsshim/internal/gcs/prot" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" + vmsandbox "github.com/Microsoft/hcsshim/sandbox-spec/vm/v2" + + "github.com/Microsoft/go-winio" + "github.com/containerd/containerd/api/runtime/sandbox/v1" + "github.com/containerd/containerd/api/types" + "github.com/containerd/errdefs" + "github.com/containerd/typeurl/v2" + "golang.org/x/sys/windows" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + // linuxPlatform refers to the Linux guest OS platform. + linuxPlatform = "linux" + + // SandboxStateReady indicates the sandbox is ready. + SandboxStateReady = "SANDBOX_READY" + // SandboxStateNotReady indicates the sandbox is not ready. + SandboxStateNotReady = "SANDBOX_NOTREADY" +) + +// createSandboxInternal is the implementation for CreateSandbox. +// +// It enforces that only one sandbox can exist per shim instance (this shim +// follows a one-sandbox-per-shim model). It builds the HCS compute-system +// document from the sandbox spec and delegates VM creation to vmController. +func (s *Service) createSandboxInternal(ctx context.Context, request *sandbox.CreateSandboxRequest) (*sandbox.CreateSandboxResponse, error) { + // Decode the Sandbox spec passed along from CRI. + var sandboxSpec vmsandbox.Spec + f, err := os.Open(filepath.Join(request.BundlePath, "config.json")) + if err != nil { + return nil, err + } + if err := json.NewDecoder(f).Decode(&sandboxSpec); err != nil { + _ = f.Close() + return nil, err + } + _ = f.Close() + + // Decode the runtime options. + shimOpts, err := vmutils.UnmarshalRuntimeOptions(ctx, request.Options) + if err != nil { + return nil, err + } + + // We take a lock at this point so that if there are multiple parallel calls to CreateSandbox, + // only one will succeed in creating the sandbox. The successful caller will set the sandboxID, + // which will cause the other call(s) to fail with an error indicating that a sandbox already exists. + s.mu.Lock() + defer s.mu.Unlock() + + if s.sandboxID != "" { + return nil, fmt.Errorf("failed to create sandbox: sandbox already exists with ID %s", s.sandboxID) + } + + hcsDocument, sandboxOptions, err := lcow.BuildSandboxConfig(ctx, vmutils.LCOWShimName, request.BundlePath, shimOpts, &sandboxSpec) + if err != nil { + return nil, fmt.Errorf("failed to parse sandbox spec: %w", err) + } + + s.sandboxOptions = sandboxOptions + + err = s.vmController.CreateVM(ctx, &vm.CreateOptions{ + ID: fmt.Sprintf("%s@vm", request.SandboxID), + HCSDocument: hcsDocument, + }) + if err != nil { + return nil, fmt.Errorf("failed to create VM: %w", err) + } + + // By setting the sandboxID here, we ensure that any parallel calls for CreateSandbox + // will fail with an error. + // Also, setting it here acts as a synchronization point - we know that if sandboxID is set, + // then the VM has been created successfully and sandboxOptions has been populated. + s.sandboxID = request.SandboxID + + return &sandbox.CreateSandboxResponse{}, nil +} + +// startSandboxInternal is the implementation for StartSandbox. +// +// It instructs the vmController to start the VM. If the +// sandbox was created with confidential settings, confidential options are +// applied to the VM after starting. +func (s *Service) startSandboxInternal(ctx context.Context, request *sandbox.StartSandboxRequest) (*sandbox.StartSandboxResponse, error) { + if s.sandboxID != request.SandboxID { + return nil, fmt.Errorf("failed to start sandbox: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + // If we successfully got past the above check, it means the sandbox was created and + // the sandboxOptions should be populated. + var confidentialOpts *guestresource.ConfidentialOptions + if s.sandboxOptions != nil && s.sandboxOptions.ConfidentialConfig != nil { + uvmReferenceInfoEncoded, err := vmutils.ParseUVMReferenceInfo( + ctx, + vmutils.DefaultLCOWOSBootFilesPath(), + s.sandboxOptions.ConfidentialConfig.UvmReferenceInfoFile, + ) + if err != nil { + return nil, fmt.Errorf("failed to parse UVM reference info: %w", err) + } + confidentialOpts = &guestresource.ConfidentialOptions{ + EnforcerType: s.sandboxOptions.ConfidentialConfig.SecurityPolicyEnforcer, + EncodedSecurityPolicy: s.sandboxOptions.ConfidentialConfig.SecurityPolicy, + EncodedUVMReference: uvmReferenceInfoEncoded, + } + } + + // VM controller ensures that only once of the Start call goes through. + err := s.vmController.StartVM(ctx, &vm.StartOptions{ + GCSServiceID: winio.VsockServiceID(prot.LinuxGcsVsockPort), + ConfidentialOptions: confidentialOpts, + }) + if err != nil { + return nil, fmt.Errorf("failed to start VM: %w", err) + } + + return &sandbox.StartSandboxResponse{ + CreatedAt: timestamppb.New(s.vmController.StartTime()), + }, nil +} + +// platformInternal is the implementation for Platform. +// +// It returns the guest OS and CPU architecture for the sandbox. +// An error is returned if the sandbox is not currently in the created state. +func (s *Service) platformInternal(_ context.Context, request *sandbox.PlatformRequest) (*sandbox.PlatformResponse, error) { + if s.sandboxID != request.SandboxID { + return nil, fmt.Errorf("failed to get platform: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + if s.vmController.State() == vm.StateNotCreated { + return nil, fmt.Errorf("failed to get platform: sandbox has not been created (state: %s)", s.vmController.State()) + } + + return &sandbox.PlatformResponse{ + Platform: &types.Platform{ + OS: linuxPlatform, + Architecture: s.sandboxOptions.Architecture, + }, + }, nil +} + +// stopSandboxInternal is the implementation for StopSandbox. +// +// It terminates the VM and performs any cleanup, if needed. +func (s *Service) stopSandboxInternal(ctx context.Context, request *sandbox.StopSandboxRequest) (*sandbox.StopSandboxResponse, error) { + if s.sandboxID != request.SandboxID { + return nil, fmt.Errorf("failed to stop sandbox: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + err := s.vmController.TerminateVM(ctx) + if err != nil { + return nil, fmt.Errorf("failed to terminate VM: %w", err) + } + + return &sandbox.StopSandboxResponse{}, nil +} + +// waitSandboxInternal is the implementation for WaitSandbox. +// +// It blocks until the underlying VM has been terminated, then maps the exit status +// to a sandbox exit code. +func (s *Service) waitSandboxInternal(ctx context.Context, request *sandbox.WaitSandboxRequest) (*sandbox.WaitSandboxResponse, error) { + if s.sandboxID != request.SandboxID { + return nil, fmt.Errorf("failed to wait for sandbox: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + // Wait for the VM to be terminated, then return the exit code. + // This is a blocking call that will wait until the VM is stopped. + err := s.vmController.Wait(ctx) + if err != nil { + return nil, fmt.Errorf("failed to wait for VM: %w", err) + } + + exitStatus, err := s.vmController.ExitStatus() + if err != nil { + return nil, fmt.Errorf("failed to get sandbox exit status: %w", err) + } + + exitStatusCode := 0 + // If there was an exit error, set a non-zero exit status. + if exitStatus.Err != nil { + exitStatusCode = int(windows.ERROR_INTERNAL_ERROR) + } + + return &sandbox.WaitSandboxResponse{ + ExitStatus: uint32(exitStatusCode), + ExitedAt: timestamppb.New(exitStatus.StoppedTime), + }, nil +} + +// sandboxStatusInternal is the implementation for SandboxStatus. +// +// It synthesizes a status response from the current vmController state. +// When verbose is true, the response may be extended with additional +// diagnostic information. +func (s *Service) sandboxStatusInternal(_ context.Context, request *sandbox.SandboxStatusRequest) (*sandbox.SandboxStatusResponse, error) { + if s.sandboxID != request.SandboxID { + return nil, fmt.Errorf("failed to get sandbox status: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + resp := &sandbox.SandboxStatusResponse{ + SandboxID: request.SandboxID, + } + + switch vmState := s.vmController.State(); vmState { + case vm.StateNotCreated, vm.StateCreated, vm.StateInvalid: + // VM has not started yet or is in invalid state; return the default not-ready response. + resp.State = SandboxStateNotReady + return resp, nil + case vm.StateRunning: + // VM is running, so we can report the created time and ready state. + resp.State = SandboxStateReady + resp.CreatedAt = timestamppb.New(s.vmController.StartTime()) + case vm.StateTerminated: + // VM has stopped, so we can report the created time, exited time, and not-ready state. + resp.State = SandboxStateNotReady + resp.CreatedAt = timestamppb.New(s.vmController.StartTime()) + stoppedStatus, err := s.vmController.ExitStatus() + if err != nil { + return nil, fmt.Errorf("failed to get sandbox stopped status: %w", err) + } + resp.ExitedAt = timestamppb.New(stoppedStatus.StoppedTime) + } + + if request.Verbose { //nolint:staticcheck + // TODO: Add compat info and any other details. + } + + return resp, nil +} + +// pingSandboxInternal is the implementation for PingSandbox. +// +// Ping is not yet implemented for this shim. +func (s *Service) pingSandboxInternal(_ context.Context, _ *sandbox.PingRequest) (*sandbox.PingResponse, error) { + // This functionality is not yet applicable for this shim. + // Best scenario, we can return true if the VM is running. + return nil, errdefs.ErrNotImplemented +} + +// shutdownSandboxInternal is used to trigger sandbox shutdown when the shim receives +// a shutdown request from containerd. +// +// The sandbox must already be in the stopped state before shutdown is accepted. +func (s *Service) shutdownSandboxInternal(ctx context.Context, request *sandbox.ShutdownSandboxRequest) (*sandbox.ShutdownSandboxResponse, error) { + if s.sandboxID != request.SandboxID { + return &sandbox.ShutdownSandboxResponse{}, fmt.Errorf("failed to shutdown sandbox: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + // Ensure the VM is terminated. If the VM is already terminated, + // TerminateVM is a no-op, so this is safe to call regardless of the current VM state. + if state := s.vmController.State(); state != vm.StateTerminated { + err := s.vmController.TerminateVM(ctx) + if err != nil { + // Just log the error instead of returning it since this is a best effort cleanup. + log.G(ctx).WithError(err).Error("failed to terminate VM during shutdown") + } + } + + // With gRPC/TTRPC, the transport later creates a child context for each incoming request, + // and cancels that context when the handler returns or the client-side connection is dropped. + // For the shutdown request, if we call shutdown.Shutdown() directly, the shim process exits + // prior to the response being sent back to containerd, which causes the shutdown call to fail. + // Therefore, use a goroutine to wait for the RPC context to be done after which + // we can safely call shutdown.Shutdown() without risking an early process exit. + go func() { + <-ctx.Done() + time.Sleep(20 * time.Millisecond) // tiny cushion to avoid edge races + + s.shutdown.Shutdown() + }() + + return &sandbox.ShutdownSandboxResponse{}, nil +} + +// sandboxMetricsInternal is the implementation for SandboxMetrics. +// +// It collects and returns runtime statistics from the vmController. +func (s *Service) sandboxMetricsInternal(ctx context.Context, request *sandbox.SandboxMetricsRequest) (*sandbox.SandboxMetricsResponse, error) { + if s.sandboxID != request.SandboxID { + return &sandbox.SandboxMetricsResponse{}, fmt.Errorf("failed to get sandbox metrics: sandbox ID mismatch, expected %s, got %s", s.sandboxID, request.SandboxID) + } + + stats, err := s.vmController.Stats(ctx) + if err != nil { + return &sandbox.SandboxMetricsResponse{}, fmt.Errorf("failed to get sandbox metrics: %w", err) + } + + anyStat, err := typeurl.MarshalAny(stats) + if err != nil { + return &sandbox.SandboxMetricsResponse{}, fmt.Errorf("failed to marshal sandbox metrics: %w", err) + } + + return &sandbox.SandboxMetricsResponse{ + Metrics: &types.Metric{ + Timestamp: timestamppb.Now(), + ID: request.SandboxID, + Data: typeurl.MarshalProto(anyStat), + }, + }, nil +} diff --git a/cmd/containerd-shim-lcow-v2/service/service_shimdiag.go b/cmd/containerd-shim-lcow-v2/service/service_shimdiag.go new file mode 100644 index 0000000000..503982d59e --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service_shimdiag.go @@ -0,0 +1,97 @@ +//go:build windows + +package service + +import ( + "context" + "os" + "strings" + + "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/internal/shimdiag" + + "github.com/containerd/errdefs/pkg/errgrpc" + "go.opencensus.io/trace" +) + +// Ensure Service implements the ShimDiagService interface at compile time. +var _ shimdiag.ShimDiagService = &Service{} + +// DiagExecInHost executes a process in the host namespace for diagnostic purposes. +// This method is part of the instrumentation layer and business logic is included in diagExecInHostInternal. +func (s *Service) DiagExecInHost(ctx context.Context, request *shimdiag.ExecProcessRequest) (resp *shimdiag.ExecProcessResponse, err error) { + ctx, span := oc.StartSpan(ctx, "DiagExecInHost") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.Args, strings.Join(request.Args, " ")), + trace.StringAttribute(logfields.Workdir, request.Workdir), + trace.BoolAttribute(logfields.Terminal, request.Terminal), + trace.StringAttribute(logfields.Stdin, request.Stdin), + trace.StringAttribute(logfields.Stdout, request.Stdout), + trace.StringAttribute(logfields.Stderr, request.Stderr)) + + r, e := s.diagExecInHostInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// DiagTasks returns information about all tasks in the shim. +// This method is part of the instrumentation layer and business logic is included in diagTasksInternal. +func (s *Service) DiagTasks(ctx context.Context, request *shimdiag.TasksRequest) (resp *shimdiag.TasksResponse, err error) { + ctx, span := oc.StartSpan(ctx, "DiagTasks") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.BoolAttribute(logfields.Execs, request.Execs)) + + r, e := s.diagTasksInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// DiagShare shares a directory from the host into the sandbox. +// This method is part of the instrumentation layer and business logic is included in diagShareInternal. +func (s *Service) DiagShare(ctx context.Context, request *shimdiag.ShareRequest) (resp *shimdiag.ShareResponse, err error) { + ctx, span := oc.StartSpan(ctx, "DiagShare") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.HostPath, request.HostPath), + trace.StringAttribute(logfields.UVMPath, request.UvmPath), + trace.BoolAttribute(logfields.ReadOnly, request.ReadOnly)) + + r, e := s.diagShareInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// DiagStacks returns the stack traces of all goroutines in the shim. +// This method is part of the instrumentation layer and business logic is included in diagStacksInternal. +func (s *Service) DiagStacks(ctx context.Context, request *shimdiag.StacksRequest) (resp *shimdiag.StacksResponse, err error) { + ctx, span := oc.StartSpan(ctx, "DiagStacks") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, s.sandboxID)) + + r, e := s.diagStacksInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// DiagPid returns the process ID (PID) of the shim for diagnostic purposes. +func (s *Service) DiagPid(ctx context.Context, _ *shimdiag.PidRequest) (resp *shimdiag.PidResponse, err error) { + _, span := oc.StartSpan(ctx, "DiagPid") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes(trace.StringAttribute(logfields.SandboxID, s.sandboxID)) + + return &shimdiag.PidResponse{ + Pid: int32(os.Getpid()), + }, nil +} diff --git a/cmd/containerd-shim-lcow-v2/service/service_shimdiag_internal.go b/cmd/containerd-shim-lcow-v2/service/service_shimdiag_internal.go new file mode 100644 index 0000000000..a835ade320 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service_shimdiag_internal.go @@ -0,0 +1,35 @@ +//go:build windows + +package service + +import ( + "context" + "fmt" + + "github.com/Microsoft/hcsshim/internal/shimdiag" + "github.com/containerd/errdefs" +) + +// diagExecInHostInternal is the implementation for DiagExecInHost. +// +// It is used to create an exec session into the hosting UVM. +func (s *Service) diagExecInHostInternal(ctx context.Context, request *shimdiag.ExecProcessRequest) (*shimdiag.ExecProcessResponse, error) { + ec, err := s.vmController.ExecIntoHost(ctx, request) + if err != nil { + return nil, fmt.Errorf("failed to exec into host: %w", err) + } + + return &shimdiag.ExecProcessResponse{ExitCode: int32(ec)}, nil +} + +func (s *Service) diagTasksInternal(_ context.Context, _ *shimdiag.TasksRequest) (*shimdiag.TasksResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) diagShareInternal(_ context.Context, _ *shimdiag.ShareRequest) (*shimdiag.ShareResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) diagStacksInternal(_ context.Context, _ *shimdiag.StacksRequest) (*shimdiag.StacksResponse, error) { + return nil, errdefs.ErrNotImplemented +} diff --git a/cmd/containerd-shim-lcow-v2/service/service_task.go b/cmd/containerd-shim-lcow-v2/service/service_task.go new file mode 100644 index 0000000000..f7f7dda5af --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service_task.go @@ -0,0 +1,339 @@ +//go:build windows + +package service + +import ( + "context" + + "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oc" + + "github.com/containerd/containerd/api/runtime/task/v3" + "github.com/containerd/errdefs/pkg/errgrpc" + "go.opencensus.io/trace" + "google.golang.org/protobuf/types/known/emptypb" +) + +// Ensure Service implements the TTRPCTaskService interface at compile time. +var _ task.TTRPCTaskService = &Service{} + +// State returns the current state of a task or process. +// This method is part of the instrumentation layer and business logic is included in stateInternal. +func (s *Service) State(ctx context.Context, request *task.StateRequest) (resp *task.StateResponse, err error) { + ctx, span := oc.StartSpan(ctx, "State") + defer span.End() + defer func() { + if resp != nil { + span.AddAttributes( + trace.StringAttribute(logfields.Status, resp.Status.String()), + trace.Int64Attribute(logfields.ExitStatus, int64(resp.ExitStatus)), + trace.StringAttribute(logfields.ExitedAt, resp.ExitedAt.String())) + } + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID)) + + r, e := s.stateInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Create creates a new task. +// This method is part of the instrumentation layer and business logic is included in createInternal. +func (s *Service) Create(ctx context.Context, request *task.CreateTaskRequest) (resp *task.CreateTaskResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Create") + defer span.End() + defer func() { + if resp != nil { + span.AddAttributes(trace.Int64Attribute(logfields.ProcessID, int64(resp.Pid))) + } + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.Bundle, request.Bundle), + trace.BoolAttribute(logfields.Terminal, request.Terminal), + trace.StringAttribute(logfields.Stdin, request.Stdin), + trace.StringAttribute(logfields.Stdout, request.Stdout), + trace.StringAttribute(logfields.Stderr, request.Stderr), + trace.StringAttribute(logfields.Checkpoint, request.Checkpoint), + trace.StringAttribute(logfields.ParentCheckpoint, request.ParentCheckpoint)) + + r, e := s.createInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Start starts a previously created task. +// This method is part of the instrumentation layer and business logic is included in startInternal. +func (s *Service) Start(ctx context.Context, request *task.StartRequest) (resp *task.StartResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Start") + defer span.End() + defer func() { + if resp != nil { + span.AddAttributes(trace.Int64Attribute(logfields.ProcessID, int64(resp.Pid))) + } + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID)) + + r, e := s.startInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Delete deletes a task and returns its exit status. +// This method is part of the instrumentation layer and business logic is included in deleteInternal. +func (s *Service) Delete(ctx context.Context, request *task.DeleteRequest) (resp *task.DeleteResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Delete") + defer span.End() + defer func() { + if resp != nil { + span.AddAttributes( + trace.Int64Attribute(logfields.ProcessID, int64(resp.Pid)), + trace.Int64Attribute(logfields.ExitStatus, int64(resp.ExitStatus)), + trace.StringAttribute(logfields.ExitedAt, resp.ExitedAt.String())) + } + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID)) + + r, e := s.deleteInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Pids returns all process IDs for a task. +// This method is part of the instrumentation layer and business logic is included in pidsInternal. +func (s *Service) Pids(ctx context.Context, request *task.PidsRequest) (resp *task.PidsResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Pids") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.pidsInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Pause pauses a task. +// This method is part of the instrumentation layer and business logic is included in pauseInternal. +func (s *Service) Pause(ctx context.Context, request *task.PauseRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Pause") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.pauseInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Resume resumes a previously paused task. +// This method is part of the instrumentation layer and business logic is included in resumeInternal. +func (s *Service) Resume(ctx context.Context, request *task.ResumeRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Resume") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.resumeInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Checkpoint creates a checkpoint of a task. +// This method is part of the instrumentation layer and business logic is included in checkpointInternal. +func (s *Service) Checkpoint(ctx context.Context, request *task.CheckpointTaskRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Checkpoint") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.Path, request.Path)) + + r, e := s.checkpointInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Kill sends a signal to a task or process. +// This method is part of the instrumentation layer and business logic is included in killInternal. +func (s *Service) Kill(ctx context.Context, request *task.KillRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Kill") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID), + trace.Int64Attribute(logfields.Signal, int64(request.Signal)), + trace.BoolAttribute(logfields.All, request.All)) + + r, e := s.killInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Exec executes an additional process inside a task. +// This method is part of the instrumentation layer and business logic is included in execInternal. +func (s *Service) Exec(ctx context.Context, request *task.ExecProcessRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Exec") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID), + trace.BoolAttribute(logfields.Terminal, request.Terminal), + trace.StringAttribute(logfields.Stdin, request.Stdin), + trace.StringAttribute(logfields.Stdout, request.Stdout), + trace.StringAttribute(logfields.Stderr, request.Stderr)) + + r, e := s.execInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// ResizePty resizes the terminal of a process. +// This method is part of the instrumentation layer and business logic is included in resizePtyInternal. +func (s *Service) ResizePty(ctx context.Context, request *task.ResizePtyRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "ResizePty") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID), + trace.Int64Attribute(logfields.Width, int64(request.Width)), + trace.Int64Attribute(logfields.Height, int64(request.Height))) + + r, e := s.resizePtyInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// CloseIO closes the IO for a process. +// This method is part of the instrumentation layer and business logic is included in closeIOInternal. +func (s *Service) CloseIO(ctx context.Context, request *task.CloseIORequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "CloseIO") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID), + trace.BoolAttribute(logfields.Stdin, request.Stdin)) + + r, e := s.closeIOInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Update updates a running task with new resource constraints. +// This method is part of the instrumentation layer and business logic is included in updateInternal. +func (s *Service) Update(ctx context.Context, request *task.UpdateTaskRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Update") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.updateInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Wait waits for a task or process to exit. +// This method is part of the instrumentation layer and business logic is included in waitInternal. +func (s *Service) Wait(ctx context.Context, request *task.WaitRequest) (resp *task.WaitResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Wait") + defer span.End() + defer func() { + if resp != nil { + span.AddAttributes( + trace.Int64Attribute(logfields.ExitStatus, int64(resp.ExitStatus)), + trace.StringAttribute(logfields.ExitedAt, resp.ExitedAt.String())) + } + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID), + trace.StringAttribute(logfields.ExecID, request.ExecID)) + + r, e := s.waitInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Stats returns resource usage statistics for a task. +// This method is part of the instrumentation layer and business logic is included in statsInternal. +func (s *Service) Stats(ctx context.Context, request *task.StatsRequest) (resp *task.StatsResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Stats") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.statsInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Connect reconnects to a running task. +// This method is part of the instrumentation layer and business logic is included in connectInternal. +func (s *Service) Connect(ctx context.Context, request *task.ConnectRequest) (resp *task.ConnectResponse, err error) { + ctx, span := oc.StartSpan(ctx, "Connect") + defer span.End() + defer func() { + if resp != nil { + span.AddAttributes( + trace.Int64Attribute(logfields.ShimPid, int64(resp.ShimPid)), + trace.Int64Attribute(logfields.TaskPid, int64(resp.TaskPid)), + trace.StringAttribute(logfields.Version, resp.Version)) + } + oc.SetSpanStatus(span, err) + }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.connectInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} + +// Shutdown gracefully shuts down the Service. +// This method is part of the instrumentation layer and business logic is included in shutdownInternal. +func (s *Service) Shutdown(ctx context.Context, request *task.ShutdownRequest) (resp *emptypb.Empty, err error) { + ctx, span := oc.StartSpan(ctx, "Shutdown") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.StringAttribute(logfields.SandboxID, s.sandboxID), + trace.StringAttribute(logfields.ID, request.ID)) + + r, e := s.shutdownInternal(ctx, request) + return r, errgrpc.ToGRPC(e) +} diff --git a/cmd/containerd-shim-lcow-v2/service/service_task_internal.go b/cmd/containerd-shim-lcow-v2/service/service_task_internal.go new file mode 100644 index 0000000000..254199873b --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/service/service_task_internal.go @@ -0,0 +1,79 @@ +//go:build windows + +package service + +import ( + "context" + + "github.com/containerd/containerd/api/runtime/task/v3" + "github.com/containerd/errdefs" + "google.golang.org/protobuf/types/known/emptypb" +) + +func (s *Service) stateInternal(_ context.Context, _ *task.StateRequest) (*task.StateResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) createInternal(_ context.Context, _ *task.CreateTaskRequest) (*task.CreateTaskResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) startInternal(_ context.Context, _ *task.StartRequest) (*task.StartResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) deleteInternal(_ context.Context, _ *task.DeleteRequest) (*task.DeleteResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) pidsInternal(_ context.Context, _ *task.PidsRequest) (*task.PidsResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) pauseInternal(_ context.Context, _ *task.PauseRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) resumeInternal(_ context.Context, _ *task.ResumeRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) checkpointInternal(_ context.Context, _ *task.CheckpointTaskRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) killInternal(_ context.Context, _ *task.KillRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) execInternal(_ context.Context, _ *task.ExecProcessRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) resizePtyInternal(_ context.Context, _ *task.ResizePtyRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) closeIOInternal(_ context.Context, _ *task.CloseIORequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) updateInternal(_ context.Context, _ *task.UpdateTaskRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) waitInternal(_ context.Context, _ *task.WaitRequest) (*task.WaitResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) statsInternal(_ context.Context, _ *task.StatsRequest) (*task.StatsResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) connectInternal(_ context.Context, _ *task.ConnectRequest) (*task.ConnectResponse, error) { + return nil, errdefs.ErrNotImplemented +} + +func (s *Service) shutdownInternal(_ context.Context, _ *task.ShutdownRequest) (*emptypb.Empty, error) { + return nil, errdefs.ErrNotImplemented +} diff --git a/cmd/containerd-shim-lcow-v2/versioninfo.json b/cmd/containerd-shim-lcow-v2/versioninfo.json new file mode 100644 index 0000000000..11316902d5 --- /dev/null +++ b/cmd/containerd-shim-lcow-v2/versioninfo.json @@ -0,0 +1,44 @@ +{ + "FixedFileInfo": { + "FileVersion": { + "Major": 1, + "Minor": 0, + "Patch": 0, + "Build": 0 + }, + "ProductVersion": { + "Major": 1, + "Minor": 0, + "Patch": 0, + "Build": 0 + }, + "FileFlagsMask": "3f", + "FileFlags ": "00", + "FileOS": "040004", + "FileType": "01", + "FileSubType": "00" + }, + "StringFileInfo": { + "Comments": "", + "CompanyName": "Microsoft", + "FileDescription": "", + "FileVersion": "", + "InternalName": "", + "LegalCopyright": "", + "LegalTrademarks": "", + "OriginalFilename": "containerd-shim-lcow-v2.exe", + "PrivateBuild": "", + "ProductName": "lcow shim", + "ProductVersion": "v1.0.0.0", + "SpecialBuild": "" + }, + "VarFileInfo": { + "Translation": { + "LangID": "0409", + "CharsetID": "04B0" + } + }, + "IconPath": "", + "ManifestPath": "containerd-shim-lcow-v2.exe.manifest" +} + diff --git a/internal/controller/vm/doc.go b/internal/controller/vm/doc.go new file mode 100644 index 0000000000..304e117157 --- /dev/null +++ b/internal/controller/vm/doc.go @@ -0,0 +1,76 @@ +//go:build windows + +// Package vm provides a controller for managing the lifecycle of a Utility VM (UVM). +// +// A Utility VM is a lightweight virtual machine used to host Linux (LCOW) or +// Windows (WCOW) containers. This package abstracts the VM lifecycle — +// creation, startup, stats collection, and termination — behind the [Controller] +// interface, with [Manager] as the primary implementation. +// +// # Lifecycle +// +// A VM follows the state machine below. +// +// ┌─────────────────┐ +// │ StateNotCreated │ +// └────────┬────────┘ +// │ CreateVM ok +// ▼ +// ┌─────────────────┐ StartVM fails / +// │ StateCreated │──────── TerminateVM fails ──────┐ +// └──┬─────┬────────┘ │ +// │ │ StartVM ok ▼ +// │ ▼ ┌───────────────┐ +// │ ┌─────────────────┐ TerminateVM │ StateInvalid │ +// │ │ StateRunning │───── fails ──────►│ │ +// │ └────────┬────────┘ └───────┬───────┘ +// │ │ VM exits / │ TerminateVM ok +// TerminateVM ok │ TerminateVM ok │ +// │ ▼ ▼ +// │ ┌─────────────────────────────────────────────────┐ +// └─►│ StateTerminated │ +// └─────────────────────────────────────────────────┘ +// +// State descriptions: +// +// - [StateNotCreated]: initial state after [NewController] is called. +// - [StateCreated]: after [Controller.CreateVM] succeeds; the VM exists but has not started. +// - [StateRunning]: after [Controller.StartVM] succeeds; the guest OS is up and the +// Guest Compute Service (GCS) connection is established. +// - [StateTerminated]: terminal state reached after the VM exits naturally or +// [Controller.TerminateVM] completes successfully. +// - [StateInvalid]: error state entered when [Controller.StartVM] fails after the underlying +// HCS VM has already started, or when [Controller.TerminateVM] fails during uvm.Close +// (from either [StateCreated] or [StateRunning]). +// A VM in this state can only be cleaned up by calling [Controller.TerminateVM]. +// +// # Platform Variants +// +// Certain behaviors differ between LCOW and WCOW guests and are implemented in +// platform-specific source files selected via build tags (default for lcow shim and "wcow" tag for wcow shim). +// +// # Usage +// +// ctrl := vm.NewController() +// +// if err := ctrl.CreateVM(ctx, &vm.CreateOptions{ +// ID: "my-uvm", +// HCSDocument: doc, +// }); err != nil { +// // handle error +// } +// +// if err := ctrl.StartVM(ctx, &vm.StartOptions{ +// GCSServiceID: serviceGUID, +// }); err != nil { +// // handle error +// } +// +// // ... use ctrl.Guest() for guest interactions ... +// +// if err := ctrl.TerminateVM(ctx); err != nil { +// // handle error +// } +// +// _ = ctrl.Wait(ctx) +package vm diff --git a/internal/controller/vm/export_test.go b/internal/controller/vm/export_test.go new file mode 100644 index 0000000000..6b8fb3f85a --- /dev/null +++ b/internal/controller/vm/export_test.go @@ -0,0 +1,17 @@ +//go:build windows + +package vm + +import ( + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" +) + +// SetManagerForTest configures a Manager with injected dependencies for testing. +// This is only available in test builds. +// +// MUST be called during test setup before any concurrent operations on the Manager. +func (c *Manager) SetManagerForTest(uvm vmmanager.LifetimeManager, guest GuestManager, state State) { + c.uvm = uvm + c.guest = guest + c.vmState = state +} diff --git a/internal/controller/vm/interface.go b/internal/controller/vm/interface.go new file mode 100644 index 0000000000..ce5e618b02 --- /dev/null +++ b/internal/controller/vm/interface.go @@ -0,0 +1,99 @@ +//go:build windows + +package vm + +//go:generate go tool mockgen -source=interface.go -build_constraint=windows -package=mockvmcontroller -destination=../../test/mock/vmcontroller/mock_interface.go + +import ( + "context" + "time" + + "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/shimdiag" + "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + + "github.com/Microsoft/go-winio/pkg/guid" +) + +// GuestManager defines the guest operations required by the VM Controller. +// It combines the core guest manager, security policy, and HvSocket interfaces +// from the guestmanager package into a single interface for dependency injection. +type GuestManager interface { + guestmanager.Manager + guestmanager.SecurityPolicyManager + guestmanager.HVSocketManager +} + +type Controller interface { + // Guest returns the guest manager instance for this VM. + Guest() GuestManager + + // State returns the current VM state. + State() State + + // CreateVM creates and initializes a new VM with the specified options. + // This prepares the VM but does not start it. + CreateVM(ctx context.Context, opts *CreateOptions) error + + // StartVM starts the created VM with the specified options. + // This establishes the guest connection, sets up necessary listeners for + // guest-host communication, and transitions the VM to StateRunning. + StartVM(context.Context, *StartOptions) error + + // ExecIntoHost executes a command in the running UVM. + ExecIntoHost(ctx context.Context, request *shimdiag.ExecProcessRequest) (int, error) + + // DumpStacks dumps the GCS stacks associated with the VM. + DumpStacks(ctx context.Context) (string, error) + + // Wait blocks until the VM exits or the context is cancelled. + // It also waits for log output processing to complete. + Wait(ctx context.Context) error + + Stats(ctx context.Context) (*stats.VirtualMachineStatistics, error) + + TerminateVM(context.Context) error + + // StartTime returns the timestamp when the VM was started. + // Returns zero value of time.time, if the VM is not in StateRunning or StateTerminated. + StartTime() time.Time + + // ExitStatus returns information about the stopped VM, including when it + // stopped and any exit error. Returns an error if the VM is not in StateTerminated. + ExitStatus() (*ExitStatus, error) +} + +// CreateOptions contains the configuration needed to create a new VM. +type CreateOptions struct { + // ID specifies the unique identifier for the VM. + ID string + + // HCSDocument specifies the HCS schema document used to create the VM. + HCSDocument *hcsschema.ComputeSystem +} + +// StartOptions contains the configuration needed to start a VM and establish +// the Guest Compute Service (GCS) connection. +type StartOptions struct { + // GCSServiceID specifies the GUID for the GCS vsock service. + GCSServiceID guid.GUID + + // ConfigOptions specifies additional configuration options for the guest config. + ConfigOptions []guestmanager.ConfigOption + + // ConfidentialOptions specifies security policy and confidential computing + // options for the VM. This is optional and only used for confidential VMs. + ConfidentialOptions *guestresource.ConfidentialOptions +} + +// ExitStatus contains information about a stopped VM's final state. +type ExitStatus struct { + // StoppedTime is the timestamp when the VM stopped. + StoppedTime time.Time + + // Err is the error that caused the VM to stop, if any. + // This will be nil if the VM exited cleanly. + Err error +} diff --git a/internal/controller/vm/state.go b/internal/controller/vm/state.go new file mode 100644 index 0000000000..6e98eb4ae1 --- /dev/null +++ b/internal/controller/vm/state.go @@ -0,0 +1,78 @@ +//go:build windows + +package vm + +// State represents the current state of the VM lifecycle. +// +// The normal progression is: +// +// StateNotCreated → StateCreated → StateRunning → StateTerminated +// +// If an unrecoverable error occurs during [Controller.StartVM] or +// [Controller.TerminateVM], the VM transitions to [StateInvalid] instead. +// A VM in [StateInvalid] can only be cleaned up via [Controller.TerminateVM]. +// +// Full state-transition table: +// +// Current State │ Trigger │ Next State +// ─────────────────┼────────────────────────────────────┼───────────────── +// StateNotCreated │ CreateVM succeeds │ StateCreated +// StateCreated │ StartVM succeeds │ StateRunning +// StateCreated │ TerminateVM succeeds │ StateTerminated +// StateCreated │ StartVM fails │ StateInvalid +// StateCreated │ TerminateVM fails │ StateInvalid +// StateRunning │ VM exits or TerminateVM succeeds │ StateTerminated +// StateRunning │ TerminateVM fails (uvm.Close) │ StateInvalid +// StateInvalid │ TerminateVM called │ StateTerminated +// StateTerminated │ (terminal — no further transitions)│ — +type State int32 + +const ( + // StateNotCreated indicates the VM has not been created yet. + // This is the initial state when a Controller is first instantiated via [NewController]. + // Valid transitions: StateNotCreated → StateCreated (via [Controller.CreateVM]) + StateNotCreated State = iota + + // StateCreated indicates the VM has been created but not yet started. + // Valid transitions: + // - StateCreated → StateRunning (via [Controller.StartVM], on success) + // - StateCreated → StateTerminated (via [Controller.TerminateVM], on success) + // - StateCreated → StateInvalid (via [Controller.StartVM], on failure) + StateCreated + + // StateRunning indicates the VM has been started and is running. + // The guest OS is up and the Guest Compute Service (GCS) connection is established. + // Valid transitions: + // - StateRunning → StateTerminated (VM exits naturally or [Controller.TerminateVM] succeeds) + // - StateRunning → StateInvalid ([Controller.TerminateVM] fails during uvm.Close) + StateRunning + + // StateTerminated indicates the VM has exited or been successfully terminated. + // This is a terminal state — once reached, no further state transitions are possible. + StateTerminated + + // StateInvalid indicates that an unrecoverable error has occurred. + // The VM transitions to this state when: + // - [Controller.StartVM] fails after the underlying HCS VM has already started, or + // - [Controller.TerminateVM] fails during uvm.Close (from either [StateCreated] or [StateRunning]). + // A VM in this state can only be cleaned up by calling [Controller.TerminateVM]. + StateInvalid +) + +// String returns a human-readable string representation of the VM State. +func (s State) String() string { + switch s { + case StateNotCreated: + return "NotCreated" + case StateCreated: + return "Created" + case StateRunning: + return "Running" + case StateTerminated: + return "Terminated" + case StateInvalid: + return "Invalid" + default: + return "Unknown" + } +} diff --git a/internal/controller/vm/state_test.go b/internal/controller/vm/state_test.go new file mode 100644 index 0000000000..507b3c3448 --- /dev/null +++ b/internal/controller/vm/state_test.go @@ -0,0 +1,33 @@ +//go:build windows + +package vm_test + +import ( + "testing" + + vm "github.com/Microsoft/hcsshim/internal/controller/vm" +) + +func TestStateString(t *testing.T) { + tests := []struct { + state vm.State + want string + }{ + {vm.StateNotCreated, "NotCreated"}, + {vm.StateCreated, "Created"}, + {vm.StateRunning, "Running"}, + {vm.StateTerminated, "Terminated"}, + {vm.StateInvalid, "Invalid"}, + {vm.State(99), "Unknown"}, + {vm.State(-1), "Unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + t.Parallel() + if got := tt.state.String(); got != tt.want { + t.Errorf("State(%d).String() = %q, want %q", tt.state, got, tt.want) + } + }) + } +} diff --git a/internal/controller/vm/vm.go b/internal/controller/vm/vm.go new file mode 100644 index 0000000000..18e7fdf25a --- /dev/null +++ b/internal/controller/vm/vm.go @@ -0,0 +1,442 @@ +//go:build windows + +package vm + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" + "github.com/Microsoft/hcsshim/internal/cmd" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/shimdiag" + "github.com/Microsoft/hcsshim/internal/timeout" + "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" + iwin "github.com/Microsoft/hcsshim/internal/windows" + "github.com/containerd/errdefs" + + "github.com/Microsoft/go-winio/pkg/process" + "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/windows" +) + +// Manager is the VM controller implementation that manages the lifecycle of a Utility VM +// and its associated resources. +type Manager struct { + vmID string + uvm vmmanager.LifetimeManager + guest GuestManager + + // vmState tracks the current state of the VM lifecycle. + // Access must be guarded by mu. + vmState State + + // mu guards the concurrent access to the Manager's fields and operations. + mu sync.RWMutex + + // logOutputDone is closed when the GCS log output processing goroutine completes. + logOutputDone chan struct{} + + // Handle to the vmmem process associated with this UVM. Used to look up + // memory metrics for the UVM. + vmmemProcess windows.Handle + + // activeExecCount tracks the number of ongoing ExecIntoHost calls. + activeExecCount atomic.Int64 + + // isPhysicallyBacked indicates whether the VM is using physical backing for its memory. + isPhysicallyBacked bool +} + +// Ensure both the Controller, and it's subset Handle are implemented by Manager. +var _ Controller = (*Manager)(nil) + +// NewController creates a new Manager instance in the [StateNotCreated] state. +func NewController() *Manager { + return &Manager{ + logOutputDone: make(chan struct{}), + vmState: StateNotCreated, + } +} + +// Guest returns the guest manager instance for this VM. +// The guest manager provides access to guest-host communication. +func (c *Manager) Guest() GuestManager { + return c.guest +} + +// State returns the current VM state. +func (c *Manager) State() State { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.vmState +} + +// CreateVM creates the VM using the HCS document and initializes device state. +func (c *Manager) CreateVM(ctx context.Context, opts *CreateOptions) error { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "CreateVM")) + + c.mu.Lock() + defer c.mu.Unlock() + + // In case of duplicate CreateVM call for the same controller, we want to fail. + if c.vmState != StateNotCreated { + return fmt.Errorf("cannot create VM: VM is in incorrect state %s", c.vmState) + } + + // Create the VM via vmmanager. + uvm, err := vmmanager.Create(ctx, opts.ID, opts.HCSDocument) + if err != nil { + return fmt.Errorf("failed to create VM: %w", err) + } + + // Set the Manager parameters after successful creation. + c.vmID = opts.ID + c.uvm = uvm + // Determine if the VM is physically backed based on the HCS document configuration. + // We need this while extracting memory metrics, as some of them are only relevant for physically backed VMs. + c.isPhysicallyBacked = !opts.HCSDocument.VirtualMachine.ComputeTopology.Memory.AllowOvercommit + + // Initialize the GuestManager for managing guest interactions. + // We will create the guest connection via GuestManager during StartVM. + c.guest = guestmanager.New(ctx, uvm) + + c.vmState = StateCreated + return nil +} + +// StartVM starts the VM that was previously created via CreateVM. +// It starts the underlying HCS VM, establishes the GCS connection, +// and transitions the VM to [StateRunning]. +// On any failure the VM is transitioned to [StateInvalid]. +func (c *Manager) StartVM(ctx context.Context, opts *StartOptions) (err error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "StartVM")) + + c.mu.Lock() + defer c.mu.Unlock() + + // If the VM is already running, we can skip the start operation and just return. + // This makes StartVM idempotent in the case of duplicate calls. + if c.vmState == StateRunning { + return nil + } + // However, if the VM is in any other state than Created, + // we should fail as StartVM is only valid on a created VM. + if c.vmState != StateCreated { + return fmt.Errorf("cannot start VM: VM is in incorrect state %s", c.vmState) + } + + defer func() { + if err != nil { + // If starting the VM fails, we transition to Invalid state to prevent any further operations on the VM. + // The VM can be terminated by invoking TerminateVM. + c.vmState = StateInvalid + } + }() + + // save parent context, without timeout to use for wait. + pCtx := ctx + // For remaining operations, we expect them to complete within the GCS connection timeout, + // otherwise we want to fail. + ctx, cancel := context.WithTimeout(pCtx, timeout.GCSConnectionTimeout) + log.G(ctx).Debugf("using gcs connection timeout: %s\n", timeout.GCSConnectionTimeout) + + g, gctx := errgroup.WithContext(ctx) + defer func() { + _ = g.Wait() + }() + defer cancel() + + // we should set up the necessary listeners for guest-host communication. + // The guest needs to connect to predefined vsock ports. + // The host must already be listening on these ports before the guest attempts to connect, + // otherwise the connection would fail. + c.setupEntropyListener(gctx, g) + c.setupLoggingListener(gctx, g) + + err = c.uvm.Start(ctx) + if err != nil { + return fmt.Errorf("failed to start VM: %w", err) + } + + // Start waiting on the utility VM in the background. + // This goroutine will complete when the VM exits. + go c.waitForVMExit(pCtx) + + // Collect any errors from writing entropy or establishing the log + // connection. + if err = g.Wait(); err != nil { + return err + } + + err = c.guest.CreateConnection(ctx, opts.GCSServiceID, opts.ConfigOptions...) + if err != nil { + return fmt.Errorf("failed to create guest connection: %w", err) + } + + err = c.finalizeGCSConnection(ctx) + if err != nil { + return fmt.Errorf("failed to finalize GCS connection: %w", err) + } + + // Set the confidential options if applicable. + if opts.ConfidentialOptions != nil { + if err := c.guest.AddSecurityPolicy(ctx, *opts.ConfidentialOptions); err != nil { + return fmt.Errorf("failed to set confidential options: %w", err) + } + } + + // If all goes well, we can transition the VM to Running state. + c.vmState = StateRunning + + return nil +} + +// waitForVMExit blocks until the VM exits and then transitions the VM state to [StateTerminated]. +// This is called in StartVM in a background goroutine. +func (c *Manager) waitForVMExit(ctx context.Context) { + // The original context may have timeout or propagate a cancellation + // copy the original to prevent it affecting the background wait go routine + ctx = context.WithoutCancel(ctx) + _ = c.uvm.Wait(ctx) + // Once the VM has exited, attempt to transition to Terminated. + // This may be a no-op if TerminateVM already ran concurrently and + // transitioned the state first — log the discarded error so that + // concurrent-termination races remain observable. + c.mu.Lock() + if c.vmState != StateTerminated { + c.vmState = StateTerminated + } else { + log.G(ctx).WithField("currentState", c.vmState).Debug("waitForVMExit: state transition to Terminated was a no-op") + } + c.mu.Unlock() +} + +// ExecIntoHost executes a command in the running UVM. +func (c *Manager) ExecIntoHost(ctx context.Context, request *shimdiag.ExecProcessRequest) (int, error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "ExecIntoHost")) + + if request.Terminal && request.Stderr != "" { + return -1, fmt.Errorf("if using terminal, stderr must be empty: %w", errdefs.ErrFailedPrecondition) + } + + // Validate that the VM is running before allowing exec into it. + c.mu.RLock() + if c.vmState != StateRunning { + c.mu.RUnlock() + return -1, fmt.Errorf("cannot exec into VM: VM is in incorrect state %s", c.vmState) + } + c.mu.RUnlock() + + // Keep a count of active exec sessions. + // This will be used to disallow LM with existing exec sessions, + // as that can lead to orphaned processes within UVM. + c.activeExecCount.Add(1) + defer c.activeExecCount.Add(-1) + + cmdReq := &cmd.CmdProcessRequest{ + Args: request.Args, + Workdir: request.Workdir, + Terminal: request.Terminal, + Stdin: request.Stdin, + Stdout: request.Stdout, + Stderr: request.Stderr, + } + return c.guest.ExecIntoUVM(ctx, cmdReq) +} + +// DumpStacks dumps the GCS stacks associated with the VM +func (c *Manager) DumpStacks(ctx context.Context) (string, error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "DumpStacks")) + + // Take read lock at this place. + // The state change cannot happen until we release the lock, + // so we are sure that the state remains consistent throughout the method. + c.mu.RLock() + defer c.mu.RUnlock() + + // Validate that the VM is running before sending dump stacks request to GCS. + if c.vmState != StateRunning { + return "", fmt.Errorf("cannot dump stacks: VM is in incorrect state %s", c.vmState) + } + + if c.guest.Capabilities().IsDumpStacksSupported() { + return c.guest.DumpStacks(ctx) + } + + return "", nil +} + +// Wait blocks until the VM exits and all log output processing has completed. +func (c *Manager) Wait(ctx context.Context) error { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "Wait")) + + // Validate that the VM has been created and can be waited on. + // Terminated VMs can also be waited on where we return immediately. + c.mu.RLock() + if c.vmState == StateNotCreated { + c.mu.RUnlock() + return fmt.Errorf("cannot wait on VM: VM is in incorrect state %s", c.vmState) + } + c.mu.RUnlock() + + // Wait for the utility VM to exit. + // This will be unblocked when the VM exits or if the context is cancelled. + err := c.uvm.Wait(ctx) + + // Wait for the log output processing to complete, + // which ensures all logs are processed before we return. + select { + case <-ctx.Done(): + ctxErr := fmt.Errorf("failed to wait on uvm output processing: %w", ctx.Err()) + err = errors.Join(err, ctxErr) + case <-c.logOutputDone: + } + + return err +} + +// Stats returns runtime statistics for the VM including processor runtime and +// memory usage. The VM must be in [StateRunning]. +func (c *Manager) Stats(ctx context.Context) (*stats.VirtualMachineStatistics, error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "Stats")) + + // Take read lock at this place. + // The state change cannot happen until we release the lock, + // so we are sure that the state remains consistent throughout the method. + c.mu.RLock() + defer c.mu.RUnlock() + + if c.vmState != StateRunning { + return nil, fmt.Errorf("cannot get stats: VM is in incorrect state %s", c.vmState) + } + + // Initialization of vmmemProcess to calculate stats properly for VA-backed UVMs. + if c.vmmemProcess == 0 { + vmmemHandle, err := vmutils.LookupVMMEM(ctx, c.uvm.RuntimeID(), &iwin.WinAPI{}) + if err != nil { + return nil, fmt.Errorf("cannot get stats: %w", err) + } + c.vmmemProcess = vmmemHandle + } + + s := &stats.VirtualMachineStatistics{} + props, err := c.uvm.PropertiesV2(ctx, hcsschema.PTStatistics, hcsschema.PTMemory) + if err != nil { + return nil, fmt.Errorf("failed to get VM properties: %w", err) + } + s.Processor = &stats.VirtualMachineProcessorStatistics{} + s.Processor.TotalRuntimeNS = uint64(props.Statistics.Processor.TotalRuntime100ns * 100) + + s.Memory = &stats.VirtualMachineMemoryStatistics{} + if !c.isPhysicallyBacked { + // The HCS properties does not return sufficient information to calculate + // working set size for a VA-backed UVM. To work around this, we instead + // locate the vmmem process for the VM, and query that process's working set + // instead, which will be the working set for the VM. + memCounters, err := process.GetProcessMemoryInfo(c.vmmemProcess) + if err != nil { + return nil, err + } + s.Memory.WorkingSetBytes = uint64(memCounters.WorkingSetSize) + } + + if props.Memory != nil { + if c.isPhysicallyBacked { + // If the uvm is physically backed we set the working set to the total amount allocated + // to the UVM. AssignedMemory returns the number of 4KB pages. Will always be 4KB + // regardless of what the UVMs actual page size is so we don't need that information. + s.Memory.WorkingSetBytes = props.Memory.VirtualMachineMemory.AssignedMemory * 4096 + } + s.Memory.VirtualNodeCount = props.Memory.VirtualNodeCount + s.Memory.VmMemory = &stats.VirtualMachineMemory{} + s.Memory.VmMemory.AvailableMemory = props.Memory.VirtualMachineMemory.AvailableMemory + s.Memory.VmMemory.AvailableMemoryBuffer = props.Memory.VirtualMachineMemory.AvailableMemoryBuffer + s.Memory.VmMemory.ReservedMemory = props.Memory.VirtualMachineMemory.ReservedMemory + s.Memory.VmMemory.AssignedMemory = props.Memory.VirtualMachineMemory.AssignedMemory + s.Memory.VmMemory.SlpActive = props.Memory.VirtualMachineMemory.SlpActive + s.Memory.VmMemory.BalancingEnabled = props.Memory.VirtualMachineMemory.BalancingEnabled + s.Memory.VmMemory.DmOperationInProgress = props.Memory.VirtualMachineMemory.DmOperationInProgress + } + return s, nil +} + +// TerminateVM forcefully terminates a running VM, closes the guest connection, +// and releases HCS resources. +// +// The context is used for all operations, including waits, so timeouts/cancellations may prevent +// proper UVM cleanup. +func (c *Manager) TerminateVM(ctx context.Context) (err error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.Operation, "TerminateVM")) + + c.mu.Lock() + defer c.mu.Unlock() + + // If the VM has already terminated, we can skip termination and just return. + // Alternatively, if the VM was never created, we can also skip termination. + // This makes the TerminateVM operation idempotent. + if c.vmState == StateTerminated || c.vmState == StateNotCreated { + return nil + } + + // Best effort attempt to clean up the open vmmem handle. + _ = windows.Close(c.vmmemProcess) + // Terminate the utility VM. This will also cause the Wait() call in the background goroutine to unblock. + _ = c.uvm.Terminate(ctx) + + if err := c.guest.CloseConnection(); err != nil { + log.G(ctx).Errorf("close guest connection failed: %s", err) + } + + err = c.uvm.Close(ctx) + if err != nil { + // Transition to Invalid so no further active operations can be performed on the VM. + c.vmState = StateInvalid + return fmt.Errorf("failed to close utility VM: %w", err) + } + + // Set the Terminated status at the end. + c.vmState = StateTerminated + return nil +} + +// StartTime returns the timestamp when the VM was started. +// Returns zero value of time.Time if the VM has not yet reached +// [StateRunning] or [StateTerminated]. +func (c *Manager) StartTime() (startTime time.Time) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.vmState == StateRunning || c.vmState == StateTerminated { + return c.uvm.StartedTime() + } + + return startTime +} + +// ExitStatus returns the final status of the VM once it has reached +// [StateTerminated], including the time it stopped and any exit error. +// Returns an error if the VM has not yet stopped. +func (c *Manager) ExitStatus() (*ExitStatus, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.vmState != StateTerminated { + return nil, fmt.Errorf("cannot get exit status: VM is in incorrect state %s", c.vmState) + } + + return &ExitStatus{ + StoppedTime: c.uvm.StoppedTime(), + Err: c.uvm.ExitError(), + }, nil +} diff --git a/internal/controller/vm/vm_lcow.go b/internal/controller/vm/vm_lcow.go new file mode 100644 index 0000000000..269871a316 --- /dev/null +++ b/internal/controller/vm/vm_lcow.go @@ -0,0 +1,96 @@ +//go:build windows && !wcow + +package vm + +import ( + "context" + "crypto/rand" + "fmt" + "io" + + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" + + "github.com/Microsoft/go-winio" + "golang.org/x/sync/errgroup" +) + +// setupEntropyListener sets up entropy for LCOW UVMs. +// +// Linux VMs require entropy to initialize their random number generators during boot. +// This method listens on a predefined vsock port and provides cryptographically secure +// random data to the Linux init process when it connects. +func (c *Manager) setupEntropyListener(ctx context.Context, group *errgroup.Group) { + group.Go(func() error { + // The Linux guest will connect to this port during init to receive entropy. + entropyConn, err := winio.ListenHvsock(&winio.HvsockAddr{ + VMID: c.uvm.RuntimeID(), + ServiceID: winio.VsockServiceID(vmutils.LinuxEntropyVsockPort), + }) + if err != nil { + return fmt.Errorf("failed to listen on hvSocket for entropy: %w", err) + } + + // Prepare to provide entropy to the init process in the background. This + // must be done in a goroutine since, when using the internal bridge, the + // call to Start() will block until the GCS launches, and this cannot occur + // until the host accepts and closes the entropy connection. + conn, err := vmmanager.AcceptConnection(ctx, c.uvm, entropyConn, true) + if err != nil { + return fmt.Errorf("failed to accept connection on hvSocket for entropy: %w", err) + } + defer conn.Close() + + // Write the required amount of entropy to the connection. + // The init process will read this data and use it to seed the kernel's + // random number generator (CRNG). + _, err = io.CopyN(conn, rand.Reader, vmutils.LinuxEntropyBytes) + if err != nil { + return fmt.Errorf("failed to write entropy to connection: %w", err) + } + + return nil + }) +} + +// setupLoggingListener sets up logging for LCOW UVMs. +// +// This method establishes a vsock connection to receive log output from GCS +// running inside the Linux VM. The logs are parsed and +// forwarded to the host's logging system for monitoring and debugging. +func (c *Manager) setupLoggingListener(ctx context.Context, group *errgroup.Group) { + group.Go(func() error { + // The GCS will connect to this port to stream log output. + logConn, err := winio.ListenHvsock(&winio.HvsockAddr{ + VMID: c.uvm.RuntimeID(), + ServiceID: winio.VsockServiceID(vmutils.LinuxLogVsockPort), + }) + if err != nil { + return fmt.Errorf("failed to listen on hvSocket for logs: %w", err) + } + + // Accept the connection from the GCS. + conn, err := vmmanager.AcceptConnection(ctx, c.uvm, logConn, true) + if err != nil { + return fmt.Errorf("failed to accept connection on hvSocket for logs: %w", err) + } + + // Launch a separate goroutine to process logs for the lifetime of the VM. + go func() { + // Parse GCS log output and forward it to the host logging system. + vmutils.ParseGCSLogrus(c.uvm.ID())(conn) + + // Signal that log output processing has completed. + // This allows Wait() to ensure all logs are processed before returning. + close(c.logOutputDone) + }() + + return nil + }) +} + +// finalizeGCSConnection finalizes the GCS connection for LCOW VMs. +// For LCOW, no additional finalization is needed. +func (c *Manager) finalizeGCSConnection(_ context.Context) error { + return nil +} diff --git a/internal/controller/vm/vm_test.go b/internal/controller/vm/vm_test.go new file mode 100644 index 0000000000..aa42edc82f --- /dev/null +++ b/internal/controller/vm/vm_test.go @@ -0,0 +1,435 @@ +//go:build windows + +package vm_test + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + vm "github.com/Microsoft/hcsshim/internal/controller/vm" + gcsmock "github.com/Microsoft/hcsshim/internal/gcs/mock" + "github.com/Microsoft/hcsshim/internal/shimdiag" + ctrlmock "github.com/Microsoft/hcsshim/internal/test/mock/vmcontroller" + vmmock "github.com/Microsoft/hcsshim/internal/vm/vmmanager/mock" + + "go.uber.org/mock/gomock" +) + +var errTest = errors.New("test error") + +// newTestManager creates a Manager in the given state with mocked dependencies. +func newTestManager(t *testing.T, ctrl *gomock.Controller, state vm.State) (*vm.Manager, *vmmock.MockLifetimeManager, *ctrlmock.MockGuestManager) { + t.Helper() + + mockUVM := vmmock.NewMockLifetimeManager(ctrl) + mockGuest := ctrlmock.NewMockGuestManager(ctrl) + + m := vm.NewController() + m.SetManagerForTest(mockUVM, mockGuest, state) + + return m, mockUVM, mockGuest +} + +func TestTerminateVM(t *testing.T) { + tests := []struct { + name string + initialState vm.State + setupMock func(*vmmock.MockLifetimeManager, *ctrlmock.MockGuestManager) + expectError bool + errorContains string + expectState vm.State + }{ + { + name: "idempotent/NotCreated", + initialState: vm.StateNotCreated, + expectState: vm.StateNotCreated, + }, + { + name: "idempotent/Terminated", + initialState: vm.StateTerminated, + expectState: vm.StateTerminated, + }, + { + name: "success/Created_to_Terminated", + initialState: vm.StateCreated, + setupMock: func(uvm *vmmock.MockLifetimeManager, guest *ctrlmock.MockGuestManager) { + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + }, + expectState: vm.StateTerminated, + }, + { + name: "success/Running_to_Terminated", + initialState: vm.StateRunning, + setupMock: func(uvm *vmmock.MockLifetimeManager, guest *ctrlmock.MockGuestManager) { + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + }, + expectState: vm.StateTerminated, + }, + { + name: "error/Close_fails_transitions_to_Invalid", + initialState: vm.StateRunning, + setupMock: func(uvm *vmmock.MockLifetimeManager, guest *ctrlmock.MockGuestManager) { + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(errTest) + }, + expectError: true, + errorContains: "failed to close utility VM", + expectState: vm.StateInvalid, + }, + { + name: "recovery/Invalid_to_Terminated", + initialState: vm.StateInvalid, + setupMock: func(uvm *vmmock.MockLifetimeManager, guest *ctrlmock.MockGuestManager) { + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + }, + expectState: vm.StateTerminated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, mockUVM, mockGuest := newTestManager(t, ctrl, tt.initialState) + + if tt.setupMock != nil { + tt.setupMock(mockUVM, mockGuest) + } + + err := m.TerminateVM(context.Background()) + + if tt.expectError { + if err == nil { + t.Fatal("expected error but got nil") + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("expected error containing %q, got: %v", tt.errorContains, err) + } + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if m.State() != tt.expectState { + t.Errorf("expected state %v, got %v", tt.expectState, m.State()) + } + }) + } +} + +func TestGuardChecks(t *testing.T) { + tests := []struct { + name string + state vm.State + call func(*vm.Manager) error + errorContains string + }{ + {"ExecIntoHost/NotCreated", vm.StateNotCreated, func(m *vm.Manager) error { + _, err := m.ExecIntoHost(context.Background(), &shimdiag.ExecProcessRequest{Args: []string{"ls"}}) + return err + }, "incorrect state"}, + {"ExecIntoHost/Created", vm.StateCreated, func(m *vm.Manager) error { + _, err := m.ExecIntoHost(context.Background(), &shimdiag.ExecProcessRequest{Args: []string{"ls"}}) + return err + }, "incorrect state"}, + {"ExecIntoHost/Terminated", vm.StateTerminated, func(m *vm.Manager) error { + _, err := m.ExecIntoHost(context.Background(), &shimdiag.ExecProcessRequest{Args: []string{"ls"}}) + return err + }, "incorrect state"}, + {"DumpStacks/NotCreated", vm.StateNotCreated, func(m *vm.Manager) error { + _, err := m.DumpStacks(context.Background()) + return err + }, "incorrect state"}, + {"DumpStacks/Terminated", vm.StateTerminated, func(m *vm.Manager) error { + _, err := m.DumpStacks(context.Background()) + return err + }, "incorrect state"}, + {"Stats/NotCreated", vm.StateNotCreated, func(m *vm.Manager) error { + _, err := m.Stats(context.Background()) + return err + }, "incorrect state"}, + {"Stats/Created", vm.StateCreated, func(m *vm.Manager) error { + _, err := m.Stats(context.Background()) + return err + }, "incorrect state"}, + {"Wait/NotCreated", vm.StateNotCreated, func(m *vm.Manager) error { + return m.Wait(context.Background()) + }, "incorrect state"}, + {"ExitStatus/NotCreated", vm.StateNotCreated, func(m *vm.Manager) error { + _, err := m.ExitStatus() + return err + }, "incorrect state"}, + {"ExitStatus/Running", vm.StateRunning, func(m *vm.Manager) error { + _, err := m.ExitStatus() + return err + }, "incorrect state"}, + {"ExitStatus/Created", vm.StateCreated, func(m *vm.Manager) error { + _, err := m.ExitStatus() + return err + }, "incorrect state"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, _, _ := newTestManager(t, ctrl, tt.state) + + err := tt.call(m) + if err == nil { + t.Fatal("expected error but got nil") + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("expected error containing %q, got: %v", tt.errorContains, err) + } + }) + } +} + +func TestExecIntoHost(t *testing.T) { + tests := []struct { + name string + request *shimdiag.ExecProcessRequest + setupMock func(*ctrlmock.MockGuestManager) + expectError bool + expectExitCode int + }{ + { + name: "terminal_with_stderr_rejected", + request: &shimdiag.ExecProcessRequest{ + Args: []string{"ls"}, + Terminal: true, + Stderr: "/some/path", + }, + expectError: true, + }, + { + name: "success/zero_exit_code", + request: &shimdiag.ExecProcessRequest{Args: []string{"ls"}}, + setupMock: func(guest *ctrlmock.MockGuestManager) { + guest.EXPECT().ExecIntoUVM(gomock.Any(), gomock.Any()).Return(0, nil) + }, + expectExitCode: 0, + }, + { + name: "success/non_zero_exit_code", + request: &shimdiag.ExecProcessRequest{Args: []string{"false"}}, + setupMock: func(guest *ctrlmock.MockGuestManager) { + guest.EXPECT().ExecIntoUVM(gomock.Any(), gomock.Any()).Return(1, nil) + }, + expectExitCode: 1, + }, + { + name: "error/exec_fails", + request: &shimdiag.ExecProcessRequest{Args: []string{"ls"}}, + setupMock: func(guest *ctrlmock.MockGuestManager) { + guest.EXPECT().ExecIntoUVM(gomock.Any(), gomock.Any()).Return(-1, errTest) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, _, mockGuest := newTestManager(t, ctrl, vm.StateRunning) + + if tt.setupMock != nil { + tt.setupMock(mockGuest) + } + + exitCode, err := m.ExecIntoHost(context.Background(), tt.request) + + if tt.expectError { + if err == nil { + t.Fatal("expected error but got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != tt.expectExitCode { + t.Errorf("expected exit code %d, got %d", tt.expectExitCode, exitCode) + } + }) + } +} + +func TestDumpStacks(t *testing.T) { + tests := []struct { + name string + supported bool + dumpError error + expectResult string + expectError bool + }{ + { + name: "supported", + supported: true, + expectResult: "stack trace", + }, + { + name: "not_supported", + supported: false, + expectResult: "", + }, + { + name: "error/dump_fails", + supported: true, + dumpError: errTest, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, _, mockGuest := newTestManager(t, ctrl, vm.StateRunning) + + mockCaps := gcsmock.NewMockGuestDefinedCapabilities(ctrl) + mockCaps.EXPECT().IsDumpStacksSupported().Return(tt.supported) + mockGuest.EXPECT().Capabilities().Return(mockCaps) + if tt.supported { + mockGuest.EXPECT().DumpStacks(gomock.Any()).Return(tt.expectResult, tt.dumpError) + } + + result, err := m.DumpStacks(context.Background()) + + if tt.expectError { + if err == nil { + t.Fatal("expected error but got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != tt.expectResult { + t.Errorf("expected %q, got %q", tt.expectResult, result) + } + }) + } +} + +func TestStartTime(t *testing.T) { + tests := []struct { + name string + state vm.State + expectZero bool + }{ + {"NotCreated", vm.StateNotCreated, true}, + {"Created", vm.StateCreated, true}, + {"Running", vm.StateRunning, false}, + {"Terminated", vm.StateTerminated, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, mockUVM, _ := newTestManager(t, ctrl, tt.state) + + if !tt.expectZero { + mockUVM.EXPECT().StartedTime().Return(time.Now()) + } + + st := m.StartTime() + if tt.expectZero && !st.IsZero() { + t.Errorf("expected zero time, got %v", st) + } + if !tt.expectZero && st.IsZero() { + t.Error("expected non-zero time, got zero") + } + }) + } +} + +func TestExitStatus(t *testing.T) { + tests := []struct { + name string + exitError error + }{ + {"clean_exit", nil}, + {"with_error", errTest}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, mockUVM, _ := newTestManager(t, ctrl, vm.StateTerminated) + + now := time.Now() + mockUVM.EXPECT().StoppedTime().Return(now) + mockUVM.EXPECT().ExitError().Return(tt.exitError) + + status, err := m.ExitStatus() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status.StoppedTime != now { + t.Errorf("expected stopped time %v, got %v", now, status.StoppedTime) + } + if !errors.Is(status.Err, tt.exitError) { + t.Errorf("expected exit error %v, got %v", tt.exitError, status.Err) + } + }) + } +} + +func TestStartVM_IdempotentWhenRunning(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, _, _ := newTestManager(t, ctrl, vm.StateRunning) + + err := m.StartVM(context.Background(), &vm.StartOptions{}) + if err != nil { + t.Fatalf("expected nil for idempotent StartVM, got: %v", err) + } + if m.State() != vm.StateRunning { + t.Errorf("expected state Running, got %v", m.State()) + } +} + +func TestStartVM_ErrorWhenNotCreated(t *testing.T) { + tests := []struct { + name string + state vm.State + }{ + {"NotCreated", vm.StateNotCreated}, + {"Terminated", vm.StateTerminated}, + {"Invalid", vm.StateInvalid}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + m, _, _ := newTestManager(t, ctrl, tt.state) + + err := m.StartVM(context.Background(), &vm.StartOptions{}) + if err == nil { + t.Errorf("expected error for StartVM in state %v", tt.state) + } + }) + } +} diff --git a/internal/controller/vm/vm_wcow.go b/internal/controller/vm/vm_wcow.go new file mode 100644 index 0000000000..de6053be8e --- /dev/null +++ b/internal/controller/vm/vm_wcow.go @@ -0,0 +1,116 @@ +//go:build windows && wcow + +package vm + +import ( + "context" + "fmt" + "sync" + + "github.com/Microsoft/go-winio" + "github.com/Microsoft/hcsshim/internal/gcs/prot" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" + + "github.com/sirupsen/logrus" + "golang.org/x/net/netutil" + "golang.org/x/sync/errgroup" +) + +// setupEntropyListener sets up entropy for WCOW (Windows Containers on Windows) VMs. +// +// For WCOW, entropy setup is not required. Windows VMs have their own internal +// random number generation that does not depend on host-provided entropy. +// This is a no-op implementation to satisfy the platform-specific interface. +// +// For comparison, LCOW VMs require entropy to be provided during boot. +func (c *Manager) setupEntropyListener(_ context.Context, _ *errgroup.Group) {} + +// setupLoggingListener sets up logging for WCOW UVMs. +// +// Unlike LCOW, where the log connection must be established before the VM starts, +// WCOW allows the GCS to connect to the logging socket at any time after the VM +// is running. This method sets up a persistent listener that can accept connections +// even if the GCS restarts or reconnects. +// +// The listener is configured to accept only one concurrent connection at a time +// to prevent resource exhaustion, but will accept new connections if the current one is closed. +// This supports scenarios where the logging service inside the VM needs to restart. +func (c *Manager) setupLoggingListener(ctx context.Context, _ *errgroup.Group) { + // For Windows, the listener can receive a connection later (after VM starts), + // so we start the output handler in a goroutine with a non-timeout context. + // This allows the output handler to run independently of the VM creation lifecycle. + // This is useful for the case when the logging service is restarted. + go func() { + baseListener, err := winio.ListenHvsock(&winio.HvsockAddr{ + VMID: c.uvm.RuntimeID(), + ServiceID: prot.WindowsLoggingHvsockServiceID, + }) + if err != nil { + // Close the output done channel to signal that logging setup + // has failed and no logs will be processed. + close(c.logOutputDone) + logrus.WithError(err).Error("failed to listen for windows logging connections") + + // Return early due to error. + return + } + + // Use a WaitGroup to track active log processing goroutines. + // This ensures we wait for all log processing to complete before closing logOutputDone. + var wg sync.WaitGroup + + // Limit the listener to accept at most 1 concurrent connection. + limitedListener := netutil.LimitListener(baseListener, 1) + + for { + // Accept a connection from the GCS. + conn, err := vmmanager.AcceptConnection(context.WithoutCancel(ctx), c.uvm, limitedListener, false) + if err != nil { + logrus.WithError(err).Error("failed to connect to log socket") + break + } + + // Launch a goroutine to process logs from this connection. + wg.Add(1) + go func() { + defer wg.Done() + logrus.Info("uvm output handler starting") + + // Parse GCS log output and forward it to the host logging system. + // The parser handles logrus-formatted logs from the GCS. + vmutils.ParseGCSLogrus(c.uvm.ID())(conn) + + logrus.Info("uvm output handler finished") + }() + } + + // Wait for all log processing goroutines to complete. + wg.Wait() + + // Signal that log output processing has completed. + close(c.logOutputDone) + }() +} + +// finalizeGCSConnection finalizes the GCS connection for WCOW UVMs. +// This is called after CreateConnection succeeds and before the VM is considered fully started. +func (c *Manager) finalizeGCSConnection(ctx context.Context) error { + // Prepare the HvSocket address configuration for the external GCS connection. + // The LocalAddress is the VM's runtime ID, and the ParentAddress is the + // predefined host ID for Windows GCS communication. + hvsocketAddress := &hcsschema.HvSocketAddress{ + LocalAddress: c.uvm.RuntimeID().String(), + ParentAddress: prot.WindowsGcsHvHostID.String(), + } + + // Update the guest manager with the HvSocket address configuration. + // This enables the GCS to establish proper bidirectional communication. + err := c.guest.UpdateHvSocketAddress(ctx, hvsocketAddress) + if err != nil { + return fmt.Errorf("failed to create GCS connection: %w", err) + } + + return nil +} diff --git a/internal/gcs/guestcaps.go b/internal/gcs/guestcaps.go index 58b1f8ebb7..47143ab35e 100644 --- a/internal/gcs/guestcaps.go +++ b/internal/gcs/guestcaps.go @@ -2,6 +2,8 @@ package gcs +//go:generate go tool mockgen -source=guestcaps.go -package=mock -destination=mock/mock_guestcaps.go + import ( "encoding/json" "fmt" diff --git a/internal/gcs/mock/mock_guestcaps.go b/internal/gcs/mock/mock_guestcaps.go new file mode 100644 index 0000000000..7018bb54f4 --- /dev/null +++ b/internal/gcs/mock/mock_guestcaps.go @@ -0,0 +1,96 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: guestcaps.go +// +// Generated by this command: +// +// mockgen -source=guestcaps.go -package=mock -destination=mock/mock_guestcaps.go +// + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockGuestDefinedCapabilities is a mock of GuestDefinedCapabilities interface. +type MockGuestDefinedCapabilities struct { + ctrl *gomock.Controller + recorder *MockGuestDefinedCapabilitiesMockRecorder + isgomock struct{} +} + +// MockGuestDefinedCapabilitiesMockRecorder is the mock recorder for MockGuestDefinedCapabilities. +type MockGuestDefinedCapabilitiesMockRecorder struct { + mock *MockGuestDefinedCapabilities +} + +// NewMockGuestDefinedCapabilities creates a new mock instance. +func NewMockGuestDefinedCapabilities(ctrl *gomock.Controller) *MockGuestDefinedCapabilities { + mock := &MockGuestDefinedCapabilities{ctrl: ctrl} + mock.recorder = &MockGuestDefinedCapabilitiesMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockGuestDefinedCapabilities) EXPECT() *MockGuestDefinedCapabilitiesMockRecorder { + return m.recorder +} + +// IsDeleteContainerStateSupported mocks base method. +func (m *MockGuestDefinedCapabilities) IsDeleteContainerStateSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsDeleteContainerStateSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsDeleteContainerStateSupported indicates an expected call of IsDeleteContainerStateSupported. +func (mr *MockGuestDefinedCapabilitiesMockRecorder) IsDeleteContainerStateSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDeleteContainerStateSupported", reflect.TypeOf((*MockGuestDefinedCapabilities)(nil).IsDeleteContainerStateSupported)) +} + +// IsDumpStacksSupported mocks base method. +func (m *MockGuestDefinedCapabilities) IsDumpStacksSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsDumpStacksSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsDumpStacksSupported indicates an expected call of IsDumpStacksSupported. +func (mr *MockGuestDefinedCapabilitiesMockRecorder) IsDumpStacksSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDumpStacksSupported", reflect.TypeOf((*MockGuestDefinedCapabilities)(nil).IsDumpStacksSupported)) +} + +// IsNamespaceAddRequestSupported mocks base method. +func (m *MockGuestDefinedCapabilities) IsNamespaceAddRequestSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNamespaceAddRequestSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsNamespaceAddRequestSupported indicates an expected call of IsNamespaceAddRequestSupported. +func (mr *MockGuestDefinedCapabilitiesMockRecorder) IsNamespaceAddRequestSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNamespaceAddRequestSupported", reflect.TypeOf((*MockGuestDefinedCapabilities)(nil).IsNamespaceAddRequestSupported)) +} + +// IsSignalProcessSupported mocks base method. +func (m *MockGuestDefinedCapabilities) IsSignalProcessSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsSignalProcessSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsSignalProcessSupported indicates an expected call of IsSignalProcessSupported. +func (mr *MockGuestDefinedCapabilitiesMockRecorder) IsSignalProcessSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsSignalProcessSupported", reflect.TypeOf((*MockGuestDefinedCapabilities)(nil).IsSignalProcessSupported)) +} diff --git a/internal/logfields/fields.go b/internal/logfields/fields.go index cceb3e2d18..dac5a708e5 100644 --- a/internal/logfields/fields.go +++ b/internal/logfields/fields.go @@ -8,12 +8,12 @@ const ( Operation = "operation" ID = "id" - SandboxID = "sid" ContainerID = "cid" ExecID = "eid" ProcessID = "pid" TaskID = "tid" UVMID = "uvm-id" + SandboxID = "sandbox-id" // networking and IO @@ -50,6 +50,40 @@ const ( Uint32 = "uint32" Uint64 = "uint64" + // task / process lifecycle + + Bundle = "bundle" + Terminal = "terminal" + Stdin = "stdin" + Stdout = "stdout" + Stderr = "stderr" + Checkpoint = "checkpoint" + ParentCheckpoint = "parent-checkpoint" + Status = "status" + ExitStatus = "exit-status" + ExitedAt = "exited-at" + Signal = "signal" + All = "all" + Width = "width" + Height = "height" + Version = "version" + ShimPid = "shim-pid" + TaskPid = "task-pid" + + // sandbox + + NetNsPath = "net-ns-path" + Verbose = "verbose" + + // shimdiag + + Args = "args" + Workdir = "workdir" + HostPath = "host-path" + UVMPath = "uvm-path" + ReadOnly = "readonly" + Execs = "execs" + // runhcs VMShimOperation = "vmshim-op" diff --git a/internal/test/mock/vmcontroller/mock_interface.go b/internal/test/mock/vmcontroller/mock_interface.go new file mode 100644 index 0000000000..294f888ee9 --- /dev/null +++ b/internal/test/mock/vmcontroller/mock_interface.go @@ -0,0 +1,399 @@ +//go:build windows + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go +// +// Generated by this command: +// +// mockgen -source=interface.go -build_constraint=windows -package=mockvmcontroller -destination=../../test/mock/vmcontroller/mock_interface.go +// + +// Package mockvmcontroller is a generated GoMock package. +package mockvmcontroller + +import ( + context "context" + reflect "reflect" + time "time" + + guid "github.com/Microsoft/go-winio/pkg/guid" + stats "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" + cmd "github.com/Microsoft/hcsshim/internal/cmd" + vm "github.com/Microsoft/hcsshim/internal/controller/vm" + cow "github.com/Microsoft/hcsshim/internal/cow" + gcs "github.com/Microsoft/hcsshim/internal/gcs" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + guestresource "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + shimdiag "github.com/Microsoft/hcsshim/internal/shimdiag" + guestmanager "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + gomock "go.uber.org/mock/gomock" +) + +// MockGuestManager is a mock of GuestManager interface. +type MockGuestManager struct { + ctrl *gomock.Controller + recorder *MockGuestManagerMockRecorder + isgomock struct{} +} + +// MockGuestManagerMockRecorder is the mock recorder for MockGuestManager. +type MockGuestManagerMockRecorder struct { + mock *MockGuestManager +} + +// NewMockGuestManager creates a new mock instance. +func NewMockGuestManager(ctrl *gomock.Controller) *MockGuestManager { + mock := &MockGuestManager{ctrl: ctrl} + mock.recorder = &MockGuestManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockGuestManager) EXPECT() *MockGuestManagerMockRecorder { + return m.recorder +} + +// AddSecurityPolicy mocks base method. +func (m *MockGuestManager) AddSecurityPolicy(ctx context.Context, settings guestresource.ConfidentialOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddSecurityPolicy", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddSecurityPolicy indicates an expected call of AddSecurityPolicy. +func (mr *MockGuestManagerMockRecorder) AddSecurityPolicy(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSecurityPolicy", reflect.TypeOf((*MockGuestManager)(nil).AddSecurityPolicy), ctx, settings) +} + +// Capabilities mocks base method. +func (m *MockGuestManager) Capabilities() gcs.GuestDefinedCapabilities { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Capabilities") + ret0, _ := ret[0].(gcs.GuestDefinedCapabilities) + return ret0 +} + +// Capabilities indicates an expected call of Capabilities. +func (mr *MockGuestManagerMockRecorder) Capabilities() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Capabilities", reflect.TypeOf((*MockGuestManager)(nil).Capabilities)) +} + +// CloseConnection mocks base method. +func (m *MockGuestManager) CloseConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseConnection indicates an expected call of CloseConnection. +func (mr *MockGuestManagerMockRecorder) CloseConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseConnection", reflect.TypeOf((*MockGuestManager)(nil).CloseConnection)) +} + +// CreateConnection mocks base method. +func (m *MockGuestManager) CreateConnection(ctx context.Context, GCSServiceID guid.GUID, opts ...guestmanager.ConfigOption) error { + m.ctrl.T.Helper() + varargs := []any{ctx, GCSServiceID} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateConnection", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateConnection indicates an expected call of CreateConnection. +func (mr *MockGuestManagerMockRecorder) CreateConnection(ctx, GCSServiceID any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, GCSServiceID}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConnection", reflect.TypeOf((*MockGuestManager)(nil).CreateConnection), varargs...) +} + +// CreateContainer mocks base method. +func (m *MockGuestManager) CreateContainer(ctx context.Context, cid string, config any) (*gcs.Container, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateContainer", ctx, cid, config) + ret0, _ := ret[0].(*gcs.Container) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateContainer indicates an expected call of CreateContainer. +func (mr *MockGuestManagerMockRecorder) CreateContainer(ctx, cid, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateContainer", reflect.TypeOf((*MockGuestManager)(nil).CreateContainer), ctx, cid, config) +} + +// CreateProcess mocks base method. +func (m *MockGuestManager) CreateProcess(ctx context.Context, settings any) (cow.Process, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateProcess", ctx, settings) + ret0, _ := ret[0].(cow.Process) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateProcess indicates an expected call of CreateProcess. +func (mr *MockGuestManagerMockRecorder) CreateProcess(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProcess", reflect.TypeOf((*MockGuestManager)(nil).CreateProcess), ctx, settings) +} + +// DeleteContainerState mocks base method. +func (m *MockGuestManager) DeleteContainerState(ctx context.Context, cid string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteContainerState", ctx, cid) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteContainerState indicates an expected call of DeleteContainerState. +func (mr *MockGuestManagerMockRecorder) DeleteContainerState(ctx, cid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteContainerState", reflect.TypeOf((*MockGuestManager)(nil).DeleteContainerState), ctx, cid) +} + +// DumpStacks mocks base method. +func (m *MockGuestManager) DumpStacks(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DumpStacks", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DumpStacks indicates an expected call of DumpStacks. +func (mr *MockGuestManagerMockRecorder) DumpStacks(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DumpStacks", reflect.TypeOf((*MockGuestManager)(nil).DumpStacks), ctx) +} + +// ExecIntoUVM mocks base method. +func (m *MockGuestManager) ExecIntoUVM(ctx context.Context, request *cmd.CmdProcessRequest) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExecIntoUVM", ctx, request) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecIntoUVM indicates an expected call of ExecIntoUVM. +func (mr *MockGuestManagerMockRecorder) ExecIntoUVM(ctx, request any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecIntoUVM", reflect.TypeOf((*MockGuestManager)(nil).ExecIntoUVM), ctx, request) +} + +// InjectPolicyFragment mocks base method. +func (m *MockGuestManager) InjectPolicyFragment(ctx context.Context, settings guestresource.SecurityPolicyFragment) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InjectPolicyFragment", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// InjectPolicyFragment indicates an expected call of InjectPolicyFragment. +func (mr *MockGuestManagerMockRecorder) InjectPolicyFragment(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InjectPolicyFragment", reflect.TypeOf((*MockGuestManager)(nil).InjectPolicyFragment), ctx, settings) +} + +// UpdateHvSocketAddress mocks base method. +func (m *MockGuestManager) UpdateHvSocketAddress(ctx context.Context, settings *hcsschema.HvSocketAddress) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateHvSocketAddress", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateHvSocketAddress indicates an expected call of UpdateHvSocketAddress. +func (mr *MockGuestManagerMockRecorder) UpdateHvSocketAddress(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHvSocketAddress", reflect.TypeOf((*MockGuestManager)(nil).UpdateHvSocketAddress), ctx, settings) +} + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder + isgomock struct{} +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// CreateVM mocks base method. +func (m *MockController) CreateVM(ctx context.Context, opts *vm.CreateOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateVM", ctx, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateVM indicates an expected call of CreateVM. +func (mr *MockControllerMockRecorder) CreateVM(ctx, opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateVM", reflect.TypeOf((*MockController)(nil).CreateVM), ctx, opts) +} + +// DumpStacks mocks base method. +func (m *MockController) DumpStacks(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DumpStacks", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DumpStacks indicates an expected call of DumpStacks. +func (mr *MockControllerMockRecorder) DumpStacks(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DumpStacks", reflect.TypeOf((*MockController)(nil).DumpStacks), ctx) +} + +// ExecIntoHost mocks base method. +func (m *MockController) ExecIntoHost(ctx context.Context, request *shimdiag.ExecProcessRequest) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExecIntoHost", ctx, request) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecIntoHost indicates an expected call of ExecIntoHost. +func (mr *MockControllerMockRecorder) ExecIntoHost(ctx, request any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecIntoHost", reflect.TypeOf((*MockController)(nil).ExecIntoHost), ctx, request) +} + +// ExitStatus mocks base method. +func (m *MockController) ExitStatus() (*vm.ExitStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExitStatus") + ret0, _ := ret[0].(*vm.ExitStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExitStatus indicates an expected call of ExitStatus. +func (mr *MockControllerMockRecorder) ExitStatus() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExitStatus", reflect.TypeOf((*MockController)(nil).ExitStatus)) +} + +// Guest mocks base method. +func (m *MockController) Guest() vm.GuestManager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Guest") + ret0, _ := ret[0].(vm.GuestManager) + return ret0 +} + +// Guest indicates an expected call of Guest. +func (mr *MockControllerMockRecorder) Guest() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Guest", reflect.TypeOf((*MockController)(nil).Guest)) +} + +// StartTime mocks base method. +func (m *MockController) StartTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// StartTime indicates an expected call of StartTime. +func (mr *MockControllerMockRecorder) StartTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartTime", reflect.TypeOf((*MockController)(nil).StartTime)) +} + +// StartVM mocks base method. +func (m *MockController) StartVM(arg0 context.Context, arg1 *vm.StartOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartVM", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// StartVM indicates an expected call of StartVM. +func (mr *MockControllerMockRecorder) StartVM(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartVM", reflect.TypeOf((*MockController)(nil).StartVM), arg0, arg1) +} + +// State mocks base method. +func (m *MockController) State() vm.State { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "State") + ret0, _ := ret[0].(vm.State) + return ret0 +} + +// State indicates an expected call of State. +func (mr *MockControllerMockRecorder) State() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockController)(nil).State)) +} + +// Stats mocks base method. +func (m *MockController) Stats(ctx context.Context) (*stats.VirtualMachineStatistics, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stats", ctx) + ret0, _ := ret[0].(*stats.VirtualMachineStatistics) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Stats indicates an expected call of Stats. +func (mr *MockControllerMockRecorder) Stats(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stats", reflect.TypeOf((*MockController)(nil).Stats), ctx) +} + +// TerminateVM mocks base method. +func (m *MockController) TerminateVM(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateVM", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// TerminateVM indicates an expected call of TerminateVM. +func (mr *MockControllerMockRecorder) TerminateVM(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateVM", reflect.TypeOf((*MockController)(nil).TerminateVM), arg0) +} + +// Wait mocks base method. +func (m *MockController) Wait(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockControllerMockRecorder) Wait(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockController)(nil).Wait), ctx) +} diff --git a/internal/uvm/start.go b/internal/uvm/start.go index c6ba805304..2c18db56b9 100644 --- a/internal/uvm/start.go +++ b/internal/uvm/start.go @@ -129,9 +129,8 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { e.Info("uvm output handler finished") } wg.Wait() - if _, ok := <-uvm.outputProcessingDone; ok { - close(uvm.outputProcessingDone) - } + // Signal that log output processing has completed. + close(uvm.outputProcessingDone) }() default: // Default handling diff --git a/internal/vm/vmmanager/lifetime.go b/internal/vm/vmmanager/lifetime.go index b2cc737b53..2da963412b 100644 --- a/internal/vm/vmmanager/lifetime.go +++ b/internal/vm/vmmanager/lifetime.go @@ -2,6 +2,8 @@ package vmmanager +//go:generate go tool mockgen -source=lifetime.go -package=mock -destination=mock/mock_lifetime.go + import ( "context" "fmt" diff --git a/internal/vm/vmmanager/mock/mock_lifetime.go b/internal/vm/vmmanager/mock/mock_lifetime.go new file mode 100644 index 0000000000..59c9dd5c33 --- /dev/null +++ b/internal/vm/vmmanager/mock/mock_lifetime.go @@ -0,0 +1,232 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: lifetime.go +// +// Generated by this command: +// +// mockgen -source=lifetime.go -package=mock -destination=mock/mock_lifetime.go +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + time "time" + + guid "github.com/Microsoft/go-winio/pkg/guid" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + gomock "go.uber.org/mock/gomock" +) + +// MockLifetimeManager is a mock of LifetimeManager interface. +type MockLifetimeManager struct { + ctrl *gomock.Controller + recorder *MockLifetimeManagerMockRecorder + isgomock struct{} +} + +// MockLifetimeManagerMockRecorder is the mock recorder for MockLifetimeManager. +type MockLifetimeManagerMockRecorder struct { + mock *MockLifetimeManager +} + +// NewMockLifetimeManager creates a new mock instance. +func NewMockLifetimeManager(ctrl *gomock.Controller) *MockLifetimeManager { + mock := &MockLifetimeManager{ctrl: ctrl} + mock.recorder = &MockLifetimeManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLifetimeManager) EXPECT() *MockLifetimeManagerMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockLifetimeManager) Close(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockLifetimeManagerMockRecorder) Close(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLifetimeManager)(nil).Close), ctx) +} + +// ExitError mocks base method. +func (m *MockLifetimeManager) ExitError() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExitError") + ret0, _ := ret[0].(error) + return ret0 +} + +// ExitError indicates an expected call of ExitError. +func (mr *MockLifetimeManagerMockRecorder) ExitError() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExitError", reflect.TypeOf((*MockLifetimeManager)(nil).ExitError)) +} + +// ID mocks base method. +func (m *MockLifetimeManager) ID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(string) + return ret0 +} + +// ID indicates an expected call of ID. +func (mr *MockLifetimeManagerMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockLifetimeManager)(nil).ID)) +} + +// Pause mocks base method. +func (m *MockLifetimeManager) Pause(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Pause", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Pause indicates an expected call of Pause. +func (mr *MockLifetimeManagerMockRecorder) Pause(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pause", reflect.TypeOf((*MockLifetimeManager)(nil).Pause), ctx) +} + +// PropertiesV2 mocks base method. +func (m *MockLifetimeManager) PropertiesV2(ctx context.Context, types ...hcsschema.PropertyType) (*hcsschema.Properties, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range types { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "PropertiesV2", varargs...) + ret0, _ := ret[0].(*hcsschema.Properties) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PropertiesV2 indicates an expected call of PropertiesV2. +func (mr *MockLifetimeManagerMockRecorder) PropertiesV2(ctx any, types ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, types...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PropertiesV2", reflect.TypeOf((*MockLifetimeManager)(nil).PropertiesV2), varargs...) +} + +// Resume mocks base method. +func (m *MockLifetimeManager) Resume(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Resume", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Resume indicates an expected call of Resume. +func (mr *MockLifetimeManagerMockRecorder) Resume(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resume", reflect.TypeOf((*MockLifetimeManager)(nil).Resume), ctx) +} + +// RuntimeID mocks base method. +func (m *MockLifetimeManager) RuntimeID() guid.GUID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RuntimeID") + ret0, _ := ret[0].(guid.GUID) + return ret0 +} + +// RuntimeID indicates an expected call of RuntimeID. +func (mr *MockLifetimeManagerMockRecorder) RuntimeID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuntimeID", reflect.TypeOf((*MockLifetimeManager)(nil).RuntimeID)) +} + +// Save mocks base method. +func (m *MockLifetimeManager) Save(ctx context.Context, options hcsschema.SaveOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Save", ctx, options) + ret0, _ := ret[0].(error) + return ret0 +} + +// Save indicates an expected call of Save. +func (mr *MockLifetimeManagerMockRecorder) Save(ctx, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockLifetimeManager)(nil).Save), ctx, options) +} + +// Start mocks base method. +func (m *MockLifetimeManager) Start(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockLifetimeManagerMockRecorder) Start(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockLifetimeManager)(nil).Start), ctx) +} + +// StartedTime mocks base method. +func (m *MockLifetimeManager) StartedTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartedTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// StartedTime indicates an expected call of StartedTime. +func (mr *MockLifetimeManagerMockRecorder) StartedTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedTime", reflect.TypeOf((*MockLifetimeManager)(nil).StartedTime)) +} + +// StoppedTime mocks base method. +func (m *MockLifetimeManager) StoppedTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StoppedTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// StoppedTime indicates an expected call of StoppedTime. +func (mr *MockLifetimeManagerMockRecorder) StoppedTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoppedTime", reflect.TypeOf((*MockLifetimeManager)(nil).StoppedTime)) +} + +// Terminate mocks base method. +func (m *MockLifetimeManager) Terminate(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Terminate", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Terminate indicates an expected call of Terminate. +func (mr *MockLifetimeManagerMockRecorder) Terminate(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Terminate", reflect.TypeOf((*MockLifetimeManager)(nil).Terminate), ctx) +} + +// Wait mocks base method. +func (m *MockLifetimeManager) Wait(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockLifetimeManagerMockRecorder) Wait(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockLifetimeManager)(nil).Wait), ctx) +} diff --git a/internal/vm/vmutils/constants.go b/internal/vm/vmutils/constants.go index a332fb6a99..6276a17595 100644 --- a/internal/vm/vmutils/constants.go +++ b/internal/vm/vmutils/constants.go @@ -10,6 +10,9 @@ import ( ) const ( + // LCOWShimName is the name of the LCOW shim implementation. + LCOWShimName = "containerd-shim-lcow-v2" + // MaxVPMEMCount is the maximum number of VPMem devices that may be added to an LCOW // utility VM. MaxVPMEMCount = 128 diff --git a/internal/vm/vmutils/doc.go b/internal/vm/vmutils/doc.go index e78e4a5809..31ffb541ca 100644 --- a/internal/vm/vmutils/doc.go +++ b/internal/vm/vmutils/doc.go @@ -7,6 +7,6 @@ // (internal/controller). Functions in this package are designed to be decoupled from // specific UVM implementations. // -// This allows different shims (containerd-shim-runhcs-v1, containerd-shim-lcow-v1) +// This allows different shims (containerd-shim-runhcs-v1, containerd-shim-lcow-v2) // to share common logic while maintaining their own orchestration patterns. package vmutils diff --git a/internal/vm/vmutils/utils.go b/internal/vm/vmutils/utils.go index cd710a6bc3..f609f975ff 100644 --- a/internal/vm/vmutils/utils.go +++ b/internal/vm/vmutils/utils.go @@ -9,7 +9,12 @@ import ( "os" "path/filepath" + runhcsoptions "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" "github.com/Microsoft/hcsshim/internal/log" + + "github.com/containerd/typeurl/v2" + "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/anypb" ) // ParseUVMReferenceInfo reads the UVM reference info file, and base64 encodes the content if it exists. @@ -30,3 +35,29 @@ func ParseUVMReferenceInfo(ctx context.Context, referenceRoot, referenceName str return base64.StdEncoding.EncodeToString(content), nil } + +// UnmarshalRuntimeOptions decodes the runtime options into runhcsoptions.Options. +// When no options are provided (options == nil) it returns a non-nil, +// zero-value Options struct. +func UnmarshalRuntimeOptions(ctx context.Context, options *anypb.Any) (*runhcsoptions.Options, error) { + opts := &runhcsoptions.Options{} + if options == nil { + return opts, nil + } + + v, err := typeurl.UnmarshalAny(options) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal options: %w", err) + } + + shimOpts, ok := v.(*runhcsoptions.Options) + if !ok { + return nil, fmt.Errorf("failed to unmarshal runtime options: expected *runhcsoptions.Options, got %T", v) + } + + if entry := log.G(ctx); entry.Logger.IsLevelEnabled(logrus.DebugLevel) { + entry.WithField("options", log.Format(ctx, shimOpts)).Debug("parsed runtime options") + } + + return shimOpts, nil +}