Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated cherry pick of #1586: update driver to support staging compute #1609: fix pointer issue for GCE staging support #1614

46 changes: 39 additions & 7 deletions cmd/gce-pd-csi-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"math/rand"
"net/url"
"os"
"runtime"
"strings"
Expand All @@ -38,7 +41,6 @@ import (
var (
cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config")
endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint")
computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.")
runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)")
runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)")
httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.")
Expand Down Expand Up @@ -67,12 +69,13 @@ var (

maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")

fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")

enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")

version string
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
computeEnvironment gce.Environment = gce.EnvironmentProduction
computeEndpoint *url.URL
version string
allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction}
)

const (
Expand All @@ -85,13 +88,16 @@ func init() {
// Use V(4) for general debug information logging
// Use V(5) for GCE Cloud Provider Call informational logging
// Use V(6) for extra repeated/polling information
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint")
klog.InitFlags(flag.CommandLine)
flag.Set("logtostderr", "true")
}

func main() {
flag.Parse()
rand.Seed(time.Now().UnixNano())
klog.Infof("Operating compute environment set to: %s and computeEndpoint is set to: %v", computeEnvironment, computeEndpoint)
handle()
os.Exit(0)
}
Expand Down Expand Up @@ -156,7 +162,7 @@ func handle() {
// Initialize requirements for the controller service
var controllerServer *driver.GCEControllerServer
if *runControllerService {
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint)
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment)
if err != nil {
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
}
Expand Down Expand Up @@ -205,3 +211,29 @@ func handle() {

gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
}

func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) {
flag.Func(name, usage, func(flagValue string) error {
for _, allowedValue := range allowedComputeEnvironment {
if gce.Environment(flagValue) == allowedValue {
*target = gce.Environment(flagValue)
return nil
}
}
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
return errors.New(errMsg)
})

}

func urlFlag(target **url.URL, name string, usage string) {
flag.Func(name, usage, func(flagValue string) error {
computeURL, err := url.ParseRequestURI(flagValue)
if err == nil {
*target = computeURL
return nil
}
klog.Infof("Error parsing endpoint compute endpoint %v", err)
return err
})
}
85 changes: 49 additions & 36 deletions pkg/gce-cloud-provider/compute/gce.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"time"
Expand All @@ -37,6 +38,9 @@ import (
"k8s.io/klog/v2"
)

type Environment string
type Version string

const (
TokenURL = "https://accounts.google.com/o/oauth2/token"
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
Expand All @@ -46,7 +50,12 @@ const (

regionURITemplate = "projects/%s/regions/%s"

replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
versionV1 Version = "v1"
versionBeta Version = "beta"
versionAlpha Version = "alpha"
EnvironmentStaging Environment = "staging"
EnvironmentProduction Environment = "production"
)

type CloudProvider struct {
Expand All @@ -72,7 +81,7 @@ type ConfigGlobal struct {
Zone string `gcfg:"zone"`
}

func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) {
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment) (*CloudProvider, error) {
configFile, err := readConfig(configPath)
if err != nil {
return nil, err
Expand All @@ -87,20 +96,23 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
return nil, err
}

svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
if err != nil {
return nil, err
}
klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath)

betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
if err != nil {
return nil, err
}
klog.Infof("Compute endpoint for Beta version: %s", betasvc.BasePath)

alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
if err != nil {
return nil, err
}
klog.Infof("Compute endpoint for Alpha version: %s", alphasvc.BasePath)

project, zone, err := getProjectAndZone(configFile)
if err != nil {
Expand Down Expand Up @@ -164,16 +176,23 @@ func readConfig(configPath string) (*ConfigFile, error) {
return cfg, nil
}

func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) {
client, err := newOauthClient(ctx, tokenSource)
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computealpha.Service, error) {
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
if err != nil {
klog.Errorf("Failed to get compute endpoint: %s", err)
}
service, err := computealpha.NewService(ctx, computeOpts...)
if err != nil {
return nil, err
}
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
return service, nil
}

computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
if computeEndpoint != "" {
betaEndpoint := fmt.Sprintf("%s/compute/beta/", computeEndpoint)
computeOpts = append(computeOpts, option.WithEndpoint(betaEndpoint))
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computebeta.Service, error) {
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
if err != nil {
klog.Errorf("Failed to get compute endpoint: %s", err)
}
service, err := computebeta.NewService(ctx, computeOpts...)
if err != nil {
Expand All @@ -183,47 +202,41 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
return service, nil
}

func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computealpha.Service, error) {
client, err := newOauthClient(ctx, tokenSource)
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*compute.Service, error) {
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
if err != nil {
return nil, err
}

computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
if computeEndpoint != "" {
alphaEndpoint := fmt.Sprintf("%s/compute/alpha/", computeEndpoint)
computeOpts = append(computeOpts, option.WithEndpoint(alphaEndpoint))
klog.Errorf("Failed to get compute endpoint: %s", err)
}
service, err := computealpha.NewService(ctx, computeOpts...)
service, err := compute.NewService(ctx, computeOpts...)
if err != nil {
return nil, err
}
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
return service, nil
}

func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint)
return svc, err
}

