diff --git a/components/ide/jetbrains/launcher/main.go b/components/ide/jetbrains/launcher/main.go index da16ea97d8bae0..c39fd696ea5aad 100644 --- a/components/ide/jetbrains/launcher/main.go +++ b/components/ide/jetbrains/launcher/main.go @@ -161,22 +161,7 @@ func main() { } if launchCtx.warmup { - ctx, cancel := context.WithCancel(context.Background()) - - out: - for { - log.Debug("launcher: wait for tasks to finish before running warmup") - err := waitForTasksToFinish(ctx, cancel) - if err != nil { - log.WithError(err).Warn("launcher: failed to observe tasks completion") - } - select { - case <-ctx.Done(): - break out - case <-time.After(1 * time.Second): - } - } - + starWarmup() launch(launchCtx) return } @@ -187,6 +172,31 @@ func main() { serve(launchCtx) } +func starWarmup() { + ctx, cancel := context.WithCancel(context.Background()) + var conn *grpc.ClientConn + var err error + + for { + conn, err = dial(ctx) + if err == nil { + log.Debug("launcher: wait for tasks to finish before running warmup") + finished, err := waitForTasksToFinish(ctx, conn) + if err != nil { + log.WithError(err).Warn("launcher: failed to observe tasks completion") + } + if finished { + cancel() + } + } + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Second): + } + } +} + func serve(launchCtx *LaunchContext) { debugAgentPrefix := "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:" http.HandleFunc("/debug", func(w http.ResponseWriter, r *http.Request) { @@ -870,20 +880,16 @@ func resolveProjectContextDir(launchCtx *LaunchContext) string { return launchCtx.projectDir } -func waitForTasksToFinish(ctx context.Context, cancel context.CancelFunc) error { - conn, err := dial(ctx) - if err != nil { - return err - } +func waitForTasksToFinish(ctx context.Context, conn *grpc.ClientConn) (bool, error) { client := supervisor.NewStatusServiceClient(conn) tasksResponse, err := client.TasksStatus(ctx, &supervisor.TasksStatusRequest{Observe: true}) if err != nil { - return xerrors.Errorf("failed get tasks status client: %w", err) + return false, xerrors.Errorf("failed get tasks status client: %w", err) } for { resp, err := tasksResponse.Recv() if err != nil { - return err + return false, err } var runningTasksCounter int @@ -893,12 +899,10 @@ func waitForTasksToFinish(ctx context.Context, cancel context.CancelFunc) error } } if runningTasksCounter == 0 { - cancel() + return true, nil } - - return nil } - + return false, nil } func dial(ctx context.Context) (*grpc.ClientConn, error) {