diff --git a/src/__mocks__/vscode.ts b/src/__mocks__/vscode.ts index ceabba2..15cb5b4 100644 --- a/src/__mocks__/vscode.ts +++ b/src/__mocks__/vscode.ts @@ -16,6 +16,14 @@ export function createVSCodeMock(vi: VitestUtils) { }; }); + vscode.env = vscode.env || {}; + vscode.env.asExternalUri = vi.fn().mockImplementation(async (uri) => uri); + enum QuickPickItemKind { + Separator = -1, + Default = 0, + } + vscode.QuickPickItemKind = QuickPickItemKind; + vscode.notebooks = vscode.notebooks || {}; vscode.notebooks.createNotebookController = vi .fn() diff --git a/src/browser/__tests__/panel.test.ts b/src/browser/__tests__/panel.test.ts index 52ba48f..3186201 100644 --- a/src/browser/__tests__/panel.test.ts +++ b/src/browser/__tests__/panel.test.ts @@ -29,6 +29,7 @@ describe("Panel", () => { diff --git a/src/commands/__tests__/show-commands.test.ts b/src/commands/__tests__/show-commands.test.ts index c3cbe4f..bdcdeda 100644 --- a/src/commands/__tests__/show-commands.test.ts +++ b/src/commands/__tests__/show-commands.test.ts @@ -16,24 +16,34 @@ describe("showCommands", () => { [ "$(split-horizontal) Open outputs in embedded browser", "$(link-external) Open outputs in system browser", + "", "$(refresh) Restart kernel", - "$(question) Show documentation", "$(export) Export notebook as...", + "", + "$(question) View marimo documentation", + "$(comment-discussion) Join Discord community", + "$(settings) Edit settings", + "$(info) Status: stopped", ] `); }); it("should show commands for non active Controller", async () => { - const commands = showMarimoControllerCommands( - await createMockController(), + const commands = ( + await showMarimoControllerCommands(await createMockController()) ).filter((index) => index.if !== false); expect(commands.map((c) => c.label)).toMatchInlineSnapshot(` [ "$(notebook) Start as VSCode notebook", "$(zap) Start in marimo editor (edit)", - "$(remote-explorer-documentation) Start in marimo editor (run)", - "$(question) Show documentation", + "$(preview) Start in marimo editor (run)", + "", "$(export) Export notebook as...", + "", + "$(question) View marimo documentation", + "$(comment-discussion) Join Discord community", + "$(settings) Edit settings", + "$(info) Status: stopped", ] `); }); @@ -42,19 +52,24 @@ describe("showCommands", () => { const controller = await createMockController(); controller.active = true; controller.currentMode = "run"; - const commands = showMarimoControllerCommands(controller).filter( + const commands = (await showMarimoControllerCommands(controller)).filter( (index) => index.if !== false, ); expect(commands.map((c) => c.label)).toMatchInlineSnapshot(` [ + "", "$(split-horizontal) Open in embedded browser", "$(link-external) Open in system browser", "$(refresh) Restart marimo kernel", "$(package) Switch to edit mode", "$(terminal) Show Terminal", "$(close) Stop kernel", - "$(question) Show documentation", "$(export) Export notebook as...", + "", + "$(question) View marimo documentation", + "$(comment-discussion) Join Discord community", + "$(settings) Edit settings", + "$(info) Status: stopped", ] `); }); @@ -63,19 +78,24 @@ describe("showCommands", () => { const controller = await createMockController(); controller.active = true; controller.currentMode = "edit"; - const commands = showMarimoControllerCommands(controller).filter( + const commands = (await showMarimoControllerCommands(controller)).filter( (index) => index.if !== false, ); expect(commands.map((c) => c.label)).toMatchInlineSnapshot(` [ + "", "$(split-horizontal) Open in embedded browser", "$(link-external) Open in system browser", "$(refresh) Restart marimo kernel", "$(package) Switch to run mode", "$(terminal) Show Terminal", "$(close) Stop kernel", - "$(question) Show documentation", "$(export) Export notebook as...", + "", + "$(question) View marimo documentation", + "$(comment-discussion) Join Discord community", + "$(settings) Edit settings", + "$(info) Status: stopped", ] `); }); diff --git a/src/commands/show-commands.ts b/src/commands/show-commands.ts index 6968847..6b6805a 100644 --- a/src/commands/show-commands.ts +++ b/src/commands/show-commands.ts @@ -85,9 +85,9 @@ export function showKernelCommands(kernel: Kernel): CommandPickItem[] { ]; } -export function showMarimoControllerCommands( +export async function showMarimoControllerCommands( controller: MarimoController, -): CommandPickItem[] { +): Promise { return [ // Non-active commands { @@ -117,7 +117,7 @@ export function showMarimoControllerCommands( // Active commands { label: "$(split-horizontal) Open in embedded browser", - description: controller.url, + description: await controller.url(), handler() { controller.open("embedded"); }, @@ -125,7 +125,7 @@ export function showMarimoControllerCommands( }, { label: "$(link-external) Open in system browser", - description: controller.url, + description: await controller.url(), handler() { controller.open("system"); }, diff --git a/src/config.ts b/src/config.ts index 689a8c4..cdc55ec 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,4 +1,5 @@ -import { workspace } from "vscode"; +import { Uri, env, workspace } from "vscode"; +import { logger } from "./logger"; export function getConfig(key: string): T | undefined; export function getConfig(key: string, v: T): T; @@ -48,6 +49,32 @@ export const Config = { }, }; -export function composeUrl(port: number) { - return `${Config.https ? "https" : "http"}://${Config.host}:${port}`; +export async function composeUrl(port: number): Promise { + const url = `${Config.https ? "https" : "http"}://${Config.host}:${port}/`; + try { + const externalUri = await env.asExternalUri(Uri.parse(url)); + const externalUrl = externalUri.toString(); + if (externalUrl !== url) { + logger.log("Mapping to external url", externalUrl, "from", url); + } + return externalUrl; + } catch (e) { + logger.error("Failed to create external url", url, e); + return url; + } +} + +export async function composeWsUrl(port: number): Promise { + const url = `${Config.https ? "wss" : "ws"}://${Config.host}:${port}/`; + try { + const externalUri = await env.asExternalUri(Uri.parse(url)); + const externalUrl = externalUri.toString(); + if (externalUrl !== url) { + logger.log("Mapping to external url", externalUrl, "from", url); + } + return externalUrl; + } catch (e) { + logger.error("Failed to create external url", url, e); + return url; + } } diff --git a/src/launcher/controller.ts b/src/launcher/controller.ts index bbf5ae3..1b3dfa3 100644 --- a/src/launcher/controller.ts +++ b/src/launcher/controller.ts @@ -18,6 +18,7 @@ import { statusBarManager } from "../ui/status-bar"; import { MarimoCmdBuilder } from "../utils/cmd"; import { ping } from "../utils/network"; import { getFocusedMarimoTextEditor, isMarimoApp } from "../utils/query"; +import { asURL } from "../utils/url"; import { ServerManager } from "./server-manager"; import { type IMarimoTerminal, MarimoTerminal } from "./terminal"; @@ -119,12 +120,13 @@ export class MarimoController implements Disposable { this.panel.show(); } + const url = await this.url(); if (browser === "system") { // Close the panel if opened this.panel.dispose(); - await env.openExternal(Uri.parse(this.url)); + await env.openExternal(Uri.parse(url)); } else if (browser === "embedded") { - await this.panel.create(this.url); + await this.panel.create(url); this.panel.show(); } @@ -149,7 +151,7 @@ export class MarimoController implements Disposable { return; } - const url = composeUrl(port); + const url = await composeUrl(port); if (!(await ping(url))) { return; @@ -181,11 +183,11 @@ export class MarimoController implements Disposable { return folderName ? `${folderName}/${fileName}` : fileName; } - public get url() { + public async url(): Promise { if (!this.port) { return ""; } - const url = new URL(composeUrl(this.port)); + const url = asURL(await composeUrl(this.port)); if (this.currentMode === "edit") { url.searchParams.set("file", this.file.uri.fsPath); } diff --git a/src/launcher/server-manager.ts b/src/launcher/server-manager.ts index badab67..0e266f9 100644 --- a/src/launcher/server-manager.ts +++ b/src/launcher/server-manager.ts @@ -1,3 +1,4 @@ +import { join } from "node:path"; import * as vscode from "vscode"; import { Config, composeUrl } from "../config"; import { logger as l, logger } from "../logger"; @@ -106,7 +107,8 @@ export class ServerManager { */ private async isHealthy(port: number): Promise { try { - const health = await fetch(`${composeUrl(port)}/health`); + const baseUrl = await composeUrl(port); + const health = await fetch(join(baseUrl, "health")); return health.ok; } catch { return false; diff --git a/src/launcher/utils.ts b/src/launcher/utils.ts index f2f0be3..fb6d83b 100644 --- a/src/launcher/utils.ts +++ b/src/launcher/utils.ts @@ -1,6 +1,7 @@ import { parse } from "node-html-parser"; import { composeUrl } from "../config"; import type { MarimoConfig, SkewToken } from "../notebook/marimo/types"; +import { asURL } from "../utils/url"; /** * Grabs the index.html of the marimo server and extracts @@ -14,7 +15,7 @@ export async function fetchMarimoStartupValues(port: number): Promise<{ version: string; userConfig: MarimoConfig; }> { - const url = new URL(composeUrl(port)); + const url = asURL(await composeUrl(port)); const response = await fetch(url); if (!response.ok) { @@ -24,7 +25,7 @@ export async function fetchMarimoStartupValues(port: number): Promise<{ } // If was redirected to /auth/login, then show a message that an existing server is running - if (new URL(response.url).pathname.startsWith("/auth/login")) { + if (asURL(response.url).pathname.startsWith("/auth/login")) { throw new Error( `An existing marimo server created outside of vscode is running at this url: ${url.toString()}`, ); diff --git a/src/notebook/kernel.ts b/src/notebook/kernel.ts index 100a01c..82831af 100644 --- a/src/notebook/kernel.ts +++ b/src/notebook/kernel.ts @@ -7,6 +7,7 @@ import { Deferred } from "../utils/deferred"; import { invariant } from "../utils/invariant"; import { LogMethodCalls } from "../utils/log"; import { closeNotebookEditor } from "../utils/show"; +import { asURL } from "../utils/url"; import type { KernelKey } from "./common/key"; import { getCellMetadata, setCellMetadata } from "./common/metadata"; import { MARKDOWN_LANGUAGE_ID, PYTHON_LANGUAGE_ID } from "./constants"; @@ -133,7 +134,7 @@ export class Kernel implements IKernel { this.panel.show(); } - const url = new URL(composeUrl(this.opts.port)); + const url = asURL(await composeUrl(this.opts.port)); url.searchParams.set("kiosk", "true"); url.searchParams.set("file", this.kernelKey); diff --git a/src/notebook/marimo/bridge.ts b/src/notebook/marimo/bridge.ts index cdbab41..3276668 100644 --- a/src/notebook/marimo/bridge.ts +++ b/src/notebook/marimo/bridge.ts @@ -1,7 +1,8 @@ +import { join } from "node:path"; import createClient from "openapi-fetch"; import * as vscode from "vscode"; import { WebSocket } from "ws"; -import { Config, composeUrl } from "../../config"; +import { composeUrl, composeWsUrl } from "../../config"; import { getGlobalState } from "../../ctx"; import { MarimoExplorer, @@ -19,6 +20,7 @@ import { Deferred } from "../../utils/deferred"; import { logNever } from "../../utils/invariant"; import { LogMethodCalls } from "../../utils/log"; import { SingleMessage } from "../../utils/single-promise"; +import { asURL } from "../../utils/url"; import type { KernelKey } from "../common/key"; import { type CellOp, @@ -79,14 +81,12 @@ export class MarimoBridge implements ILifecycle { @LogMethodCalls() async start(): Promise { // Create URLs - const host = Config.host; - const https = Config.https; this.sessionId = SessionId.create(); - const wsProtocol = https ? "wss" : "ws"; - const wsURL = new URL(`${wsProtocol}://${host}:${this.port}/ws`); + const wsBaseUrl = await composeWsUrl(this.port); + const wsURL = asURL(join(wsBaseUrl, "ws")); wsURL.searchParams.set("session_id", this.sessionId); wsURL.searchParams.set("file", this.kernelKey); - const httpURL = composeUrl(this.port); + const httpURL = await composeUrl(this.port); // Create WebSocket this.socket = new WebSocket(wsURL); @@ -173,9 +173,9 @@ export class MarimoBridge implements ILifecycle { this.start(); } - static getRunningNotebooks(port: number, skewToken: SkewToken) { + static async getRunningNotebooks(port: number, skewToken: SkewToken) { const client = createClient({ - baseUrl: composeUrl(port), + baseUrl: await composeUrl(port), }); client.use({ onRequest: (req) => { @@ -186,13 +186,13 @@ export class MarimoBridge implements ILifecycle { return client.POST("/api/home/running_notebooks"); } - static shutdownSession( + static async shutdownSession( port: number, skewToken: SkewToken, sessionId: string, ) { const client = createClient({ - baseUrl: composeUrl(port), + baseUrl: await composeUrl(port), }); client.use({ onRequest: (req) => { diff --git a/src/utils/network.ts b/src/utils/network.ts index 34a2bac..9917a35 100644 --- a/src/utils/network.ts +++ b/src/utils/network.ts @@ -6,7 +6,7 @@ import { composeUrl } from "../config"; * Check if a port is free */ async function isPortFree(port: number) { - const healthy = await ping(composeUrl(port)); + const healthy = await ping(await composeUrl(port)); return !healthy; } diff --git a/src/utils/url.ts b/src/utils/url.ts new file mode 100644 index 0000000..d5c48ed --- /dev/null +++ b/src/utils/url.ts @@ -0,0 +1,10 @@ +import { logger } from "../logger"; + +export function asURL(url: string): URL { + try { + return new URL(url); + } catch (e) { + logger.error("Failed to parse url", url, e); + throw e; + } +}