Skip to content

Commit

Permalink
make selection more robust (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq authored Apr 12, 2023
1 parent 83efdeb commit b828199
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 31 deletions.
13 changes: 3 additions & 10 deletions packages/jupyter-ai/src/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { AiService } from './handler';
import { OpenTaskDialog } from './components/open-task-dialog';
import { ClosableDialog } from './widgets/closable-dialog';
import { InsertionContext, insertOutput } from './inserter';
import { getTextSelection, getEditor } from './utils';
import { getTextSelection, getEditor, getCellIndex } from './utils';

/**
* Creates a placeholder markdown cell either above/below the currently active
Expand Down Expand Up @@ -53,7 +53,7 @@ function insertPlaceholderCell(
* Replaces a cell with a markdown cell containing a string.
*/
function replaceWithMarkdown(notebook: Notebook, cellId: string, body: string) {
const cellIdx = findIndex(notebook, cellId);
const cellIdx = getCellIndex(notebook, cellId);
if (cellIdx === -1) {
return;
}
Expand All @@ -64,15 +64,8 @@ function replaceWithMarkdown(notebook: Notebook, cellId: string, body: string) {
NotebookActions.run(notebook);
}

function findIndex(notebook: Notebook, id: string): number {
const idx = notebook.model?.sharedModel.cells.findIndex(
cell => cell.getId() === id
);
return idx === undefined ? -1 : idx;
}

function deleteCell(notebook: Notebook, id: string): void {
const idx = findIndex(notebook, id);
const idx = getCellIndex(notebook, id);
if (idx !== -1) {
notebook.model?.sharedModel.deleteCell(idx);
}
Expand Down
11 changes: 8 additions & 3 deletions packages/jupyter-ai/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ function ChatBody(): JSX.Element {
setLoading(false);
}

if (replaceSelection) {
replaceSelectionFn(response.output);
if (replaceSelection && selection) {
const { cellId, ...selectionProps } = selection;
replaceSelectionFn({
...selectionProps,
...(cellId && { cellId }),
text: response.output
});
}
setMessageGroups(messageGroups => [
...messageGroups,
Expand Down Expand Up @@ -147,7 +152,7 @@ function ChatBody(): JSX.Element {
value={input}
onChange={setInput}
onSend={onSend}
hasSelection={!!selection}
hasSelection={!!selection?.text}
includeSelection={includeSelection}
toggleIncludeSelection={() =>
setIncludeSelection(includeSelection => !includeSelection)
Expand Down
10 changes: 5 additions & 5 deletions packages/jupyter-ai/src/contexts/selection-context.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import React, { useCallback, useContext, useEffect, useState } from 'react';
import { SelectionWatcher } from '../selection-watcher';
import { Selection, SelectionWatcher } from '../selection-watcher';

const SelectionContext = React.createContext<
[string, (value: string) => unknown]
[Selection | null, (value: Selection) => unknown]
>([
'',
null,
() => {
/* noop */
}
Expand All @@ -23,7 +23,7 @@ export function SelectionContextProvider({
selectionWatcher,
children
}: SelectionContextProviderProps) {
const [selection, setSelection] = useState('');
const [selection, setSelection] = useState<Selection | null>(null);

/**
* Effect: subscribe to SelectionWatcher
Expand All @@ -35,7 +35,7 @@ export function SelectionContextProvider({
}, []);

const replaceSelection = useCallback(
(value: string) => {
(value: Selection) => {
selectionWatcher.replaceSelection(value);
},
[selectionWatcher]
Expand Down
142 changes: 129 additions & 13 deletions packages/jupyter-ai/src/selection-watcher.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,168 @@
import { JupyterFrontEnd, LabShell } from '@jupyterlab/application';
import { DocumentWidget } from '@jupyterlab/docregistry';
import { CodeEditor } from '@jupyterlab/codeeditor';
import { CodeMirrorEditor } from '@jupyterlab/codemirror';
import { FileEditor } from '@jupyterlab/fileeditor';
import { Notebook } from '@jupyterlab/notebook';

import { find } from '@lumino/algorithm';
import { Widget } from '@lumino/widgets';
import { Signal } from '@lumino/signaling';

import { getEditor, getTextSelection } from './utils';
import { getCellIndex } from './utils';

/**
* Gets the editor instance used by a document widget. Returns `null` if unable.
*/
function getEditor(widget: Widget | null) {
if (!(widget instanceof DocumentWidget)) {
return null;
}

let editor: CodeEditor.IEditor | undefined;
const { content } = widget;

if (content instanceof FileEditor) {
editor = content.editor;
} else if (content instanceof Notebook) {
editor = content.activeCell?.editor;
}

if (!(editor instanceof CodeMirrorEditor)) {
return undefined;
}

return editor;
}

/**
* Gets a Selection object from a document widget. Returns `null` if unable.
*/
function getTextSelection(widget: Widget | null): Selection | null {
const editor = getEditor(widget);
// widget type check is redundant but hints the type to TypeScript
if (!editor || !(widget instanceof DocumentWidget)) {
return null;
}

let cellId: string | undefined = undefined;
if (widget.content instanceof Notebook) {
cellId = widget.content.activeCell?.model.id;
}

let { start, end, ...selectionObj } = editor.getSelection();
const startOffset = editor.getOffsetAt(start);
const endOffset = editor.getOffsetAt(end);
const text = editor.model.sharedModel
.getSource()
.substring(startOffset, endOffset);

// ensure start <= end
// required for editor.model.sharedModel.updateSource()
if (startOffset > endOffset) {
[start, end] = [end, start];
}

return {
...selectionObj,
start,
end,
text,
widgetId: widget.id,
...(cellId && {
cellId
})
};
}

export type Selection = CodeEditor.ITextSelection & {
/**
* The text within the selection as a string.
*/
text: string;
/**
* The ID of the document widget in which the selection was made.
*/
widgetId: string;
/**
* The ID of the cell in which the selection was made, if the original widget
* was a notebook.
*/
cellId?: string;
};

export class SelectionWatcher {
constructor(shell: JupyterFrontEnd.IShell) {
if (!(shell instanceof LabShell)) {
throw 'Shell is not an instance of LabShell. Jupyter AI does not currently support custom shells.';
}

shell.currentChanged.connect((sender, args) => {
this._shell = shell;
this._shell.currentChanged.connect((sender, args) => {
this._mainAreaWidget = args.newValue;
});

setInterval(this._poll.bind(this), 200);
}

get selectionChanged() {
return this._selectionChanged;
}

replaceSelection(value: string) {
if (!(this._mainAreaWidget instanceof DocumentWidget)) {
replaceSelection(selection: Selection) {
// unfortunately shell.currentWidget doesn't update synchronously after
// shell.activateById(), which is why we have to get a reference to the
// widget manually.
const widget = find(
this._shell.widgets(),
widget => widget.id === selection.widgetId
);
if (!(widget instanceof DocumentWidget)) {
return;
}

const editor = getEditor(this._mainAreaWidget.content);
editor?.replaceSelection?.(value);
}
// activate the widget if not already active
this._shell.activateById(selection.widgetId);

protected _poll() {
if (!(this._mainAreaWidget instanceof DocumentWidget)) {
// activate notebook cell if specified
if (widget.content instanceof Notebook && selection.cellId) {
const cellIndex = getCellIndex(widget.content, selection.cellId);
if (cellIndex !== -1) {
widget.content.activeCellIndex = cellIndex;
}
}

// get editor instance
const editor = getEditor(widget);
if (!editor) {
return;
}

editor.model.sharedModel.updateSource(
editor.getOffsetAt(selection.start),
editor.getOffsetAt(selection.end),
selection.text
);
const newPosition = editor.getPositionAt(
editor.getOffsetAt(selection.start) + selection.text.length
);
editor.setSelection({ start: newPosition, end: newPosition });
}

protected _poll() {
const prevSelection = this._selection;
const currSelection = getTextSelection(this._mainAreaWidget.content);
const currSelection = getTextSelection(this._mainAreaWidget);

if (prevSelection === currSelection) {
if (prevSelection?.text === currSelection?.text) {
return;
}

this._selection = currSelection;
this._selectionChanged.emit(currSelection);
}

protected _shell: LabShell;
protected _mainAreaWidget: Widget | null = null;
protected _selection = '';
protected _selectionChanged = new Signal<this, string>(this);
protected _selection: Selection | null = null;
protected _selectionChanged = new Signal<this, Selection | null>(this);
}
10 changes: 10 additions & 0 deletions packages/jupyter-ai/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@ export function getEditor(widget: Widget): CodeEditor.IEditor | undefined {

return editor;
}

/**
* Gets the index of the cell associated with `cellId`.
*/
export function getCellIndex(notebook: Notebook, cellId: string): number {
const idx = notebook.model?.sharedModel.cells.findIndex(
cell => cell.getId() === cellId
);
return idx === undefined ? -1 : idx;
}

0 comments on commit b828199

Please sign in to comment.