Skip to content

Commit

Permalink
Refactor GCEClient: wrap compute.Service in an interface for mocking …
Browse files Browse the repository at this point in the history
…GCP compute

This change creates a ComputeService implementation which has a runtime
implementation that wraps the compute.Service. The MachineActuator is
changed to make use of the ComputeService through a new interface named
GCEClientComputeService. This will enable creating tests that mock GCP Compute
Service calls to control MachineActuator behavior.
  • Loading branch information
spew committed Apr 27, 2018
1 parent 70c282d commit 061a664
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 14 deletions.
66 changes: 66 additions & 0 deletions cloud/google/clients/computeservice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package clients

import (
compute "google.golang.org/api/compute/v1"
"net/http"
"net/url"
)

// ComputeService is a pass through wrapper for google.golang.org/api/compute/v1/compute
// The purpose of the ComputeService's wrap of the GCE client is to enable tests to mock this struct and control behavior.
type ComputeService struct {
service *compute.Service
}

func NewComputeService(client *http.Client) (*ComputeService, error) {
service, err := compute.New(client)
if err != nil {
return nil, err
}
return &ComputeService{
service: service,
}, nil
}

func NewComputeServiceForURL(client *http.Client, baseURL string) (*ComputeService, error) {
computeService, err := NewComputeService(client)
if err != nil {
return nil, err
}
url, err := url.Parse(computeService.service.BasePath)
if err != nil {
return nil, err
}
computeService.service.BasePath = baseURL + url.Path
return computeService, err
}

// A pass through wrapper for compute.Service.Images.Get(...)
func (c *ComputeService) ImagesGet(project string, image string) (*compute.Image, error) {
return c.service.Images.Get(project, image).Do()
}

// A pass through wrapper for compute.Service.Images.GetFromFamily(...)
func (c *ComputeService) ImagesGetFromFamily(project string, family string) (*compute.Image, error) {
return c.service.Images.GetFromFamily(project, family).Do()
}

// A pass through wrapper for compute.Service.Instances.Delete(...)
func (c *ComputeService) InstancesDelete(project string, zone string, targetInstance string) (*compute.Operation, error) {
return c.service.Instances.Delete(project, zone, targetInstance).Do()
}

// A pass through wrapper for compute.Service.Instances.Get(...)
func (c *ComputeService) InstancesGet(project string, zone string, instance string) (*compute.Instance, error) {
return c.service.Instances.Get(project, zone, instance).Do()
}

// A pass through wrapper for compute.Service.Instances.Insert(...)
func (c *ComputeService) InstancesInsert(project string, zone string, instance *compute.Instance) (*compute.Operation, error) {
return c.service.Instances.Insert(project, zone, instance).Do()
}

// A pass through wrapper for compute.Service.ZoneOperations.Get(...)
func (c *ComputeService) ZoneOperationsGet(project string, zone string, operation string) (*compute.Operation, error) {
return c.service.ZoneOperations.Get(project, zone, operation).Do()
}
156 changes: 156 additions & 0 deletions cloud/google/clients/computeservice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package clients_test

import (
compute "google.golang.org/api/compute/v1"
"encoding/json"
"google.golang.org/api/googleapi"
"net/http"
"net/http/httptest"
"sigs.k8s.io/cluster-api/cloud/google/clients"
"testing"
)

func TestImagesGet(t *testing.T) {
mux, server, client := createMuxServerAndClient()
defer server.Close()
responseImage := compute.Image{
Name: "imageName",
ArchiveSizeBytes: 544,
}
mux.Handle("/compute/v1/projects/projectName/global/images/imageName", handler(nil, &responseImage))
image, err := client.ImagesGet("projectName", "imageName")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if image == nil {
t.Error("expected a valid image")
}
if "imageName" != image.Name {
t.Errorf("expected imageName got %v", image.Name)
}
if image.ArchiveSizeBytes != int64(544) {
t.Errorf("expected %v got %v", image.ArchiveSizeBytes, 544)
}
}