func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
client, err := newOauthClient(ctx, tokenSource)
if err != nil {
return nil, err
}

computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
if computeEndpoint != "" {
v1Endpoint := fmt.Sprintf("%s/compute/v1/", computeEndpoint)
computeOpts = append(computeOpts, option.WithEndpoint(v1Endpoint))

if computeEndpoint != nil {
computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion)
computeEndpoint.Path = computeEnvironmentSuffix
endpoint := computeEndpoint.String()
computeOpts = append(computeOpts, option.WithEndpoint(endpoint))
}
service, err := compute.NewService(ctx, computeOpts...)
if err != nil {
return nil, err
return computeOpts, nil
}

func constructComputeEndpointPath(env Environment, version Version) string {
prefix := ""
if env == EnvironmentStaging {
prefix = fmt.Sprintf("%s_", env)
}
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
return service, nil
return fmt.Sprintf("compute/%s%s/", prefix, version)
}

func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
Expand Down
74 changes: 74 additions & 0 deletions pkg/gce-cloud-provider/compute/gce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@ limitations under the License.
package gcecloudprovider

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"testing"
"time"

"golang.org/x/oauth2"

"google.golang.org/api/compute/v1"
"google.golang.org/api/googleapi"
)

type mockTokenSource struct{}

func (*mockTokenSource) Token() (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "access",
TokenType: "Bearer",
RefreshToken: "refresh",
Expiry: time.Now().Add(1 * time.Hour),
}, nil
}
func TestIsGCEError(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -84,3 +100,61 @@ func TestIsGCEError(t *testing.T) {
}
}
}

func TestGetComputeVersion(t *testing.T) {
testCases := []struct {
name string
computeEndpoint *url.URL
computeEnvironment Environment
computeVersion Version
expectedEndpoint string
expectError bool
}{

{
name: "check for production environment",
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
computeEnvironment: EnvironmentProduction,
computeVersion: versionBeta,
expectedEndpoint: "https://compute.googleapis.com/compute/beta/",
expectError: false,
},
{
name: "check for staging environment",
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
computeEnvironment: EnvironmentStaging,
computeVersion: versionV1,
expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/",
expectError: false,
},
{
name: "check for random string as endpoint",
computeEndpoint: convertStringToURL(""),
computeEnvironment: "prod",
computeVersion: "v1",
expectedEndpoint: "compute/v1/",
expectError: true,
},
}
for _, tc := range testCases {
ctx := context.Background()
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
service, _ := compute.NewService(ctx, computeOpts...)
gotEndpoint := service.BasePath
if err != nil && !tc.expectError {
t.Fatalf("Got error %v", err)
}
if gotEndpoint != tc.expectedEndpoint && !tc.expectError {
t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint)
}
}

}

func convertStringToURL(urlString string) *url.URL {
parsedURL, err := url.ParseRequestURI(urlString)
if err != nil {
return nil
}
return parsedURL
}
2 changes: 1 addition & 1 deletion pkg/gce-pd-csi-driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ const (
)

var (
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true}
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true, "staging_v1": true, "staging_beta": true, "staging_alpha": true}
)

func isDiskReady(disk *gce.CloudDisk) (bool, error) {
Expand Down
11 changes: 1 addition & 10 deletions test/e2e/tests/single_zone_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ var _ = Describe("GCE PD CSI Driver", func() {
}()
})

It("Should pass/fail if valid/invalid compute endpoint is passed in", func() {
It("Should pass if valid compute endpoint is passed in", func() {
// gets instance set up w/o compute-endpoint set from test setup
_, err := getRandomTestContext().Client.ListVolumes()
Expect(err).To(BeNil(), "no error expected when passed valid compute url")
Expand All @@ -1295,15 +1295,6 @@ var _ = Describe("GCE PD CSI Driver", func() {

klog.Infof("Creating new driver and client for node %s\n", i.GetName())

// Create new driver and client w/ invalid endpoint
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
if err != nil {
klog.Fatalf("Failed to set up Test Context for instance %v: %w", i.GetName(), err)
}

_, err = tcInvalid.Client.ListVolumes()
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")

// Create new driver and client w/ valid, passed-in endpoint
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion test/e2e/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, computeEndpoint stri
// useful to see what's happening when debugging tests.
driverRunCmd := fmt.Sprintf("sh -c '/usr/bin/nohup %s/gce-pd-csi-driver -v=6 --endpoint=%s %s 2> %s/prog.out < /dev/null > /dev/null &'",
workspace, endpoint, strings.Join(extra_flags, " "), workspace)

config := &remote.ClientConfig{
PkgPath: pkgPath,
BinPath: binPath,
Expand Down