diff --git a/vsphere/resource_vsphere_virtual_machine.go b/vsphere/resource_vsphere_virtual_machine.go index d6b44dcfa..fea0d4501 100644 --- a/vsphere/resource_vsphere_virtual_machine.go +++ b/vsphere/resource_vsphere_virtual_machine.go @@ -244,9 +244,9 @@ func resourceVSphereVirtualMachine() *schema.Resource { Elem: &schema.Schema{Type: schema.TypeString}, }, "shared_pci_device_id": { - Type: schema.TypeString, + Type: schema.TypeSet, Optional: true, - Description: "Id of Shared PCI passthrough device, 'grid_rtx8000-8q'", + Description: "A list of Shared PCI passthrough device, 'grid_rtx8000-8q'", Elem: &schema.Schema{Type: schema.TypeString}, }, "clone": { @@ -515,21 +515,23 @@ func resourceVSphereVirtualMachineRead(d *schema.ResourceData, meta interface{}) // Read the virtual machine PCI passthrough devices var pciDevs []string + var sharedPciDevs []string for _, dev := range vprops.Config.Hardware.Device { if pci, ok := dev.(*types.VirtualPCIPassthrough); ok { - if pciBacking, ok := pci.Backing.(*types.VirtualPCIPassthroughDeviceBackingInfo); ok { - devId := pciBacking.Id + switch t := pci.Backing.(type) { + case *types.VirtualPCIPassthroughDeviceBackingInfo: + devId := t.Id pciDevs = append(pciDevs, devId) - } else { - if pciBacking, ok := pci.Backing.(*types.VirtualPCIPassthroughVmiopBackingInfo); ok { - err = d.Set("shared_pci_device_id", pciBacking.Vgpu) - if err != nil { - return err - } - } else { - log.Printf("[WARN] Ignoring VM %q VirtualPCIPassthrough device with backing type of %T", - vm.InventoryPath, pci.Backing) - } + log.Printf("[DEBUG] Identified VM %q VirtualPCIPassthrough device %s with backing type of %T", + vm.InventoryPath, devId, pci.Backing) + case *types.VirtualPCIPassthroughVmiopBackingInfo: + dev := t.Vgpu + sharedPciDevs = append(sharedPciDevs, dev) + log.Printf("[DEBUG] Identified VM %q VirtualPCIPassthrough device %s with backing type of %T", + vm.InventoryPath, dev, pci.Backing) + default: + log.Printf("[WARN] Ignoring VM %q VirtualPCIPassthrough device with backing type of %T", + vm.InventoryPath, pci.Backing) } } } @@ -537,6 +539,10 @@ func resourceVSphereVirtualMachineRead(d *schema.ResourceData, meta interface{}) if err != nil { return err } + err = d.Set("shared_pci_device_id", sharedPciDevs) + if err != nil { + return err + } // Perform pending device read operations. devices := object.VirtualDeviceList(vprops.Config.Hardware.Device) @@ -1675,6 +1681,17 @@ func resourceVSphereVirtualMachinePostDeployChanges(d *schema.ResourceData, meta ) } cfgSpec.DeviceChange = virtualdevice.AppendDeviceChangeSpec(cfgSpec.DeviceChange, delta...) + // Shared PCI devices + devices, delta, err = virtualdevice.SharedPciPostCloneOperation(d, client, devices) + if err != nil { + return resourceVSphereVirtualMachineRollbackCreate( + d, + meta, + vm, + fmt.Errorf("error processing shared PCI device changes post-clone: %s", err), + ) + } + cfgSpec.DeviceChange = virtualdevice.AppendDeviceChangeSpec(cfgSpec.DeviceChange, delta...) log.Printf("[DEBUG] %s: Final device list: %s", resourceVSphereVirtualMachineIDString(d), virtualdevice.DeviceListString(devices)) log.Printf("[DEBUG] %s: Final device change cfgSpec: %s", resourceVSphereVirtualMachineIDString(d), virtualdevice.DeviceChangeString(cfgSpec.DeviceChange)) @@ -1978,8 +1995,8 @@ func applyVirtualDevices(d *schema.ResourceData, c *govmomi.Client, l object.Vir return nil, err } spec = virtualdevice.AppendDeviceChangeSpec(spec, delta...) - // Shared PCI passthrough device - l, delta, err = virtualdevice.SharedPciPassthroughApplyOperation(d, c, l) + // Shared PCI device + l, delta, err = virtualdevice.SharedPciApplyOperation(d, c, l) if err != nil { return nil, err }