func TestImagesGetFromFamily(t *testing.T) {
mux, server, client := createMuxServerAndClient()
defer server.Close()
responseImage := compute.Image{
Name: "imageName",
ArchiveSizeBytes: 544,
}
mux.Handle("/compute/v1/projects/projectName/global/images/family/familyName", handler(nil, &responseImage))
image, err := client.ImagesGetFromFamily("projectName", "familyName")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if image == nil {
t.Error("expected a valid image")
}
if "imageName" != image.Name {
t.Errorf("expected imageName got %v", image.Name)
}
if image.ArchiveSizeBytes != int64(544) {
t.Errorf("expected %v got %v", image.ArchiveSizeBytes, 544)
}
}

func TestInstancesDelete(t *testing.T) {
mux, server, client := createMuxServerAndClient()
defer server.Close()
responseOperation := compute.Operation{
Id: 4501,
}
mux.Handle("/compute/v1/projects/projectName/zones/zoneName/instances/instanceName", handler(nil, &responseOperation))
op, err := client.InstancesDelete("projectName", "zoneName", "instanceName")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if op == nil {
t.Error("expected a valid operation")
}
if responseOperation.Id != uint64(4501) {
t.Errorf("expected %v got %v", responseOperation.Id, 4501)
}
}

func TestInstancesGet(t *testing.T) {
mux, server, client := createMuxServerAndClient()
defer server.Close()
responseInstance := compute.Instance{
Name: "instanceName",
Zone: "zoneName",
}
mux.Handle("/compute/v1/projects/projectName/zones/zoneName/instances/instanceName", handler(nil, &responseInstance))
instance, err := client.InstancesGet("projectName", "zoneName", "instanceName")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if instance == nil {
t.Error("expected a valid instance")
}
if "instanceName" != instance.Name {
t.Errorf("expected instanceName got %v", instance.Name)
}
if "zoneName" != instance.Zone {
t.Errorf("expected zoneName got %v", instance.Zone)
}
}

func TestInstancesInsert(t *testing.T) {
mux, server, client := createMuxServerAndClient()
defer server.Close()
responseOperation := compute.Operation{
Id: 3001,
}
mux.Handle("/compute/v1/projects/projectName/zones/zoneName/instances", handler(nil, &responseOperation))
op, err := client.InstancesInsert("projectName", "zoneName", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if op == nil {
t.Error("expected a valid operation")
}
if responseOperation.Id != uint64(3001) {
t.Errorf("expected %v got %v", responseOperation.Id, 3001)
}
}

func createMuxServerAndClient() (*http.ServeMux, *httptest.Server, *clients.ComputeService) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
client, _ := clients.NewComputeServiceForURL(server.Client(), server.URL)
return mux, server, client
}

func handler(err *googleapi.Error, obj interface{}) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
handleTestRequest(w, err, obj)
}
}

func handleTestRequest(w http.ResponseWriter, handleErr *googleapi.Error, obj interface{}) {
if handleErr != nil {
http.Error(w, errMsg(handleErr), handleErr.Code)
return
}
res, err := json.Marshal(obj)
if err != nil {
http.Error(w, "json marshal error", http.StatusInternalServerError)
return
}
w.Write(res)
}

func errMsg(e *googleapi.Error) string {
res, err := json.Marshal(&errorReply{e})
if err != nil {
return "json marshal error"
}
return string(res)
}

type errorReply struct {
Error *googleapi.Error `json:"error"`
}
38 changes: 24 additions & 14 deletions cloud/google/machineactuator.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
apierrors "sigs.k8s.io/cluster-api/errors"
clusterv1 "sigs.k8s.io/cluster-api/pkg/apis/cluster/v1alpha1"
client "sigs.k8s.io/cluster-api/pkg/client/clientset_generated/clientset/typed/cluster/v1alpha1"
"sigs.k8s.io/cluster-api/cloud/google/clients"
"sigs.k8s.io/cluster-api/util"
)

Expand All @@ -66,8 +67,17 @@ type SshCreds struct {
privateKeyPath string
}

type GCEClientComputeService interface {
ImagesGet(project string, image string) (*compute.Image, error)
ImagesGetFromFamily(project string, family string) (*compute.Image, error)
InstancesDelete(project string, zone string, targetInstance string) (*compute.Operation, error)
InstancesGet(project string, zone string, instance string) (*compute.Instance, error)
InstancesInsert(project string, zone string, instance *compute.Instance) (*compute.Operation, error)
ZoneOperationsGet(project string, zone string, operation string) (*compute.Operation, error)
}

