diff --git a/test/e2e/sparkapplication_test.go b/test/e2e/sparkapplication_test.go index a3e8829a0..1f64049a8 100644 --- a/test/e2e/sparkapplication_test.go +++ b/test/e2e/sparkapplication_test.go @@ -21,7 +21,6 @@ import ( "os" "path/filepath" "strings" - "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -29,18 +28,12 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/yaml" "github.com/kubeflow/spark-operator/api/v1beta2" "github.com/kubeflow/spark-operator/pkg/util" ) -const ( - PollInterval = 1 * time.Second - WaitTimeout = 300 * time.Second -) - var _ = Describe("Example SparkApplication", func() { Context("spark-pi", func() { ctx := context.Background() @@ -72,15 +65,7 @@ var _ = Describe("Example SparkApplication", func() { It("should complete successfully", func() { By("Waiting for SparkApplication to complete") key := types.NamespacedName{Namespace: app.Namespace, Name: app.Name} - cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) - defer cancelFunc() - Expect(wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (done bool, err error) { - err = k8sClient.Get(ctx, key, app) - if app.Status.AppState.State == v1beta2.ApplicationStateCompleted { - return true, nil - } - return false, err - })).NotTo(HaveOccurred()) + Expect(waitForSparkApplicationCompleted(ctx, key)).NotTo(HaveOccurred()) By("Checking out driver logs") driverPodName := util.GetDriverPodName(app) @@ -148,15 +133,7 @@ var _ = Describe("Example SparkApplication", func() { It("Should complete successfully", func() { By("Waiting for SparkApplication to complete") key := types.NamespacedName{Namespace: app.Namespace, Name: app.Name} - cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) - defer cancelFunc() - Expect(wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (done bool, err error) { - err = k8sClient.Get(ctx, key, app) - if app.Status.AppState.State == v1beta2.ApplicationStateCompleted { - return true, nil - } - return false, err - })).NotTo(HaveOccurred()) + Expect(waitForSparkApplicationCompleted(ctx, key)).NotTo(HaveOccurred()) By("Checking out driver logs") driverPodName := util.GetDriverPodName(app) @@ -197,15 +174,7 @@ var _ = Describe("Example SparkApplication", func() { It("Should complete successfully", func() { By("Waiting for SparkApplication to complete") key := types.NamespacedName{Namespace: app.Namespace, Name: app.Name} - cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) - defer cancelFunc() - Expect(wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (done bool, err error) { - err = k8sClient.Get(ctx, key, app) - if app.Status.AppState.State == v1beta2.ApplicationStateCompleted { - return true, nil - } - return false, err - })).NotTo(HaveOccurred()) + Expect(waitForSparkApplicationCompleted(ctx, key)).NotTo(HaveOccurred()) By("Checking out driver logs") driverPodName := util.GetDriverPodName(app) @@ -246,15 +215,7 @@ var _ = Describe("Example SparkApplication", func() { It("Should complete successfully", func() { By("Waiting for SparkApplication to complete") key := types.NamespacedName{Namespace: app.Namespace, Name: app.Name} - cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) - defer cancelFunc() - Expect(wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (done bool, err error) { - err = k8sClient.Get(ctx, key, app) - if app.Status.AppState.State == v1beta2.ApplicationStateCompleted { - return true, nil - } - return false, err - })).NotTo(HaveOccurred()) + Expect(waitForSparkApplicationCompleted(ctx, key)).NotTo(HaveOccurred()) By("Checking out driver logs") driverPodName := util.GetDriverPodName(app) diff --git a/test/e2e/suit_test.go b/test/e2e/suit_test.go index 5c8a21dd4..e409a1cbd 100644 --- a/test/e2e/suit_test.go +++ b/test/e2e/suit_test.go @@ -31,8 +31,11 @@ import ( "helm.sh/helm/v3/pkg/chart/loader" "helm.sh/helm/v3/pkg/chartutil" "helm.sh/helm/v3/pkg/cli" + admissionregistrationv1 "k8s.io/api/admissionregistration/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" @@ -53,6 +56,12 @@ import ( const ( ReleaseName = "spark-operator" ReleaseNamespace = "spark-operator" + + MutatingWebhookName = "spark-operator-webhook" + ValidatingWebhookName = "spark-operator-webhook" + + PollInterval = 1 * time.Second + WaitTimeout = 5 * time.Minute ) var ( @@ -123,7 +132,7 @@ var _ = BeforeSuite(func() { installAction.ReleaseName = ReleaseName installAction.Namespace = envSettings.Namespace() installAction.Wait = true - installAction.Timeout = 5 * time.Minute + installAction.Timeout = WaitTimeout chartPath := filepath.Join("..", "..", "charts", "spark-operator-chart") chart, err := loader.Load(chartPath) Expect(err).NotTo(HaveOccurred()) @@ -134,6 +143,12 @@ var _ = BeforeSuite(func() { release, err := installAction.Run(chart, values) Expect(err).NotTo(HaveOccurred()) Expect(release).NotTo(BeNil()) + + By("Waiting for the webhooks to be ready") + mutatingWebhookKey := types.NamespacedName{Name: MutatingWebhookName} + validatingWebhookKey := types.NamespacedName{Name: ValidatingWebhookName} + Expect(waitForMutatingWebhookReady(context.Background(), mutatingWebhookKey)).NotTo(HaveOccurred()) + Expect(waitForValidatingWebhookReady(context.Background(), validatingWebhookKey)).NotTo(HaveOccurred()) }) var _ = AfterSuite(func() { @@ -147,7 +162,7 @@ var _ = AfterSuite(func() { uninstallAction := action.NewUninstall(actionConfig) Expect(uninstallAction).NotTo(BeNil()) uninstallAction.Wait = true - uninstallAction.Timeout = 5 * time.Minute + uninstallAction.Timeout = WaitTimeout resp, err := uninstallAction.Run(ReleaseName) Expect(err).To(BeNil()) Expect(resp).NotTo(BeNil()) @@ -160,3 +175,95 @@ var _ = AfterSuite(func() { err = testEnv.Stop() Expect(err).ToNot(HaveOccurred()) }) + +func waitForMutatingWebhookReady(ctx context.Context, key types.NamespacedName) error { + cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) + defer cancelFunc() + + mutatingWebhook := admissionregistrationv1.MutatingWebhookConfiguration{} + err := wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (bool, error) { + if err := k8sClient.Get(ctx, key, &mutatingWebhook); err != nil { + return false, err + } + + for _, wh := range mutatingWebhook.Webhooks { + // Checkout webhook CA certificate + if wh.ClientConfig.CABundle == nil { + return false, nil + } + + // Checkout webhook service endpoints + svcRef := wh.ClientConfig.Service + if svcRef == nil { + return false, fmt.Errorf("webhook service is nil") + } + endpoints := corev1.Endpoints{} + endpointsKey := types.NamespacedName{Namespace: svcRef.Namespace, Name: svcRef.Name} + if err := k8sClient.Get(ctx, endpointsKey, &endpoints); err != nil { + return false, err + } + if len(endpoints.Subsets) == 0 { + return false, nil + } + } + + return true, nil + }) + return err +} + +func waitForValidatingWebhookReady(ctx context.Context, key types.NamespacedName) error { + cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) + defer cancelFunc() + + validatingWebhook := admissionregistrationv1.ValidatingWebhookConfiguration{} + err := wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (bool, error) { + if err := k8sClient.Get(ctx, key, &validatingWebhook); err != nil { + return false, err + } + + for _, wh := range validatingWebhook.Webhooks { + // Checkout webhook CA certificate + if wh.ClientConfig.CABundle == nil { + return false, nil + } + + // Checkout webhook service endpoints + svcRef := wh.ClientConfig.Service + if svcRef == nil { + return false, fmt.Errorf("webhook service is nil") + } + endpoints := corev1.Endpoints{} + endpointsKey := types.NamespacedName{Namespace: svcRef.Namespace, Name: svcRef.Name} + if err := k8sClient.Get(ctx, endpointsKey, &endpoints); err != nil { + return false, err + } + if len(endpoints.Subsets) == 0 { + return false, nil + } + } + + return true, nil + }) + return err +} + +func waitForSparkApplicationCompleted(ctx context.Context, key types.NamespacedName) error { + cancelCtx, cancelFunc := context.WithTimeout(ctx, WaitTimeout) + defer cancelFunc() + + app := &v1beta2.SparkApplication{} + err := wait.PollUntilContextCancel(cancelCtx, PollInterval, true, func(ctx context.Context) (bool, error) { + if err := k8sClient.Get(ctx, key, app); err != nil { + return false, err + } + switch app.Status.AppState.State { + case v1beta2.ApplicationStateFailedSubmission, v1beta2.ApplicationStateFailed: + return false, fmt.Errorf(app.Status.AppState.ErrorMessage) + case v1beta2.ApplicationStateCompleted: + return true, nil + } + return false, nil + }) + return err +}