From 2da6cafabeb680a14ad473bf3960349c2cf27551 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Fri, 17 Jan 2025 12:10:08 +0100 Subject: [PATCH] Fix notebook create on Windows (#1996) * Add unit test * test more cases * fix stage path handling * update release notes --- RELEASE-NOTES.md | 1 + .../cli/_plugins/notebook/manager.py | 6 ++-- tests/notebook/test_notebook_commands.py | 31 +++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 1a57f27c67..817e1fd0fe 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -48,6 +48,7 @@ ## Fixes and improvements * Fixed inability to add patches to lowercase quoted versions * Fixes label being set to blank instead of None when not provided. +* Fixed stage path handling for notebook commands. # v3.2.2 ## Backward incompatibility diff --git a/src/snowflake/cli/_plugins/notebook/manager.py b/src/snowflake/cli/_plugins/notebook/manager.py index f0e9ed881f..0e711d96d6 100644 --- a/src/snowflake/cli/_plugins/notebook/manager.py +++ b/src/snowflake/cli/_plugins/notebook/manager.py @@ -21,6 +21,7 @@ from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.sql_execution import SqlExecutionMixin +from snowflake.cli.api.stage_path import StagePath class NotebookManager(SqlExecutionMixin): @@ -40,8 +41,9 @@ def parse_stage_as_path(notebook_file: str) -> Path: """Parses notebook file path to pathlib.Path.""" if not notebook_file.endswith(".ipynb"): raise NotebookStagePathError(notebook_file) - stage_path = Path(notebook_file) - if len(stage_path.parts) < 2: + # we don't perform any operations on the path, so we don't need to differentiate git repository paths + stage_path = StagePath.from_stage_str(notebook_file) + if len(stage_path.parts) < 1: raise NotebookStagePathError(notebook_file) return stage_path diff --git a/tests/notebook/test_notebook_commands.py b/tests/notebook/test_notebook_commands.py index 23c9b2a076..34f377a927 100644 --- a/tests/notebook/test_notebook_commands.py +++ b/tests/notebook/test_notebook_commands.py @@ -14,6 +14,7 @@ from unittest import mock +import pytest import typer from snowflake.cli._plugins.notebook.manager import NotebookManager from snowflake.cli.api.identifiers import FQN @@ -65,3 +66,33 @@ def test_create(mock_create, runner): notebook_name=FQN.from_string("my_notebook"), notebook_file=notebook_file, ) + + +@pytest.mark.parametrize( + "stage_path", + ["@db.schema.stage", "@stage/dir/subdir", "@git_repo_stage/branch/main"], +) +@mock.patch("snowflake.connector.connect") +@mock.patch("snowflake.cli._plugins.notebook.manager.make_snowsight_url") +def test_create_query( + mock_make_snowsight_url, mock_connector, mock_ctx, runner, stage_path +): + ctx = mock_ctx() + mock_connector.return_value = ctx + mock_make_snowsight_url.return_value = "mocked_snowsight.url" + notebook_name = "my_notebook" + notebook_file = f"{stage_path}/notebook.ipynb" + result = runner.invoke( + ["notebook", "create", notebook_name, "--notebook-file", notebook_file] + ) + assert result.exit_code == 0, result.output + assert ctx.get_query() == ( + "\n" + "CREATE OR REPLACE NOTEBOOK " + f"IDENTIFIER('MockDatabase.MockSchema.{notebook_name}')\n" + f"FROM '{stage_path}'\n" + "QUERY_WAREHOUSE = 'MockWarehouse'\n" + "MAIN_FILE = 'notebook.ipynb';\n" + "// Cannot use IDENTIFIER(...)\n" + f"ALTER NOTEBOOK MockDatabase.MockSchema.{notebook_name} ADD LIVE VERSION FROM LAST;\n" + )