type GCEClient struct {
service *compute.Service
computeService GCEClientComputeService
scheme *runtime.Scheme
codecFactory *serializer.CodecFactory
kubeadmToken string
Expand All @@ -89,7 +99,7 @@ func NewMachineActuator(kubeadmToken string, machineClient client.MachineInterfa
return nil, err
}

service, err := compute.New(client)
computeService, err := clients.NewComputeService(client)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -122,10 +132,10 @@ func NewMachineActuator(kubeadmToken string, machineClient client.MachineInterfa
}

return &GCEClient{
service: service,
scheme: scheme,
codecFactory: codecFactory,
kubeadmToken: kubeadmToken,
computeService: computeService,
scheme: scheme,
codecFactory: codecFactory,
kubeadmToken: kubeadmToken,
sshCreds: SshCreds{
privateKeyPath: privateKeyPath,
user: user,
Expand Down Expand Up @@ -257,7 +267,7 @@ func (gce *GCEClient) Create(cluster *clusterv1.Cluster, machine *clusterv1.Mach
labels[BootstrapLabelKey] = "true"
}

op, err := gce.service.Instances.Insert(project, zone, &compute.Instance{
op, err := gce.computeService.InstancesInsert(project, zone, &compute.Instance{
Name: name,
MachineType: fmt.Sprintf("zones/%s/machineTypes/%s", zone, config.MachineType),
NetworkInterfaces: []*compute.NetworkInterface{
Expand Down Expand Up @@ -288,7 +298,7 @@ func (gce *GCEClient) Create(cluster *clusterv1.Cluster, machine *clusterv1.Mach
Items: []string{"https-server"},
},
Labels: labels,
}).Do()
})

if err == nil {
err = gce.waitForOperation(config, op)
Expand Down Expand Up @@ -347,7 +357,7 @@ func (gce *GCEClient) Delete(machine *clusterv1.Machine) error {
name = machine.ObjectMeta.Name
}

op, err := gce.service.Instances.Delete(project, zone, name).Do()
op, err := gce.computeService.InstancesDelete(project, zone, name)
if err == nil {
err = gce.waitForOperation(config, op)
}
Expand Down Expand Up @@ -442,7 +452,7 @@ func (gce *GCEClient) GetIP(machine *clusterv1.Machine) (string, error) {
return "", err
}

instance, err := gce.service.Instances.Get(config.Project, config.Zone, machine.ObjectMeta.Name).Do()
instance, err := gce.computeService.InstancesGet(config.Project, config.Zone, machine.ObjectMeta.Name)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -532,7 +542,7 @@ func (gce *GCEClient) instanceIfExists(machine *clusterv1.Machine) (*compute.Ins
return nil, err
}

instance, err := gce.service.Instances.Get(config.Project, config.Zone, identifyingMachine.ObjectMeta.Name).Do()
instance, err := gce.computeService.InstancesGet(config.Project, config.Zone, identifyingMachine.ObjectMeta.Name)
if err != nil {
// TODO: Use formal way to check for error code 404
if strings.Contains(err.Error(), "Error 404") {
Expand Down Expand Up @@ -583,7 +593,7 @@ func (gce *GCEClient) waitForOperation(c *gceconfig.GCEProviderConfig, op *compu

// getOp returns an updated operation.
func (gce *GCEClient) getOp(c *gceconfig.GCEProviderConfig, op *compute.Operation) (*compute.Operation, error) {
return gce.service.ZoneOperations.Get(c.Project, path.Base(op.Zone), op.Name).Do()
return gce.computeService.ZoneOperationsGet(c.Project, path.Base(op.Zone), op.Name)
}

func (gce *GCEClient) checkOp(op *compute.Operation, err error) error {
Expand Down Expand Up @@ -684,9 +694,9 @@ func (gce *GCEClient) getImagePath(img string) (imagePath string) {
project, family, name := matches[1], matches[2], matches[3]
var err error
if family == "" {
_, err = gce.service.Images.Get(project, name).Do()
_, err = gce.computeService.ImagesGet(project, name)
} else {
_, err = gce.service.Images.GetFromFamily(project, name).Do()
_, err = gce.computeService.ImagesGetFromFamily(project, name)
}

if err == nil {
Expand Down

0 comments on commit 061a664

Please sign in to comment.