diff --git a/packages/jupyter-ai/src/commands.ts b/packages/jupyter-ai/src/commands.ts index 7c9418de6..5ab39273a 100644 --- a/packages/jupyter-ai/src/commands.ts +++ b/packages/jupyter-ai/src/commands.ts @@ -13,7 +13,7 @@ import { AiService } from './handler'; import { OpenTaskDialog } from './components/open-task-dialog'; import { ClosableDialog } from './widgets/closable-dialog'; import { InsertionContext, insertOutput } from './inserter'; -import { getTextSelection, getEditor } from './utils'; +import { getTextSelection, getEditor, getCellIndex } from './utils'; /** * Creates a placeholder markdown cell either above/below the currently active @@ -53,7 +53,7 @@ function insertPlaceholderCell( * Replaces a cell with a markdown cell containing a string. */ function replaceWithMarkdown(notebook: Notebook, cellId: string, body: string) { - const cellIdx = findIndex(notebook, cellId); + const cellIdx = getCellIndex(notebook, cellId); if (cellIdx === -1) { return; } @@ -64,15 +64,8 @@ function replaceWithMarkdown(notebook: Notebook, cellId: string, body: string) { NotebookActions.run(notebook); } -function findIndex(notebook: Notebook, id: string): number { - const idx = notebook.model?.sharedModel.cells.findIndex( - cell => cell.getId() === id - ); - return idx === undefined ? -1 : idx; -} - function deleteCell(notebook: Notebook, id: string): void { - const idx = findIndex(notebook, id); + const idx = getCellIndex(notebook, id); if (idx !== -1) { notebook.model?.sharedModel.deleteCell(idx); } diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index ac0103d22..431dc84d8 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -102,8 +102,13 @@ function ChatBody(): JSX.Element { setLoading(false); } - if (replaceSelection) { - replaceSelectionFn(response.output); + if (replaceSelection && selection) { + const { cellId, ...selectionProps } = selection; + replaceSelectionFn({ + ...selectionProps, + ...(cellId && { cellId }), + text: response.output + }); } setMessageGroups(messageGroups => [ ...messageGroups, @@ -147,7 +152,7 @@ function ChatBody(): JSX.Element { value={input} onChange={setInput} onSend={onSend} - hasSelection={!!selection} + hasSelection={!!selection?.text} includeSelection={includeSelection} toggleIncludeSelection={() => setIncludeSelection(includeSelection => !includeSelection) diff --git a/packages/jupyter-ai/src/contexts/selection-context.tsx b/packages/jupyter-ai/src/contexts/selection-context.tsx index 7258d98bb..3c7a78fca 100644 --- a/packages/jupyter-ai/src/contexts/selection-context.tsx +++ b/packages/jupyter-ai/src/contexts/selection-context.tsx @@ -1,10 +1,10 @@ import React, { useCallback, useContext, useEffect, useState } from 'react'; -import { SelectionWatcher } from '../selection-watcher'; +import { Selection, SelectionWatcher } from '../selection-watcher'; const SelectionContext = React.createContext< - [string, (value: string) => unknown] + [Selection | null, (value: Selection) => unknown] >([ - '', + null, () => { /* noop */ } @@ -23,7 +23,7 @@ export function SelectionContextProvider({ selectionWatcher, children }: SelectionContextProviderProps) { - const [selection, setSelection] = useState(''); + const [selection, setSelection] = useState(null); /** * Effect: subscribe to SelectionWatcher @@ -35,7 +35,7 @@ export function SelectionContextProvider({ }, []); const replaceSelection = useCallback( - (value: string) => { + (value: Selection) => { selectionWatcher.replaceSelection(value); }, [selectionWatcher] diff --git a/packages/jupyter-ai/src/selection-watcher.ts b/packages/jupyter-ai/src/selection-watcher.ts index 40e2bbbe8..f0e61cf0d 100644 --- a/packages/jupyter-ai/src/selection-watcher.ts +++ b/packages/jupyter-ai/src/selection-watcher.ts @@ -1,9 +1,95 @@ import { JupyterFrontEnd, LabShell } from '@jupyterlab/application'; import { DocumentWidget } from '@jupyterlab/docregistry'; +import { CodeEditor } from '@jupyterlab/codeeditor'; +import { CodeMirrorEditor } from '@jupyterlab/codemirror'; +import { FileEditor } from '@jupyterlab/fileeditor'; +import { Notebook } from '@jupyterlab/notebook'; + +import { find } from '@lumino/algorithm'; import { Widget } from '@lumino/widgets'; import { Signal } from '@lumino/signaling'; -import { getEditor, getTextSelection } from './utils'; +import { getCellIndex } from './utils'; + +/** + * Gets the editor instance used by a document widget. Returns `null` if unable. + */ +function getEditor(widget: Widget | null) { + if (!(widget instanceof DocumentWidget)) { + return null; + } + + let editor: CodeEditor.IEditor | undefined; + const { content } = widget; + + if (content instanceof FileEditor) { + editor = content.editor; + } else if (content instanceof Notebook) { + editor = content.activeCell?.editor; + } + + if (!(editor instanceof CodeMirrorEditor)) { + return undefined; + } + + return editor; +} + +/** + * Gets a Selection object from a document widget. Returns `null` if unable. + */ +function getTextSelection(widget: Widget | null): Selection | null { + const editor = getEditor(widget); + // widget type check is redundant but hints the type to TypeScript + if (!editor || !(widget instanceof DocumentWidget)) { + return null; + } + + let cellId: string | undefined = undefined; + if (widget.content instanceof Notebook) { + cellId = widget.content.activeCell?.model.id; + } + + let { start, end, ...selectionObj } = editor.getSelection(); + const startOffset = editor.getOffsetAt(start); + const endOffset = editor.getOffsetAt(end); + const text = editor.model.sharedModel + .getSource() + .substring(startOffset, endOffset); + + // ensure start <= end + // required for editor.model.sharedModel.updateSource() + if (startOffset > endOffset) { + [start, end] = [end, start]; + } + + return { + ...selectionObj, + start, + end, + text, + widgetId: widget.id, + ...(cellId && { + cellId + }) + }; +} + +export type Selection = CodeEditor.ITextSelection & { + /** + * The text within the selection as a string. + */ + text: string; + /** + * The ID of the document widget in which the selection was made. + */ + widgetId: string; + /** + * The ID of the cell in which the selection was made, if the original widget + * was a notebook. + */ + cellId?: string; +}; export class SelectionWatcher { constructor(shell: JupyterFrontEnd.IShell) { @@ -11,9 +97,11 @@ export class SelectionWatcher { throw 'Shell is not an instance of LabShell. Jupyter AI does not currently support custom shells.'; } - shell.currentChanged.connect((sender, args) => { + this._shell = shell; + this._shell.currentChanged.connect((sender, args) => { this._mainAreaWidget = args.newValue; }); + setInterval(this._poll.bind(this), 200); } @@ -21,24 +109,51 @@ export class SelectionWatcher { return this._selectionChanged; } - replaceSelection(value: string) { - if (!(this._mainAreaWidget instanceof DocumentWidget)) { + replaceSelection(selection: Selection) { + // unfortunately shell.currentWidget doesn't update synchronously after + // shell.activateById(), which is why we have to get a reference to the + // widget manually. + const widget = find( + this._shell.widgets(), + widget => widget.id === selection.widgetId + ); + if (!(widget instanceof DocumentWidget)) { return; } - const editor = getEditor(this._mainAreaWidget.content); - editor?.replaceSelection?.(value); - } + // activate the widget if not already active + this._shell.activateById(selection.widgetId); - protected _poll() { - if (!(this._mainAreaWidget instanceof DocumentWidget)) { + // activate notebook cell if specified + if (widget.content instanceof Notebook && selection.cellId) { + const cellIndex = getCellIndex(widget.content, selection.cellId); + if (cellIndex !== -1) { + widget.content.activeCellIndex = cellIndex; + } + } + + // get editor instance + const editor = getEditor(widget); + if (!editor) { return; } + editor.model.sharedModel.updateSource( + editor.getOffsetAt(selection.start), + editor.getOffsetAt(selection.end), + selection.text + ); + const newPosition = editor.getPositionAt( + editor.getOffsetAt(selection.start) + selection.text.length + ); + editor.setSelection({ start: newPosition, end: newPosition }); + } + + protected _poll() { const prevSelection = this._selection; - const currSelection = getTextSelection(this._mainAreaWidget.content); + const currSelection = getTextSelection(this._mainAreaWidget); - if (prevSelection === currSelection) { + if (prevSelection?.text === currSelection?.text) { return; } @@ -46,7 +161,8 @@ export class SelectionWatcher { this._selectionChanged.emit(currSelection); } + protected _shell: LabShell; protected _mainAreaWidget: Widget | null = null; - protected _selection = ''; - protected _selectionChanged = new Signal(this); + protected _selection: Selection | null = null; + protected _selectionChanged = new Signal(this); } diff --git a/packages/jupyter-ai/src/utils.ts b/packages/jupyter-ai/src/utils.ts index 0df6b3070..070c5d403 100644 --- a/packages/jupyter-ai/src/utils.ts +++ b/packages/jupyter-ai/src/utils.ts @@ -36,3 +36,13 @@ export function getEditor(widget: Widget): CodeEditor.IEditor | undefined { return editor; } + +/** + * Gets the index of the cell associated with `cellId`. + */ +export function getCellIndex(notebook: Notebook, cellId: string): number { + const idx = notebook.model?.sharedModel.cells.findIndex( + cell => cell.getId() === cellId + ); + return idx === undefined ? -1 : idx; +}