diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index dcf94f41a..8a122ee64 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -17,8 +17,11 @@ package main import ( "context" + "errors" "flag" + "fmt" "math/rand" + "net/url" "os" "runtime" "strings" @@ -38,7 +41,6 @@ import ( var ( cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config") endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint") - computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.") runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)") runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)") httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.") @@ -67,12 +69,13 @@ var ( maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations") formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount") + fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk") - fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk") - - enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools") - - version string + enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools") + computeEnvironment gce.Environment = gce.EnvironmentProduction + computeEndpoint *url.URL + version string + allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction} ) const ( @@ -85,6 +88,8 @@ func init() { // Use V(4) for general debug information logging // Use V(5) for GCE Cloud Provider Call informational logging // Use V(6) for extra repeated/polling information + enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment") + urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint") klog.InitFlags(flag.CommandLine) flag.Set("logtostderr", "true") } @@ -92,6 +97,7 @@ func init() { func main() { flag.Parse() rand.Seed(time.Now().UnixNano()) + klog.Infof("Operating compute environment set to: %s and computeEndpoint is set to: %v", computeEnvironment, computeEndpoint) handle() os.Exit(0) } @@ -156,7 +162,7 @@ func handle() { // Initialize requirements for the controller service var controllerServer *driver.GCEControllerServer if *runControllerService { - cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint) + cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment) if err != nil { klog.Fatalf("Failed to get cloud provider: %v", err.Error()) } @@ -205,3 +211,29 @@ func handle() { gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing) } + +func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) { + flag.Func(name, usage, func(flagValue string) error { + for _, allowedValue := range allowedComputeEnvironment { + if gce.Environment(flagValue) == allowedValue { + *target = gce.Environment(flagValue) + return nil + } + } + errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment) + return errors.New(errMsg) + }) + +} + +func urlFlag(target **url.URL, name string, usage string) { + flag.Func(name, usage, func(flagValue string) error { + computeURL, err := url.ParseRequestURI(flagValue) + if err == nil { + *target = computeURL + return nil + } + klog.Infof("Error parsing endpoint compute endpoint %v", err) + return err + }) +} diff --git a/pkg/gce-cloud-provider/compute/gce.go b/pkg/gce-cloud-provider/compute/gce.go index e2a998b16..2c769cedc 100644 --- a/pkg/gce-cloud-provider/compute/gce.go +++ b/pkg/gce-cloud-provider/compute/gce.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" "runtime" "time" @@ -37,6 +38,9 @@ import ( "k8s.io/klog/v2" ) +type Environment string +type Version string + const ( TokenURL = "https://accounts.google.com/o/oauth2/token" diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}" @@ -46,7 +50,12 @@ const ( regionURITemplate = "projects/%s/regions/%s" - replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone} + replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone} + versionV1 Version = "v1" + versionBeta Version = "beta" + versionAlpha Version = "alpha" + EnvironmentStaging Environment = "staging" + EnvironmentProduction Environment = "production" ) type CloudProvider struct { @@ -72,7 +81,7 @@ type ConfigGlobal struct { Zone string `gcfg:"zone"` } -func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) { +func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment) (*CloudProvider, error) { configFile, err := readConfig(configPath) if err != nil { return nil, err @@ -87,20 +96,23 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s return nil, err } - svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) + svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment) if err != nil { return nil, err } + klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath) - betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) + betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment) if err != nil { return nil, err } + klog.Infof("Compute endpoint for Beta version: %s", betasvc.BasePath) - alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) + alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment) if err != nil { return nil, err } + klog.Infof("Compute endpoint for Alpha version: %s", alphasvc.BasePath) project, zone, err := getProjectAndZone(configFile) if err != nil { @@ -164,16 +176,23 @@ func readConfig(configPath string) (*ConfigFile, error) { return cfg, nil } -func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) { - client, err := newOauthClient(ctx, tokenSource) +func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computealpha.Service, error) { + computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha) + if err != nil { + klog.Errorf("Failed to get compute endpoint: %s", err) + } + service, err := computealpha.NewService(ctx, computeOpts...) if err != nil { return nil, err } + service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH) + return service, nil +} - computeOpts := []option.ClientOption{option.WithHTTPClient(client)} - if computeEndpoint != "" { - betaEndpoint := fmt.Sprintf("%s/compute/beta/", computeEndpoint) - computeOpts = append(computeOpts, option.WithEndpoint(betaEndpoint)) +func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computebeta.Service, error) { + computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta) + if err != nil { + klog.Errorf("Failed to get compute endpoint: %s", err) } service, err := computebeta.NewService(ctx, computeOpts...) if err != nil { @@ -183,18 +202,12 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour return service, nil } -func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computealpha.Service, error) { - client, err := newOauthClient(ctx, tokenSource) +func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*compute.Service, error) { + computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1) if err != nil { - return nil, err - } - - computeOpts := []option.ClientOption{option.WithHTTPClient(client)} - if computeEndpoint != "" { - alphaEndpoint := fmt.Sprintf("%s/compute/alpha/", computeEndpoint) - computeOpts = append(computeOpts, option.WithEndpoint(alphaEndpoint)) + klog.Errorf("Failed to get compute endpoint: %s", err) } - service, err := computealpha.NewService(ctx, computeOpts...) + service, err := compute.NewService(ctx, computeOpts...) if err != nil { return nil, err } @@ -202,28 +215,28 @@ func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSou return service, nil } -func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) { - svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint) - return svc, err -} - -func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) { +func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) { client, err := newOauthClient(ctx, tokenSource) if err != nil { return nil, err } - computeOpts := []option.ClientOption{option.WithHTTPClient(client)} - if computeEndpoint != "" { - v1Endpoint := fmt.Sprintf("%s/compute/v1/", computeEndpoint) - computeOpts = append(computeOpts, option.WithEndpoint(v1Endpoint)) + + if computeEndpoint != nil { + computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion) + computeEndpoint.Path = computeEnvironmentSuffix + endpoint := computeEndpoint.String() + computeOpts = append(computeOpts, option.WithEndpoint(endpoint)) } - service, err := compute.NewService(ctx, computeOpts...) - if err != nil { - return nil, err + return computeOpts, nil +} + +func constructComputeEndpointPath(env Environment, version Version) string { + prefix := "" + if env == EnvironmentStaging { + prefix = fmt.Sprintf("%s_", env) } - service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH) - return service, nil + return fmt.Sprintf("compute/%s%s/", prefix, version) } func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) { diff --git a/pkg/gce-cloud-provider/compute/gce_test.go b/pkg/gce-cloud-provider/compute/gce_test.go index 5bb2aed89..49f85221d 100644 --- a/pkg/gce-cloud-provider/compute/gce_test.go +++ b/pkg/gce-cloud-provider/compute/gce_test.go @@ -18,14 +18,30 @@ limitations under the License. package gcecloudprovider import ( + "context" "errors" "fmt" "net/http" + "net/url" "testing" + "time" + "golang.org/x/oauth2" + + "google.golang.org/api/compute/v1" "google.golang.org/api/googleapi" ) +type mockTokenSource struct{} + +func (*mockTokenSource) Token() (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "access", + TokenType: "Bearer", + RefreshToken: "refresh", + Expiry: time.Now().Add(1 * time.Hour), + }, nil +} func TestIsGCEError(t *testing.T) { testCases := []struct { name string @@ -84,3 +100,61 @@ func TestIsGCEError(t *testing.T) { } } } + +func TestGetComputeVersion(t *testing.T) { + testCases := []struct { + name string + computeEndpoint *url.URL + computeEnvironment Environment + computeVersion Version + expectedEndpoint string + expectError bool + }{ + + { + name: "check for production environment", + computeEndpoint: convertStringToURL("https://compute.googleapis.com"), + computeEnvironment: EnvironmentProduction, + computeVersion: versionBeta, + expectedEndpoint: "https://compute.googleapis.com/compute/beta/", + expectError: false, + }, + { + name: "check for staging environment", + computeEndpoint: convertStringToURL("https://compute.googleapis.com"), + computeEnvironment: EnvironmentStaging, + computeVersion: versionV1, + expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/", + expectError: false, + }, + { + name: "check for random string as endpoint", + computeEndpoint: convertStringToURL(""), + computeEnvironment: "prod", + computeVersion: "v1", + expectedEndpoint: "compute/v1/", + expectError: true, + }, + } + for _, tc := range testCases { + ctx := context.Background() + computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion) + service, _ := compute.NewService(ctx, computeOpts...) + gotEndpoint := service.BasePath + if err != nil && !tc.expectError { + t.Fatalf("Got error %v", err) + } + if gotEndpoint != tc.expectedEndpoint && !tc.expectError { + t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint) + } + } + +} + +func convertStringToURL(urlString string) *url.URL { + parsedURL, err := url.ParseRequestURI(urlString) + if err != nil { + return nil + } + return parsedURL +} diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index fc8aeae66..fb4745015 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -159,7 +159,7 @@ const ( ) var ( - validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true} + validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true, "staging_v1": true, "staging_beta": true, "staging_alpha": true} ) func isDiskReady(disk *gce.CloudDisk) (bool, error) { diff --git a/test/e2e/tests/single_zone_e2e_test.go b/test/e2e/tests/single_zone_e2e_test.go index 933b7551a..58911cd34 100644 --- a/test/e2e/tests/single_zone_e2e_test.go +++ b/test/e2e/tests/single_zone_e2e_test.go @@ -1280,7 +1280,7 @@ var _ = Describe("GCE PD CSI Driver", func() { }() }) - It("Should pass/fail if valid/invalid compute endpoint is passed in", func() { + It("Should pass if valid compute endpoint is passed in", func() { // gets instance set up w/o compute-endpoint set from test setup _, err := getRandomTestContext().Client.ListVolumes() Expect(err).To(BeNil(), "no error expected when passed valid compute url") @@ -1295,15 +1295,6 @@ var _ = Describe("GCE PD CSI Driver", func() { klog.Infof("Creating new driver and client for node %s\n", i.GetName()) - // Create new driver and client w/ invalid endpoint - tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string") - if err != nil { - klog.Fatalf("Failed to set up Test Context for instance %v: %w", i.GetName(), err) - } - - _, err = tcInvalid.Client.ListVolumes() - Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url") - // Create new driver and client w/ valid, passed-in endpoint tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com") if err != nil { diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 97364bed4..c9ee708ff 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -65,7 +65,6 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, computeEndpoint stri // useful to see what's happening when debugging tests. driverRunCmd := fmt.Sprintf("sh -c '/usr/bin/nohup %s/gce-pd-csi-driver -v=6 --endpoint=%s %s 2> %s/prog.out < /dev/null > /dev/null &'", workspace, endpoint, strings.Join(extra_flags, " "), workspace) - config := &remote.ClientConfig{ PkgPath: pkgPath, BinPath: binPath,