diff --git a/extension/src/experiments/workspace.test.ts b/extension/src/experiments/workspace.test.ts index c2365b2f46..78119ac07d 100644 --- a/extension/src/experiments/workspace.test.ts +++ b/extension/src/experiments/workspace.test.ts @@ -14,6 +14,7 @@ import { OutputChannel } from '../vscode/outputChannel' import { Title } from '../vscode/title' import { Args } from '../cli/dvc/constants' import { findOrCreateDvcYamlFile, getFileExtension } from '../fileSystem' +import { Toast } from '../vscode/toast' const mockedShowWebview = jest.fn() const mockedDisposable = jest.mocked(Disposable) @@ -595,5 +596,38 @@ describe('Experiments', () => { mockedDvcRoot ) }) + + it('should show a toast if the dvc.yaml file is invalid', async () => { + const showErrorSpy = jest.spyOn(Toast, 'showError') + + mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) + mockedListStages.mockResolvedValueOnce(undefined) + + await workspaceExperiments.getCwdThenRun(mockedCommandId) + + expect(showErrorSpy).toHaveBeenCalledWith( + 'Cannot perform task. Your dvc.yaml file is invalid.' + ) + }) + + it('should not ask to create a stage if the dvc.yaml file is invalid', async () => { + mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) + mockedListStages.mockResolvedValueOnce(undefined) + + await workspaceExperiments.getCwdThenRun(mockedCommandId) + + expect(mockedGetValidInput).not.toHaveBeenCalled() + }) + + it('should not show a toast if the dvc.yaml file is valid', async () => { + const showErrorSpy = jest.spyOn(Toast, 'showError') + + mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) + mockedListStages.mockResolvedValueOnce('train') + + await workspaceExperiments.getCwdThenRun(mockedCommandId) + + expect(showErrorSpy).not.toHaveBeenCalled() + }) }) }) diff --git a/extension/src/experiments/workspace.ts b/extension/src/experiments/workspace.ts index d434e2e610..2b1d336664 100644 --- a/extension/src/experiments/workspace.ts +++ b/extension/src/experiments/workspace.ts @@ -444,18 +444,30 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< cwd ) + if (stages === undefined) { + await Toast.showError( + 'Cannot perform task. Your dvc.yaml file is invalid.' + ) + return false + } + if (!stages) { - const stageName = await this.askForStageName() - if (!stageName) { - return false - } + return this.addPipeline(cwd) + } + return true + } - const { trainingScript, command } = await this.askForTrainingScript() - if (!trainingScript) { - return false - } - void findOrCreateDvcYamlFile(cwd, trainingScript, stageName, command) + private async addPipeline(cwd: string) { + const stageName = await this.askForStageName() + if (!stageName) { + return false + } + + const { trainingScript, command } = await this.askForTrainingScript() + if (!trainingScript) { + return false } + void findOrCreateDvcYamlFile(cwd, trainingScript, stageName, command) return true }