From c811a5ae8ba4a512eb74ece2f9f6cb332c3bcdd4 Mon Sep 17 00:00:00 2001 From: Kaz Wesley Date: Mon, 19 Feb 2024 15:57:42 -0800 Subject: [PATCH] Enable the Code Editor, with new apply-text-edits algo. (#9055) - Fix the UI problems with our CodeMirror integration (Fixed view stability; Fixed a focus bug; Fixed errors caused by diagnostics range exceptions; Fixed linter invalidation--see https://discuss.codemirror.net/t/problem-trying-to-force-linting/5823; Implemented edit-coalescing for performance). - Introduce an algorithm for applying text edits to an AST. Compared to the GUI1 approach, the new algorithm supports deeper identity-stability for expressions (which is important for subexpression metadata and Y.Js sync), as well as reordered-subtree identification. - Enable the code editor. --- app/gui2/mock/index.ts | 2 +- app/gui2/shared/ast/debug.ts | 22 + app/gui2/shared/ast/ffi.ts | 11 + app/gui2/shared/ast/index.ts | 10 +- app/gui2/shared/ast/mutableModule.ts | 141 ++++-- app/gui2/shared/ast/parse.ts | 412 +++++++++++++--- app/gui2/shared/ast/sourceDocument.ts | 51 +- app/gui2/shared/ast/token.ts | 5 + app/gui2/shared/ast/tree.ts | 461 ++++++++++-------- .../util/data/__tests__/iterable.test.ts | 8 + .../shared/util/data/__tests__/text.test.ts | 137 ++++++ app/gui2/shared/util/data/iterable.ts | 98 ++++ app/gui2/shared/util/data/text.ts | 149 ++++++ app/gui2/shared/yjsModel.ts | 26 +- app/gui2/src/components/CodeEditor.vue | 187 +++++-- .../src/components/CodeEditor/codemirror.ts | 26 +- app/gui2/src/components/GraphEditor.vue | 8 +- .../src/components/GraphEditor/GraphNode.vue | 4 +- .../components/GraphEditor/NodeWidgetTree.vue | 12 +- app/gui2/src/stores/graph/graphDatabase.ts | 4 +- app/gui2/src/stores/graph/index.ts | 51 +- app/gui2/src/stores/project/index.ts | 6 +- .../src/util/ast/__tests__/abstract.test.ts | 283 +++++++++-- app/gui2/src/util/ast/__tests__/match.test.ts | 2 +- .../util/ast/__tests__/sourceDocument.test.ts | 34 ++ app/gui2/src/util/ast/__tests__/testCase.ts | 74 +++ app/gui2/src/util/ast/abstract.ts | 4 +- app/gui2/src/util/ast/extended.ts | 12 +- app/gui2/src/util/ast/index.ts | 12 +- app/gui2/src/util/ast/match.ts | 2 +- app/gui2/src/util/ast/prefixes.ts | 2 +- app/gui2/src/util/data/iterable.ts | 63 +-- app/gui2/ydoc-server/edits.ts | 2 +- app/gui2/ydoc-server/languageServerSession.ts | 43 +- .../dashboard/src/layouts/dashboard/Chat.tsx | 1 - 35 files changed, 1831 insertions(+), 534 deletions(-) create mode 100644 app/gui2/shared/ast/debug.ts create mode 100644 app/gui2/shared/util/data/__tests__/iterable.test.ts create mode 100644 app/gui2/shared/util/data/__tests__/text.test.ts create mode 100644 app/gui2/shared/util/data/iterable.ts create mode 100644 app/gui2/shared/util/data/text.ts create mode 100644 app/gui2/src/util/ast/__tests__/sourceDocument.test.ts create mode 100644 app/gui2/src/util/ast/__tests__/testCase.ts diff --git a/app/gui2/mock/index.ts b/app/gui2/mock/index.ts index ea6492549e19..87f2acf092e4 100644 --- a/app/gui2/mock/index.ts +++ b/app/gui2/mock/index.ts @@ -67,7 +67,7 @@ export function projectStore() { const mod = projectStore.projectModel.createNewModule('Main.enso') mod.doc.ydoc.emit('load', []) const syncModule = new Ast.MutableModule(mod.doc.ydoc) - mod.transact(() => { + syncModule.transact(() => { const root = Ast.parseBlock('main =\n', syncModule) syncModule.replaceRoot(root) }) diff --git a/app/gui2/shared/ast/debug.ts b/app/gui2/shared/ast/debug.ts new file mode 100644 index 000000000000..a63d26207e22 --- /dev/null +++ b/app/gui2/shared/ast/debug.ts @@ -0,0 +1,22 @@ +import { Ast } from './tree' + +/// Returns a GraphViz graph illustrating parent/child relationships in the given subtree. +export function graphParentPointers(ast: Ast) { + const sanitize = (id: string) => id.replace('ast:', '').replace(/[^A-Za-z0-9]/g, '') + const parentToChild = new Array<{ parent: string; child: string }>() + const childToParent = new Array<{ child: string; parent: string }>() + ast.visitRecursiveAst((ast) => { + for (const child of ast.children()) { + if (child instanceof Ast) + parentToChild.push({ child: sanitize(child.id), parent: sanitize(ast.id) }) + } + const parent = ast.parentId + if (parent) childToParent.push({ child: sanitize(ast.id), parent: sanitize(parent) }) + }) + let result = 'digraph parentPointers {\n' + for (const { parent, child } of parentToChild) result += `${parent} -> ${child};\n` + for (const { child, parent } of childToParent) + result += `${child} -> ${parent} [weight=0; color=red; style=dotted];\n` + result += '}\n' + return result +} diff --git a/app/gui2/shared/ast/ffi.ts b/app/gui2/shared/ast/ffi.ts index 0f8aaaf99002..67bb6358dab1 100644 --- a/app/gui2/shared/ast/ffi.ts +++ b/app/gui2/shared/ast/ffi.ts @@ -1,6 +1,16 @@ +import { createXXHash128 } from 'hash-wasm' import init, { is_ident_or_operator, parse, parse_doc_to_json } from '../../rust-ffi/pkg/rust_ffi' +import { assertDefined } from '../util/assert' import { isNode } from '../util/detect' +let xxHasher128: Awaited> | undefined +export function xxHash128(input: string) { + assertDefined(xxHasher128, 'Module should have been loaded with `initializeFFI`.') + xxHasher128.init() + xxHasher128.update(input) + return xxHasher128.digest() +} + export async function initializeFFI(path?: string | undefined) { if (isNode) { const fs = await import('node:fs/promises') @@ -9,6 +19,7 @@ export async function initializeFFI(path?: string | undefined) { } else { await init() } + xxHasher128 = await createXXHash128() } // TODO[ao]: We cannot to that, because the ffi is used by cjs modules. diff --git a/app/gui2/shared/ast/index.ts b/app/gui2/shared/ast/index.ts index 5833a292a872..3b14f82e08d4 100644 --- a/app/gui2/shared/ast/index.ts +++ b/app/gui2/shared/ast/index.ts @@ -40,7 +40,7 @@ export function parentId(ast: Ast): AstId | undefined { export function subtrees(module: Module, ids: Iterable) { const subtrees = new Set() for (const id of ids) { - let ast = module.get(id) + let ast = module.tryGet(id) while (ast != null && !subtrees.has(ast.id)) { subtrees.add(ast.id) ast = ast.parent() @@ -50,10 +50,10 @@ export function subtrees(module: Module, ids: Iterable) { } /** Returns the IDs of the ASTs that are not descendants of any others in the given set. */ -export function subtreeRoots(module: Module, ids: Set) { - const roots = new Array() +export function subtreeRoots(module: Module, ids: Set): Set { + const roots = new Set() for (const id of ids) { - const astInModule = module.get(id) + const astInModule = module.tryGet(id) if (!astInModule) continue let ast = astInModule.parent() let hasParentInSet @@ -64,7 +64,7 @@ export function subtreeRoots(module: Module, ids: Set) { } ast = ast.parent() } - if (!hasParentInSet) roots.push(id) + if (!hasParentInSet) roots.add(id) } return roots } diff --git a/app/gui2/shared/ast/mutableModule.ts b/app/gui2/shared/ast/mutableModule.ts index 447c666b4594..8b1126e7a824 100644 --- a/app/gui2/shared/ast/mutableModule.ts +++ b/app/gui2/shared/ast/mutableModule.ts @@ -1,16 +1,25 @@ import * as random from 'lib0/random' import * as Y from 'yjs' -import type { AstId, Owned, SyncTokenId } from '.' -import { Token, asOwned, isTokenId, newExternalId } from '.' -import { assert } from '../util/assert' -import type { ExternalId } from '../yjsModel' +import { + Token, + asOwned, + isTokenId, + newExternalId, + subtreeRoots, + type AstId, + type Owned, + type SyncTokenId, +} from '.' +import { assert, assertDefined } from '../util/assert' +import type { SourceRangeEdit } from '../util/data/text' +import { defaultLocalOrigin, tryAsOrigin, type ExternalId, type Origin } from '../yjsModel' import type { AstFields, FixedMap, Mutable } from './tree' import { Ast, - Invalid, MutableAst, MutableInvalid, Wildcard, + composeFieldData, invalidFields, materializeMutable, setAll, @@ -19,13 +28,13 @@ import { export interface Module { edit(): MutableModule root(): Ast | undefined - get(id: AstId): Ast | undefined - get(id: AstId | undefined): Ast | undefined + tryGet(id: AstId | undefined): Ast | undefined ///////////////////////////////// - checkedGet(id: AstId): Ast - checkedGet(id: AstId | undefined): Ast | undefined + /** Return the specified AST. Throws an exception if no AST with the provided ID was found. */ + get(id: AstId): Ast + get(id: AstId | undefined): Ast | undefined getToken(token: SyncTokenId): Token getToken(token: SyncTokenId | undefined): Token | undefined getAny(node: AstId | SyncTokenId): Ast | Token @@ -33,10 +42,12 @@ export interface Module { } export interface ModuleUpdate { - nodesAdded: AstId[] - nodesDeleted: AstId[] - fieldsUpdated: { id: AstId; fields: (readonly [string, unknown])[] }[] + nodesAdded: Set + nodesDeleted: Set + nodesUpdated: Set + updateRoots: Set metadataUpdated: { id: AstId; changes: Map }[] + origin: Origin | undefined } type YNode = FixedMap @@ -45,7 +56,7 @@ type YNodes = Y.Map export class MutableModule implements Module { private readonly nodes: YNodes - get ydoc() { + private get ydoc() { const ydoc = this.nodes.doc assert(ydoc != null) return ydoc @@ -53,7 +64,7 @@ export class MutableModule implements Module { /** Return this module's copy of `ast`, if this module was created by cloning `ast`'s module. */ getVersion(ast: T): Mutable { - const instance = this.checkedGet(ast.id) + const instance = this.get(ast.id) return instance as Mutable } @@ -63,6 +74,14 @@ export class MutableModule implements Module { return new MutableModule(doc) } + applyEdit(edit: MutableModule, origin: Origin = defaultLocalOrigin) { + Y.applyUpdateV2(this.ydoc, Y.encodeStateAsUpdateV2(edit.ydoc), origin) + } + + transact(f: () => T, origin: Origin = defaultLocalOrigin): T { + return this.ydoc.transact(f, origin) + } + root(): MutableAst | undefined { return this.rootPointer()?.expression } @@ -93,6 +112,22 @@ export class MutableModule implements Module { this.gc() } + syncToCode(code: string) { + const root = this.root() + if (root) { + root.syncToCode(code) + } else { + this.replaceRoot(Ast.parse(code, this)) + } + } + + /** Update the module according to changes to its corresponding source code. */ + applyTextEdits(textEdits: SourceRangeEdit[], metadataSource?: Module) { + const root = this.root() + assertDefined(root) + root.applyTextEdits(textEdits, metadataSource) + } + private gc() { const live = new Set() const active = new Array() @@ -129,7 +164,9 @@ export class MutableModule implements Module { } observe(observer: (update: ModuleUpdate) => void) { - const handle = (events: Y.YEvent[]) => observer(this.observeEvents(events)) + const handle = (events: Y.YEvent[], transaction: Y.Transaction) => { + observer(this.observeEvents(events, tryAsOrigin(transaction.origin))) + } // Attach the observer first, so that if an update hook causes changes in reaction to the initial state update, we // won't miss them. this.nodes.observeDeep(handle) @@ -142,15 +179,15 @@ export class MutableModule implements Module { } getStateAsUpdate(): ModuleUpdate { - const updateBuilder = new UpdateBuilder(this.nodes) + const updateBuilder = new UpdateBuilder(this, this.nodes, undefined) for (const id of this.nodes.keys()) updateBuilder.addNode(id as AstId) - return updateBuilder + return updateBuilder.finish() } - applyUpdate(update: Uint8Array, origin?: string): ModuleUpdate | undefined { + applyUpdate(update: Uint8Array, origin: Origin): ModuleUpdate | undefined { let summary: ModuleUpdate | undefined const observer = (events: Y.YEvent[]) => { - summary = this.observeEvents(events) + summary = this.observeEvents(events, origin) } this.nodes.observeDeep(observer) Y.applyUpdate(this.ydoc, update, origin) @@ -158,8 +195,8 @@ export class MutableModule implements Module { return summary } - private observeEvents(events: Y.YEvent[]): ModuleUpdate { - const updateBuilder = new UpdateBuilder(this.nodes) + private observeEvents(events: Y.YEvent[], origin: Origin | undefined): ModuleUpdate { + const updateBuilder = new UpdateBuilder(this, this.nodes, origin) for (const event of events) { if (event.target === this.nodes) { // Updates to the node map. @@ -201,25 +238,23 @@ export class MutableModule implements Module { updateBuilder.updateMetadata(id, changes) } } - return updateBuilder + return updateBuilder.finish() } clear() { this.nodes.clear() } - checkedGet(id: AstId): Mutable - checkedGet(id: AstId | undefined): Mutable | undefined - checkedGet(id: AstId | undefined): Mutable | undefined { + get(id: AstId): Mutable + get(id: AstId | undefined): Mutable | undefined + get(id: AstId | undefined): Mutable | undefined { if (!id) return undefined - const ast = this.get(id) + const ast = this.tryGet(id) assert(ast !== undefined, 'id in module') return ast } - get(id: AstId): Mutable | undefined - get(id: AstId | undefined): Mutable | undefined - get(id: AstId | undefined): Mutable | undefined { + tryGet(id: AstId | undefined): Mutable | undefined { if (!id) return undefined const nodeData = this.nodes.get(id) if (!nodeData) return undefined @@ -228,19 +263,19 @@ export class MutableModule implements Module { } replace(id: AstId, value: Owned): Owned | undefined { - return this.get(id)?.replace(value) + return this.tryGet(id)?.replace(value) } replaceValue(id: AstId, value: Owned): Owned | undefined { - return this.get(id)?.replaceValue(value) + return this.tryGet(id)?.replaceValue(value) } take(id: AstId): Owned { - return this.replace(id, Wildcard.new(this)) || asOwned(this.checkedGet(id)) + return this.replace(id, Wildcard.new(this)) || asOwned(this.get(id)) } updateValue(id: AstId, f: (x: Owned) => Owned): T | undefined { - return this.get(id)?.updateValue(f) + return this.tryGet(id)?.updateValue(f) } ///////////////////////////////////////////// @@ -250,7 +285,7 @@ export class MutableModule implements Module { } private rootPointer(): MutableRootPointer | undefined { - const rootPointer = this.get(ROOT_ID) + const rootPointer = this.tryGet(ROOT_ID) if (rootPointer) return rootPointer as MutableRootPointer } @@ -269,8 +304,9 @@ export class MutableModule implements Module { parent: undefined, metadata: metadataFields, }) - this.nodes.set(id, fields) - return fields + const fieldObject = composeFieldData(fields, {}) + this.nodes.set(id, fieldObject) + return fieldObject } /** @internal */ @@ -283,7 +319,7 @@ export class MutableModule implements Module { } getAny(node: AstId | SyncTokenId): MutableAst | Token { - return isTokenId(node) ? this.getToken(node) : this.checkedGet(node) + return isTokenId(node) ? this.getToken(node) : this.get(node) } /** @internal Copy a node into the module, if it is bound to a different module. */ @@ -306,8 +342,6 @@ export class MutableModule implements Module { } type MutableRootPointer = MutableInvalid & { get expression(): MutableAst | undefined } -/** @internal */ -export interface RootPointer extends Invalid {} function newAstId(type: string): AstId { return `ast:${type}#${random.uint53()}` as AstId @@ -318,20 +352,24 @@ export function isAstId(value: string): value is AstId { } export const ROOT_ID = `Root` as AstId -class UpdateBuilder implements ModuleUpdate { - readonly nodesAdded: AstId[] = [] - readonly nodesDeleted: AstId[] = [] - readonly fieldsUpdated: { id: AstId; fields: (readonly [string, unknown])[] }[] = [] +class UpdateBuilder { + readonly nodesAdded = new Set() + readonly nodesDeleted = new Set() + readonly nodesUpdated = new Set() readonly metadataUpdated: { id: AstId; changes: Map }[] = [] + readonly origin: Origin | undefined + private readonly module: Module private readonly nodes: YNodes - constructor(nodes: YNodes) { + constructor(module: Module, nodes: YNodes, origin: Origin | undefined) { + this.module = module this.nodes = nodes + this.origin = origin } addNode(id: AstId) { - this.nodesAdded.push(id) + this.nodesAdded.add(id) this.updateAllFields(id) } @@ -340,7 +378,7 @@ class UpdateBuilder implements ModuleUpdate { } updateFields(id: AstId, changes: Iterable) { - const fields = new Array() + let fieldsChanged = false let metadataChanges = undefined for (const entry of changes) { const [key, value] = entry @@ -349,10 +387,10 @@ class UpdateBuilder implements ModuleUpdate { metadataChanges = new Map(value.entries()) } else { assert(!(value instanceof Y.AbstractType)) - fields.push(entry) + fieldsChanged = true } } - if (fields.length !== 0) this.fieldsUpdated.push({ id, fields }) + if (fieldsChanged) this.nodesUpdated.add(id) if (metadataChanges) this.metadataUpdated.push({ id, changes: metadataChanges }) } @@ -363,6 +401,11 @@ class UpdateBuilder implements ModuleUpdate { } deleteNode(id: AstId) { - this.nodesDeleted.push(id) + this.nodesDeleted.add(id) + } + + finish(): ModuleUpdate { + const updateRoots = subtreeRoots(this.module, new Set(this.nodesUpdated.keys())) + return { ...this, updateRoots } } } diff --git a/app/gui2/shared/ast/parse.ts b/app/gui2/shared/ast/parse.ts index 0f4715e7933c..d1b5020d8085 100644 --- a/app/gui2/shared/ast/parse.ts +++ b/app/gui2/shared/ast/parse.ts @@ -1,18 +1,36 @@ import * as map from 'lib0/map' +import type { AstId, Module, NodeChild, Owned } from '.' import { Token, asOwned, isTokenId, parentId, + rewriteRefs, subtreeRoots, - type AstId, - type NodeChild, - type Owned, + syncFields, + syncNodeMetadata, } from '.' import { assert, assertDefined, assertEqual } from '../util/assert' -import type { SourceRange, SourceRangeKey } from '../yjsModel' -import { IdMap, isUuid, sourceRangeFromKey, sourceRangeKey } from '../yjsModel' -import { parse_tree } from './ffi' +import { tryGetSoleValue, zip } from '../util/data/iterable' +import type { SourceRangeEdit, SpanTree } from '../util/data/text' +import { + applyTextEdits, + applyTextEditsToSpans, + enclosingSpans, + textChangeToEdits, + trimEnd, +} from '../util/data/text' +import { + IdMap, + isUuid, + rangeLength, + sourceRangeFromKey, + sourceRangeKey, + type SourceRange, + type SourceRangeKey, +} from '../yjsModel' +import { graphParentPointers } from './debug' +import { parse_tree, xxHash128 } from './ffi' import * as RawAst from './generated/ast' import { MutableModule } from './mutableModule' import type { LazyObject } from './parserSupport' @@ -28,6 +46,8 @@ import { Ident, Import, Invalid, + MutableAssignment, + MutableAst, MutableBodyBlock, MutableIdent, NegationApp, @@ -39,11 +59,16 @@ import { Wildcard, } from './tree' -export function parseEnso(code: string): RawAst.Tree { +/** Return the raw parser output for the given code. */ +export function parseEnso(code: string): RawAst.Tree.BodyBlock { const blob = parse_tree(code) - return RawAst.Tree.read(new DataView(blob.buffer), blob.byteLength - 4) + const tree = RawAst.Tree.read(new DataView(blob.buffer), blob.byteLength - 4) + // The root of the parser output is always a body block. + assert(tree.type === RawAst.Tree.Type.BodyBlock) + return tree } +/** Print the AST and re-parse it, copying `externalId`s (but not other metadata) from the original. */ export function normalize(rootIn: Ast): Ast { const printed = print(rootIn) const idMap = spanMapToIdMap(printed.info) @@ -55,27 +80,54 @@ export function normalize(rootIn: Ast): Ast { return parsed } +/** Produce `Ast` types from `RawAst` parser output. */ +export function abstract( + module: MutableModule, + tree: RawAst.Tree.BodyBlock, + code: string, + substitutor?: (key: NodeKey) => Owned | undefined, +): { root: Owned; spans: SpanMap; toRaw: Map } export function abstract( module: MutableModule, tree: RawAst.Tree, code: string, + substitutor?: (key: NodeKey) => Owned | undefined, +): { root: Owned; spans: SpanMap; toRaw: Map } +export function abstract( + module: MutableModule, + tree: RawAst.Tree, + code: string, + substitutor?: (key: NodeKey) => Owned | undefined, ): { root: Owned; spans: SpanMap; toRaw: Map } { - const abstractor = new Abstractor(module, code) + const abstractor = new Abstractor(module, code, substitutor) const root = abstractor.abstractTree(tree).node const spans = { tokens: abstractor.tokens, nodes: abstractor.nodes } - return { root, spans, toRaw: abstractor.toRaw } + return { root: root as Owned, spans, toRaw: abstractor.toRaw } } +/** Produces `Ast` types from `RawAst` parser output. */ class Abstractor { private readonly module: MutableModule private readonly code: string + private readonly substitutor: ((key: NodeKey) => Owned | undefined) | undefined readonly nodes: NodeSpanMap readonly tokens: TokenSpanMap readonly toRaw: Map - constructor(module: MutableModule, code: string) { + /** + * @param module - Where to allocate the new nodes. + * @param code - Source code that will be used to resolve references in any passed `RawAst` objects. + * @param substitutor - A function that can inject subtrees for some spans, instead of the abstractor producing them. + * This can be used for incremental abstraction. + */ + constructor( + module: MutableModule, + code: string, + substitutor?: (key: NodeKey) => Owned | undefined, + ) { this.module = module this.code = code + this.substitutor = substitutor this.nodes = new Map() this.tokens = new Map() this.toRaw = new Map() @@ -88,6 +140,8 @@ class Abstractor { const codeStart = whitespaceEnd const codeEnd = codeStart + tree.childrenLengthInCodeParsed const spanKey = nodeKey(codeStart, codeEnd - codeStart) + const substitute = this.substitutor?.(spanKey) + if (substitute) return { node: substitute, whitespace } let node: Owned switch (tree.type) { case RawAst.Tree.Type.BodyBlock: { @@ -153,10 +207,11 @@ class Abstractor { ? [this.abstractToken(tree.opr.value)] : Array.from(tree.opr.error.payload.operators, this.abstractToken.bind(this)) const rhs = tree.rhs ? this.abstractTree(tree.rhs) : undefined - if (opr.length === 1 && opr[0]?.node.code() === '.' && rhs?.node instanceof MutableIdent) { + const soleOpr = tryGetSoleValue(opr) + if (soleOpr?.node.code() === '.' && rhs?.node instanceof MutableIdent) { // Propagate type. const rhs_ = { ...rhs, node: rhs.node } - node = PropertyAccess.concrete(this.module, lhs, opr[0], rhs_) + node = PropertyAccess.concrete(this.module, lhs, soleOpr, rhs_) } else { node = OprApp.concrete(this.module, lhs, opr, rhs) } @@ -259,7 +314,7 @@ class Abstractor { return { whitespace, node } } - private abstractChildren(tree: LazyObject) { + private abstractChildren(tree: LazyObject): NodeChild[] { const children: NodeChild[] = [] const visitor = (child: LazyObject) => { if (RawAst.Tree.isInstance(child)) { @@ -276,29 +331,38 @@ class Abstractor { } declare const nodeKeyBrand: unique symbol +/** A source-range key for an `Ast`. */ export type NodeKey = SourceRangeKey & { [nodeKeyBrand]: never } declare const tokenKeyBrand: unique symbol +/** A source-range key for a `Token`. */ export type TokenKey = SourceRangeKey & { [tokenKeyBrand]: never } +/** Create a source-range key for an `Ast`. */ export function nodeKey(start: number, length: number): NodeKey { return sourceRangeKey([start, start + length]) as NodeKey } +/** Create a source-range key for a `Token`. */ export function tokenKey(start: number, length: number): TokenKey { return sourceRangeKey([start, start + length]) as TokenKey } +/** Maps from source ranges to `Ast`s. */ export type NodeSpanMap = Map +/** Maps from source ranges to `Token`s. */ export type TokenSpanMap = Map +/** Maps from source ranges to `Ast`s and `Token`s. */ export interface SpanMap { nodes: NodeSpanMap tokens: TokenSpanMap } +/** Code with an associated mapping to `Ast` types. */ interface PrintedSource { info: SpanMap code: string } +/** Generate an `IdMap` from a `SpanMap`. */ export function spanMapToIdMap(spans: SpanMap): IdMap { const idMap = new IdMap() for (const [key, token] of spans.tokens.entries()) { @@ -314,6 +378,7 @@ export function spanMapToIdMap(spans: SpanMap): IdMap { return idMap } +/** Given a `SpanMap`, return a function that can look up source ranges by AST ID. */ export function spanMapToSpanGetter(spans: SpanMap): (id: AstId) => SourceRange | undefined { const reverseMap = new Map() for (const [key, asts] of spans.nodes) { @@ -344,7 +409,7 @@ export function printAst( ): string { let code = '' for (const child of ast.concreteChildren(verbatim)) { - if (!isTokenId(child.node) && ast.module.checkedGet(child.node) === undefined) continue + if (!isTokenId(child.node) && ast.module.get(child.node) === undefined) continue if (child.whitespace != null) { code += child.whitespace } else if (code.length != 0) { @@ -357,13 +422,16 @@ export function printAst( info.tokens.set(span, token) code += token.code() } else { - const childNode = ast.module.checkedGet(child.node) - assert(childNode != null) + const childNode = ast.module.get(child.node) code += childNode.printSubtree(info, offset + code.length, parentIndent, verbatim) // Extra structural validation. assertEqual(childNode.id, child.node) if (parentId(childNode) !== ast.id) { - console.error(`Inconsistent parent pointer (expected ${ast.id})`, childNode) + console.error( + `Inconsistent parent pointer (expected ${ast.id})`, + childNode, + graphParentPointers(ast.module.root()!), + ) } assertEqual(parentId(childNode), ast.id) } @@ -404,7 +472,7 @@ export function printBlock( } const validIndent = (line.expression.whitespace?.length ?? 0) > (parentIndent?.length ?? 0) code += validIndent ? line.expression.whitespace : blockIndent - const lineNode = block.module.checkedGet(line.expression.node) + const lineNode = block.module.get(line.expression.node) assertEqual(lineNode.id, line.expression.node) assertEqual(parentId(lineNode), block.id) code += lineNode.printSubtree(info, offset + code.length, blockIndent, verbatim) @@ -416,7 +484,7 @@ export function printBlock( } /** Parse the input as a block. */ -export function parseBlock(code: string, inModule?: MutableModule) { +export function parseBlock(code: string, inModule?: MutableModule): Owned { return parseBlockWithSpans(code, inModule).root } @@ -424,58 +492,55 @@ export function parseBlock(code: string, inModule?: MutableModule) { export function parse(code: string, module?: MutableModule): Owned { const module_ = module ?? MutableModule.Transient() const ast = parseBlock(code, module_) - const [expr] = ast.statements() - if (!expr) return ast - const parent = parentId(expr) + const soleStatement = tryGetSoleValue(ast.statements()) + if (!soleStatement) return ast + const parent = parentId(soleStatement) if (parent) module_.delete(parent) - expr.fields.set('parent', undefined) - return asOwned(expr) + soleStatement.fields.set('parent', undefined) + return asOwned(soleStatement) } +/** Parse a block, and return it along with a mapping from source locations to parsed objects. */ export function parseBlockWithSpans( code: string, inModule?: MutableModule, ): { root: Owned; spans: SpanMap } { const tree = parseEnso(code) const module = inModule ?? MutableModule.Transient() - return fromRaw(tree, code, module) -} - -function fromRaw( - tree: RawAst.Tree, - code: string, - inModule?: MutableModule, -): { - root: Owned - spans: SpanMap - toRaw: Map -} { - const module = inModule ?? MutableModule.Transient() - const ast = abstract(module, tree, code) - const spans = ast.spans - // The root of the tree produced by the parser is always a `BodyBlock`. - const root = ast.root as Owned - return { root, spans, toRaw: ast.toRaw } + return abstract(module, tree, code) } +/** Parse the input, and apply the given `IdMap`. Return the parsed tree, the updated `IdMap`, the span map, and a + * mapping to the `RawAst` representation. + */ export function parseExtended(code: string, idMap?: IdMap | undefined, inModule?: MutableModule) { const rawRoot = parseEnso(code) const module = inModule ?? MutableModule.Transient() - const { root, spans, toRaw, idMapUpdates } = module.ydoc.transact(() => { - const { root, spans, toRaw } = fromRaw(rawRoot, code, module) + const { root, spans, toRaw } = module.transact(() => { + const { root, spans, toRaw } = abstract(module, rawRoot, code) root.module.replaceRoot(root) - const idMapUpdates = idMap ? setExternalIds(root.module, spans, idMap) : 0 - return { root, spans, toRaw, idMapUpdates } - }, 'local') + if (idMap) setExternalIds(root.module, spans, idMap) + return { root, spans, toRaw } + }) const getSpan = spanMapToSpanGetter(spans) const idMapOut = spanMapToIdMap(spans) - return { root, idMap: idMapOut, getSpan, toRaw, idMapUpdates } + return { root, idMap: idMapOut, getSpan, toRaw } } -export function setExternalIds(edit: MutableModule, spans: SpanMap, ids: IdMap) { +/** Return the number of `Ast`s in the tree, including the provided root. */ +export function astCount(ast: Ast): number { + let count = 0 + ast.visitRecursiveAst((_subtree) => { + count += 1 + }) + return count +} + +/** Apply an `IdMap` to a module, using the given `SpanMap`. + * @returns The number of IDs that were assigned from the map. + */ +export function setExternalIds(edit: MutableModule, spans: SpanMap, ids: IdMap): number { let astsMatched = 0 - let asts = 0 - edit.root()?.visitRecursiveAst((_ast) => (asts += 1)) for (const [key, externalId] of ids.entries()) { const asts = spans.nodes.get(key as NodeKey) if (asts) { @@ -486,9 +551,12 @@ export function setExternalIds(edit: MutableModule, spans: SpanMap, ids: IdMap) } } } - return edit.root() ? asts - astsMatched : 0 + return astsMatched } +/** Try to find all the spans in `expected` in `encountered`. If any are missing, use the provided `code` to determine + * whether the lost spans are single-line or multi-line. + */ function checkSpans(expected: NodeSpanMap, encountered: NodeSpanMap, code: string) { const lost = new Array() for (const [key, asts] of expected) { @@ -573,7 +641,7 @@ function resync( const parentsOfBadSubtrees = new Set() const badAstIds = new Set(Array.from(badAsts, (ast) => ast.id)) for (const id of subtreeRoots(edit, badAstIds)) { - const parent = edit.checkedGet(id)?.parentId + const parent = edit.get(id)?.parentId if (parent) parentsOfBadSubtrees.add(parent) } @@ -587,11 +655,11 @@ function resync( assertEqual(spanOfBadParent.length, parentsOfBadSubtrees.size) for (const [id, span] of spanOfBadParent) { - const parent = edit.checkedGet(id) + const parent = edit.get(id) const goodAst = goodSpans.get(span)?.[0] // The parent of the root of a bad subtree must be a good AST. assertDefined(goodAst) - parent.replaceValue(edit.copy(goodAst)) + parent.syncToCode(goodAst.code()) } console.warn( @@ -599,3 +667,235 @@ function resync( parentsOfBadSubtrees, ) } + +/** @internal Recursion helper for {@link syntaxHash}. */ +function hashSubtreeSyntax(ast: Ast, hashesOut: Map): SyntaxHash { + let content = '' + content += ast.typeName + ':' + for (const child of ast.concreteChildren()) { + content += child.whitespace ?? '?' + if (isTokenId(child.node)) { + content += 'Token:' + hashString(ast.module.getToken(child.node).code()) + } else { + content += hashSubtreeSyntax(ast.module.get(child.node), hashesOut) + } + } + const astHash = hashString(content) + map.setIfUndefined(hashesOut, astHash, (): Ast[] => []).unshift(ast) + return astHash +} + +declare const brandHash: unique symbol +/** See {@link syntaxHash}. */ +type SyntaxHash = string & { [brandHash]: never } +/** Applies the syntax-data hashing function to the input, and brands the result as a `SyntaxHash`. */ +function hashString(input: string): SyntaxHash { + return xxHash128(input) as SyntaxHash +} + +/** Calculates `SyntaxHash`es for the given node and all its children. + * + * Each `SyntaxHash` summarizes the syntactic content of an AST. If two ASTs have the same code and were parsed the + * same way (i.e. one was not parsed in a context that resulted in a different interpretation), they will have the same + * hash. Note that the hash is invariant to metadata, including `externalId` assignments. + */ +function syntaxHash(root: Ast) { + const hashes = new Map() + const rootHash = hashSubtreeSyntax(root, hashes) + return { root: rootHash, hashes } +} + +/** If the input is a block containing a single expression, return the expression; otherwise return the input. */ +function rawBlockToInline(tree: RawAst.Tree.Tree) { + if (tree.type !== RawAst.Tree.Type.BodyBlock) return tree + return tryGetSoleValue(tree.statements)?.expression ?? tree +} + +/** Update `ast` to match the given source code, while modifying it as little as possible. */ +export function syncToCode(ast: MutableAst, code: string, metadataSource?: Module) { + const codeBefore = ast.code() + const textEdits = textChangeToEdits(codeBefore, code) + applyTextEditsToAst(ast, textEdits, metadataSource ?? ast.module) +} + +/** Find nodes in the input `ast` that should be treated as equivalents of nodes in `parsedRoot`. */ +function calculateCorrespondence( + ast: Ast, + astSpans: NodeSpanMap, + parsedRoot: Ast, + parsedSpans: NodeSpanMap, + textEdits: SourceRangeEdit[], + codeAfter: string, +): Map { + const newSpans = new Map() + for (const [key, asts] of parsedSpans) { + for (const ast of asts) newSpans.set(ast.id, sourceRangeFromKey(key)) + } + + // Retained-code matching: For each new tree, check for some old tree of the same type such that the new tree is the + // smallest node to contain all characters of the old tree's code that were not deleted in the edit. + // + // If the new node's span exactly matches the retained code, add the match to `toSync`. If the new node's span + // contains additional code, add the match to `candidates`. + const toSync = new Map() + const candidates = new Map() + const allSpansBefore = Array.from(astSpans.keys(), sourceRangeFromKey) + const spansBeforeAndAfter = applyTextEditsToSpans(textEdits, allSpansBefore).map( + ([before, after]) => [before, trimEnd(after, codeAfter)] satisfies [any, any], + ) + const partAfterToAstBefore = new Map() + for (const [spanBefore, partAfter] of spansBeforeAndAfter) { + const astBefore = astSpans.get(sourceRangeKey(spanBefore) as NodeKey)?.[0]! + partAfterToAstBefore.set(sourceRangeKey(partAfter), astBefore) + } + const matchingPartsAfter = spansBeforeAndAfter.map(([_before, after]) => after) + const parsedSpanTree = new AstWithSpans(parsedRoot, (id) => newSpans.get(id)!) + const astsMatchingPartsAfter = enclosingSpans(parsedSpanTree, matchingPartsAfter) + for (const [astAfter, partsAfter] of astsMatchingPartsAfter) { + for (const partAfter of partsAfter) { + const astBefore = partAfterToAstBefore.get(sourceRangeKey(partAfter))! + if (astBefore.typeName() === astAfter.typeName()) { + ;(rangeLength(newSpans.get(astAfter.id)!) === rangeLength(partAfter) + ? toSync + : candidates + ).set(astBefore.id, astAfter) + break + } + } + } + + // Index the matched nodes. + const oldIdsMatched = new Set() + const newIdsMatched = new Set() + for (const [oldId, newAst] of toSync) { + oldIdsMatched.add(oldId) + newIdsMatched.add(newAst.id) + } + + // Movement matching: For each new tree that hasn't been matched, match it with any identical unmatched old tree. + const newHashes = syntaxHash(parsedRoot).hashes + const oldHashes = syntaxHash(ast).hashes + for (const [hash, newAsts] of newHashes) { + const unmatchedNewAsts = newAsts.filter((ast) => !newIdsMatched.has(ast.id)) + const unmatchedOldAsts = oldHashes.get(hash)?.filter((ast) => !oldIdsMatched.has(ast.id)) ?? [] + for (const [unmatchedNew, unmatchedOld] of zip(unmatchedNewAsts, unmatchedOldAsts)) { + toSync.set(unmatchedOld.id, unmatchedNew) + // Update the matched-IDs indices. + oldIdsMatched.add(unmatchedOld.id) + newIdsMatched.add(unmatchedNew.id) + } + } + + // Apply any non-optimal span matches from `candidates`, if the nodes involved were not matched during + // movement-matching. + for (const [beforeId, after] of candidates) { + if (oldIdsMatched.has(beforeId) || newIdsMatched.has(after.id)) continue + toSync.set(beforeId, after) + } + + return toSync +} + +/** Update `ast` according to changes to its corresponding source code. */ +export function applyTextEditsToAst( + ast: MutableAst, + textEdits: SourceRangeEdit[], + metadataSource: Module, +) { + const printed = print(ast) + const code = applyTextEdits(printed.code, textEdits) + const rawParsedBlock = parseEnso(code) + const rawParsed = + ast instanceof MutableBodyBlock ? rawParsedBlock : rawBlockToInline(rawParsedBlock) + const parsed = abstract(ast.module, rawParsed, code) + const toSync = calculateCorrespondence( + ast, + printed.info.nodes, + parsed.root, + parsed.spans.nodes, + textEdits, + code, + ) + syncTree(ast, parsed.root, toSync, ast.module, metadataSource) +} + +/** Replace `target` with `newContent`, reusing nodes according to the correspondence in `toSync`. */ +function syncTree( + target: Ast, + newContent: Owned, + toSync: Map, + edit: MutableModule, + metadataSource: Module, +) { + const newIdToEquivalent = new Map() + for (const [beforeId, after] of toSync) newIdToEquivalent.set(after.id, beforeId) + const childReplacerFor = (parentId: AstId) => (id: AstId) => { + const original = newIdToEquivalent.get(id) + if (original) { + const replacement = edit.get(original) + if (replacement.parentId !== parentId) replacement.fields.set('parent', parentId) + return original + } else { + const child = edit.get(id) + if (child.parentId !== parentId) child.fields.set('parent', parentId) + } + } + const parentId = target.fields.get('parent') + assertDefined(parentId) + const parent = edit.get(parentId) + const targetSyncEquivalent = toSync.get(target.id) + const syncRoot = targetSyncEquivalent?.id === newContent.id ? targetSyncEquivalent : undefined + if (!syncRoot) { + parent.replaceChild(target.id, newContent) + newContent.fields.set('metadata', target.fields.get('metadata').clone()) + } + const newRoot = syncRoot ? target : newContent + newRoot.visitRecursiveAst((ast) => { + const syncFieldsFrom = toSync.get(ast.id) + const editAst = edit.getVersion(ast) + if (syncFieldsFrom) { + const originalAssignmentExpression = + ast instanceof Assignment + ? metadataSource.get(ast.fields.get('expression').node) + : undefined + syncFields(edit.getVersion(ast), syncFieldsFrom, childReplacerFor(ast.id)) + if (editAst instanceof MutableAssignment && originalAssignmentExpression) { + if (editAst.expression.externalId !== originalAssignmentExpression.externalId) + editAst.expression.setExternalId(originalAssignmentExpression.externalId) + syncNodeMetadata( + editAst.expression.mutableNodeMetadata(), + originalAssignmentExpression.nodeMetadata, + ) + } + } else { + rewriteRefs(editAst, childReplacerFor(ast.id)) + } + return true + }) + return newRoot +} + +/** Provides a `SpanTree` view of an `Ast`, given span information. */ +class AstWithSpans implements SpanTree { + private readonly ast: Ast + private readonly getSpan: (astId: AstId) => SourceRange + + constructor(ast: Ast, getSpan: (astId: AstId) => SourceRange) { + this.ast = ast + this.getSpan = getSpan + } + + id(): Ast { + return this.ast + } + + span(): SourceRange { + return this.getSpan(this.ast.id) + } + + *children(): IterableIterator> { + for (const child of this.ast.children()) { + if (child instanceof Ast) yield new AstWithSpans(child, this.getSpan) + } + } +} diff --git a/app/gui2/shared/ast/sourceDocument.ts b/app/gui2/shared/ast/sourceDocument.ts index 660569564eec..4dacefc903af 100644 --- a/app/gui2/shared/ast/sourceDocument.ts +++ b/app/gui2/shared/ast/sourceDocument.ts @@ -1,5 +1,9 @@ import { print, type AstId, type Module, type ModuleUpdate } from '.' -import { rangeEquals, sourceRangeFromKey, type SourceRange } from '../yjsModel' +import { assertDefined } from '../util/assert' +import type { SourceRangeEdit } from '../util/data/text' +import { offsetEdit, textChangeToEdits } from '../util/data/text' +import type { Origin, SourceRange } from '../yjsModel' +import { rangeEquals, sourceRangeFromKey } from '../yjsModel' /** Provides a view of the text representation of a module, * and information about the correspondence between the text and the ASTs, @@ -8,10 +12,12 @@ import { rangeEquals, sourceRangeFromKey, type SourceRange } from '../yjsModel' export class SourceDocument { private text_: string private readonly spans: Map + private readonly observers: SourceDocumentObserver[] private constructor(text: string, spans: Map) { this.text_ = text this.spans = spans + this.observers = [] } static Empty() { @@ -19,23 +25,43 @@ export class SourceDocument { } clear() { - if (this.text_ !== '') this.text_ = '' if (this.spans.size !== 0) this.spans.clear() + if (this.text_ !== '') { + const range: SourceRange = [0, this.text_.length] + this.text_ = '' + this.notifyObservers([{ range, insert: '' }], undefined) + } } applyUpdate(module: Module, update: ModuleUpdate) { for (const id of update.nodesDeleted) this.spans.delete(id) const root = module.root() if (!root) return + const subtreeTextEdits = new Array() const printed = print(root) for (const [key, nodes] of printed.info.nodes) { const range = sourceRangeFromKey(key) for (const node of nodes) { const oldSpan = this.spans.get(node.id) if (!oldSpan || !rangeEquals(range, oldSpan)) this.spans.set(node.id, range) + if (update.updateRoots.has(node.id) && node.id !== root.id) { + assertDefined(oldSpan) + const oldCode = this.text_.slice(oldSpan[0], oldSpan[1]) + const newCode = printed.code.slice(range[0], range[1]) + const subedits = textChangeToEdits(oldCode, newCode).map((textEdit) => + offsetEdit(textEdit, oldSpan[0]), + ) + subtreeTextEdits.push(...subedits) + } } } - if (printed.code !== this.text_) this.text_ = printed.code + if (printed.code !== this.text_) { + const textEdits = update.updateRoots.has(root.id) + ? [{ range: [0, this.text_.length] satisfies SourceRange, insert: printed.code }] + : subtreeTextEdits + this.text_ = printed.code + this.notifyObservers(textEdits, update.origin) + } } get text(): string { @@ -45,4 +71,23 @@ export class SourceDocument { getSpan(id: AstId): SourceRange | undefined { return this.spans.get(id) } + + observe(observer: SourceDocumentObserver) { + this.observers.push(observer) + if (this.text_.length) observer([{ range: [0, 0], insert: this.text_ }], undefined) + } + + unobserve(observer: SourceDocumentObserver) { + const index = this.observers.indexOf(observer) + if (index !== undefined) this.observers.splice(index, 1) + } + + private notifyObservers(textEdits: SourceRangeEdit[], origin: Origin | undefined) { + for (const o of this.observers) o(textEdits, origin) + } } + +export type SourceDocumentObserver = ( + textEdits: SourceRangeEdit[], + origin: Origin | undefined, +) => void diff --git a/app/gui2/shared/ast/token.ts b/app/gui2/shared/ast/token.ts index d4397fccb34e..e307f7a81cd4 100644 --- a/app/gui2/shared/ast/token.ts +++ b/app/gui2/shared/ast/token.ts @@ -23,6 +23,7 @@ export interface SyncTokenId { code_: string tokenType_: RawAst.Token.Type | undefined } + export class Token implements SyncTokenId { readonly id: TokenId code_: string @@ -47,6 +48,10 @@ export class Token implements SyncTokenId { return new this(code, type, id) } + static equal(a: SyncTokenId, b: SyncTokenId): boolean { + return a.tokenType_ === b.tokenType_ && a.code_ === b.code_ + } + code(): string { return this.code_ } diff --git a/app/gui2/shared/ast/tree.ts b/app/gui2/shared/ast/tree.ts index 76641b3fd2d4..f9849f1615af 100644 --- a/app/gui2/shared/ast/tree.ts +++ b/app/gui2/shared/ast/tree.ts @@ -11,6 +11,7 @@ import type { } from '.' import { MutableModule, + ROOT_ID, Token, asOwned, isIdentifier, @@ -18,17 +19,23 @@ import { isTokenId, newExternalId, parentId, - parse, - parseBlock, - print, - printAst, - printBlock, } from '.' import { assert, assertDefined, assertEqual, bail } from '../util/assert' import type { Result } from '../util/data/result' import { Err, Ok } from '../util/data/result' +import type { SourceRangeEdit } from '../util/data/text' import type { ExternalId, VisualizationMetadata } from '../yjsModel' +import { visMetadataEquals } from '../yjsModel' import * as RawAst from './generated/ast' +import { + applyTextEditsToAst, + parse, + parseBlock, + print, + printAst, + printBlock, + syncToCode, +} from './parse' declare const brandAstId: unique symbol export type AstId = string & { [brandAstId]: never } @@ -47,12 +54,22 @@ export function asNodeMetadata(map: Map): NodeMetadata { return map as unknown as NodeMetadata } /** @internal */ -export interface AstFields { +interface RawAstFields { id: AstId type: string parent: AstId | undefined metadata: FixedMap } +export interface AstFields extends RawAstFields, LegalFieldContent {} +function allKeys(keys: Record): (keyof T)[] { + return Object.keys(keys) as any +} +const astFieldKeys = allKeys({ + id: null, + type: null, + parent: null, + metadata: null, +}) export abstract class Ast { readonly module: Module /** @internal */ @@ -104,8 +121,8 @@ export abstract class Ast { } } - visitRecursiveAst(visit: (ast: Ast) => void): void { - visit(this) + visitRecursiveAst(visit: (ast: Ast) => void | boolean): void { + if (visit(this) === false) return for (const child of this.children()) { if (!isToken(child)) child.visitRecursiveAst(visit) } @@ -126,7 +143,7 @@ export abstract class Ast { if (isTokenId(child.node)) { yield this.module.getToken(child.node) } else { - const node = this.module.checkedGet(child.node) + const node = this.module.get(child.node) if (node) yield node } } @@ -134,11 +151,11 @@ export abstract class Ast { get parentId(): AstId | undefined { const parentId = this.fields.get('parent') - if (parentId !== 'ROOT_ID') return parentId + if (parentId !== ROOT_ID) return parentId } parent(): Ast | undefined { - return this.module.checkedGet(this.parentId) + return this.module.get(this.parentId) } static parseBlock(source: string, inModule?: MutableModule) { @@ -185,7 +202,7 @@ export abstract class MutableAst extends Ast { replace(replacement: Owned): Owned { const parentId = this.fields.get('parent') if (parentId) { - const parent = this.module.checkedGet(parentId) + const parent = this.module.get(parentId) parent.replaceChild(this.id, replacement) this.fields.set('parent', undefined) } @@ -230,7 +247,7 @@ export abstract class MutableAst extends Ast { takeIfParented(): Owned { const parent = parentId(this) if (parent) { - const parentAst = this.module.checkedGet(parent) + const parentAst = this.module.get(parent) const placeholder = Wildcard.new(this.module) parentAst.replaceChild(this.id, placeholder) this.fields.set('parent', undefined) @@ -277,7 +294,17 @@ export abstract class MutableAst extends Ast { mutableParent(): MutableAst | undefined { const parentId = this.fields.get('parent') if (parentId === 'ROOT_ID') return - return this.module.checkedGet(parentId) + return this.module.get(parentId) + } + + /** Modify this tree to represent the given code, while minimizing changes from the current set of `Ast`s. */ + syncToCode(code: string, metadataSource?: Module) { + syncToCode(this, code, metadataSource) + } + + /** Update the AST according to changes to its corresponding source code. */ + applyTextEdits(textEdits: SourceRangeEdit[], metadataSource?: Module) { + applyTextEditsToAst(this, textEdits, metadataSource ?? this.module) } /////////////////// @@ -287,7 +314,7 @@ export abstract class MutableAst extends Ast { if (module === this.module) return for (const child of this.concreteChildren()) { if (!isTokenId(child.node)) { - const childInForeignModule = module.checkedGet(child.node) + const childInForeignModule = module.get(child.node) assert(childInForeignModule !== undefined) const importedChild = this.module.copy(childInForeignModule) importedChild.fields.set('parent', undefined) @@ -297,7 +324,11 @@ export abstract class MutableAst extends Ast { } /** @internal */ - abstract replaceChild(target: AstId, replacement: Owned): void + replaceChild(target: AstId, replacement: Owned) { + const replacementId = this.claimChild(replacement) + const changes = rewriteRefs(this, (id) => (id === target ? replacementId : undefined)) + assertEqual(changes, 1) + } protected claimChild(child: Owned): AstId protected claimChild(child: Owned | undefined): AstId | undefined @@ -306,6 +337,130 @@ export abstract class MutableAst extends Ast { } } +/** Values that may be found in fields of `Ast` subtypes. */ +type FieldData = + | NodeChild + | NodeChild + | NodeChild + | FieldData[] + | undefined + | StructuralField +/** Objects that do not directly contain `AstId`s or `SyncTokenId`s, but may have `NodeChild` fields. */ +type StructuralField = + | RawMultiSegmentAppSegment + | RawBlockLine + | RawOpenCloseTokens + | RawNameSpecification +/** Type whose fields are all suitable for storage as `Ast` fields. */ +interface FieldObject { + [field: string]: FieldData +} +/** Returns the fields of an `Ast` subtype that are not part of `AstFields`. */ +function* fieldDataEntries(map: FixedMapView) { + for (const entry of map.entries()) { + // All fields that are not from `AstFields` are `FieldData`. + if (!astFieldKeys.includes(entry[0] as any)) yield entry as [string, FieldData] + } +} + +/** Apply the given function to each `AstId` in the fields of `ast`. For each value that it returns an output, that + * output will be substituted for the input ID. + */ +export function rewriteRefs(ast: MutableAst, f: (id: AstId) => AstId | undefined) { + let fieldsChanged = 0 + for (const [key, value] of fieldDataEntries(ast.fields)) { + const newValue = rewriteFieldRefs(value, f) + if (newValue !== undefined) { + ast.fields.set(key as any, newValue) + fieldsChanged += 1 + } + } + return fieldsChanged +} + +/** Copy all fields except the `Ast` base fields from `ast2` to `ast1`. A reference-rewriting function will be applied + * to `AstId`s in copied fields; see {@link rewriteRefs}. + */ +export function syncFields(ast1: MutableAst, ast2: Ast, f: (id: AstId) => AstId | undefined) { + for (const [key, value] of fieldDataEntries(ast2.fields)) { + const changedValue = rewriteFieldRefs(value, f) + const newValue = changedValue ?? value + if (!fieldEqual(ast1.fields.get(key as any), newValue)) ast1.fields.set(key as any, newValue) + } +} + +export function syncNodeMetadata(target: MutableNodeMetadata, source: NodeMetadata) { + const oldPos = target.get('position') + const newPos = source.get('position') + if (oldPos?.x !== newPos?.x || oldPos?.y !== newPos?.y) target.set('position', newPos) + const newVis = source.get('visualization') + if (!visMetadataEquals(target.get('visualization'), newVis)) target.set('visualization', newVis) +} + +function rewriteFieldRefs(field: FieldData, f: (id: AstId) => AstId | undefined): FieldData { + if (field === undefined) return field + if ('node' in field) { + const child = field.node + if (isTokenId(child)) return + const newValue = f(child) + if (newValue !== undefined) { + field.node = newValue + return field + } + } else if (Array.isArray(field)) { + let fieldChanged = false + field.forEach((subfield, i) => { + const newValue = rewriteFieldRefs(subfield, f) + if (newValue !== undefined) { + field[i] = newValue + fieldChanged = true + } + }) + if (fieldChanged) return field + } else { + const fieldObject = field satisfies StructuralField + let fieldChanged = false + for (const [key, value] of Object.entries(fieldObject)) { + const newValue = rewriteFieldRefs(value, f) + if (newValue !== undefined) { + // This update is safe because `newValue` was obtained by reading `fieldObject[key]` and modifying it in a + // type-preserving way. + ;(fieldObject as any)[key] = newValue + fieldChanged = true + } + } + if (fieldChanged) return fieldObject + } +} + +function fieldEqual(field1: FieldData, field2: FieldData): boolean { + if (field1 === undefined) return field2 === undefined + if (field2 === undefined) return false + if ('node' in field1 && 'node' in field2) { + if (field1['whitespace'] !== field2['whitespace']) return false + if (isTokenId(field1.node) && isTokenId(field2.node)) + return Token.equal(field1.node, field2.node) + else return field1.node === field2.node + } else if ('node' in field1 || 'node' in field2) { + return false + } else if (Array.isArray(field1) && Array.isArray(field2)) { + return ( + field1.length === field2.length && field1.every((value1, i) => fieldEqual(value1, field2[i])) + ) + } else if (Array.isArray(field1) || Array.isArray(field2)) { + return false + } else { + const fieldObject1 = field1 satisfies StructuralField + const fieldObject2 = field2 satisfies StructuralField + const keys = new Set() + for (const key of Object.keys(fieldObject1)) keys.add(key) + for (const key of Object.keys(fieldObject2)) keys.add(key) + for (const key of keys) + if (!fieldEqual((fieldObject1 as any)[key], (fieldObject2 as any)[key])) return false + return true + } +} + function applyMixins(derivedCtor: any, constructors: any[]) { constructors.forEach((baseCtor) => { Object.getOwnPropertyNames(baseCtor.prototype).forEach((name) => { @@ -320,10 +475,18 @@ function applyMixins(derivedCtor: any, constructors: any[]) { interface AppFields { function: NodeChild - parens: { open: NodeChild; close: NodeChild } | undefined - nameSpecification: { name: NodeChild; equals: NodeChild } | undefined + parens: RawOpenCloseTokens | undefined + nameSpecification: RawNameSpecification | undefined argument: NodeChild } +interface RawOpenCloseTokens { + open: NodeChild + close: NodeChild +} +interface RawNameSpecification { + name: NodeChild + equals: NodeChild +} export class App extends Ast { declare fields: FixedMap constructor(module: Module, fields: FixedMapView) { @@ -344,7 +507,7 @@ export class App extends Ast { ) { const base = module.baseObject('App') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { function: concreteChild(module, func, id_), parens, nameSpecification, @@ -361,7 +524,7 @@ export class App extends Ast { ) { return App.concrete( module, - unspaced(func), + autospaced(func), undefined, nameSpecification(argumentName), autospaced(argument), @@ -369,26 +532,27 @@ export class App extends Ast { } get function(): Ast { - return this.module.checkedGet(this.fields.get('function').node) + return this.module.get(this.fields.get('function').node) } get argumentName(): Token | undefined { return this.module.getToken(this.fields.get('nameSpecification')?.name.node) } get argument(): Ast { - return this.module.checkedGet(this.fields.get('argument').node) + return this.module.get(this.fields.get('argument').node) } *concreteChildren(verbatim?: boolean): IterableIterator { const { function: function_, parens, nameSpecification, argument } = getAll(this.fields) - yield function_ - if (parens) yield parens.open - const spacedEquals = !!parens && !!nameSpecification?.equals.whitespace + yield ensureUnspaced(function_, verbatim) + const useParens = !!(parens && (nameSpecification || verbatim)) + const spacedEquals = useParens && !!nameSpecification?.equals.whitespace + if (useParens) yield ensureSpaced(parens.open, verbatim) if (nameSpecification) { - yield ensureSpacedIf(nameSpecification.name, !parens, verbatim) + yield ensureSpacedIf(nameSpecification.name, !useParens, verbatim) yield ensureSpacedOnlyIf(nameSpecification.equals, spacedEquals, verbatim) } yield ensureSpacedOnlyIf(argument, !nameSpecification || spacedEquals, verbatim) - if (parens) yield parens.close + if (useParens) yield preferUnspaced(parens.close) } printSubtree( @@ -424,6 +588,9 @@ function ensureUnspaced(child: NodeChild, verbatim: boolean | undefined): if (verbatim && child.whitespace != null) return child return child.whitespace === '' ? child : { whitespace: '', ...child } } +function preferUnspaced(child: NodeChild): NodeChild { + return child.whitespace === undefined ? { whitespace: '', ...child } : child +} export class MutableApp extends App implements MutableAst { declare readonly module: MutableModule declare readonly fields: FixedMap @@ -437,14 +604,6 @@ export class MutableApp extends App implements MutableAst { setArgument(value: Owned) { setNode(this.fields, 'argument', this.claimChild(value)) } - - replaceChild(target: AstId, replacement: Owned) { - if (this.fields.get('function').node === target) { - this.setFunction(replacement) - } else if (this.fields.get('argument').node === target) { - this.setArgument(replacement) - } - } } export interface MutableApp extends App, MutableAst { get function(): MutableAst @@ -474,7 +633,7 @@ export class UnaryOprApp extends Ast { ) { const base = module.baseObject('UnaryOprApp') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { operator, argument: concreteChild(module, argument, id_), }) @@ -489,7 +648,7 @@ export class UnaryOprApp extends Ast { return this.module.getToken(this.fields.get('operator').node) } get argument(): Ast | undefined { - return this.module.checkedGet(this.fields.get('argument')?.node) + return this.module.get(this.fields.get('argument')?.node) } *concreteChildren(_verbatim?: boolean): IterableIterator { @@ -508,12 +667,6 @@ export class MutableUnaryOprApp extends UnaryOprApp implements MutableAst { setArgument(argument: Owned | undefined) { setNode(this.fields, 'argument', this.claimChild(argument)) } - - replaceChild(target: AstId, replacement: Owned) { - if (this.fields.get('argument')?.node === target) { - this.setArgument(replacement) - } - } } export interface MutableUnaryOprApp extends UnaryOprApp, MutableAst { get argument(): MutableAst | undefined @@ -538,7 +691,7 @@ export class NegationApp extends Ast { static concrete(module: MutableModule, operator: NodeChild, argument: NodeChild) { const base = module.baseObject('NegationApp') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { operator, argument: concreteChild(module, argument, id_), }) @@ -553,7 +706,7 @@ export class NegationApp extends Ast { return this.module.getToken(this.fields.get('operator').node) } get argument(): Ast { - return this.module.checkedGet(this.fields.get('argument').node) + return this.module.get(this.fields.get('argument').node) } *concreteChildren(_verbatim?: boolean): IterableIterator { @@ -569,12 +722,6 @@ export class MutableNegationApp extends NegationApp implements MutableAst { setArgument(value: Owned) { setNode(this.fields, 'argument', this.claimChild(value)) } - - replaceChild(target: AstId, replacement: Owned) { - if (this.fields.get('argument')?.node === target) { - this.setArgument(replacement) - } - } } export interface MutableNegationApp extends NegationApp, MutableAst { get argument(): MutableAst @@ -605,7 +752,7 @@ export class OprApp extends Ast { ) { const base = module.baseObject('OprApp') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { lhs: concreteChild(module, lhs, id_), operators, rhs: concreteChild(module, rhs, id_), @@ -623,7 +770,7 @@ export class OprApp extends Ast { } get lhs(): Ast | undefined { - return this.module.checkedGet(this.fields.get('lhs')?.node) + return this.module.get(this.fields.get('lhs')?.node) } get operator(): Result[]> { const operators = this.fields.get('operators') @@ -635,7 +782,7 @@ export class OprApp extends Ast { return opr ? Ok(opr.node) : Err(operators_) } get rhs(): Ast | undefined { - return this.module.checkedGet(this.fields.get('rhs')?.node) + return this.module.get(this.fields.get('rhs')?.node) } *concreteChildren(_verbatim?: boolean): IterableIterator { @@ -658,14 +805,6 @@ export class MutableOprApp extends OprApp implements MutableAst { setRhs(value: Owned) { setNode(this.fields, 'rhs', this.claimChild(value)) } - - replaceChild(target: AstId, replacement: Owned) { - if (this.fields.get('lhs')?.node === target) { - this.setLhs(replacement) - } else if (this.fields.get('rhs')?.node === target) { - this.setRhs(replacement) - } - } } export interface MutableOprApp extends OprApp, MutableAst { get lhs(): MutableAst | undefined @@ -736,7 +875,7 @@ export class PropertyAccess extends Ast { ) { const base = module.baseObject('PropertyAccess') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { lhs: concreteChild(module, lhs, id_), operator, rhs: concreteChild(module, rhs, id_), @@ -745,13 +884,13 @@ export class PropertyAccess extends Ast { } get lhs(): Ast | undefined { - return this.module.checkedGet(this.fields.get('lhs')?.node) + return this.module.get(this.fields.get('lhs')?.node) } get operator(): Token { return this.module.getToken(this.fields.get('operator').node) } get rhs(): IdentifierOrOperatorIdentifierToken { - const ast = this.module.checkedGet(this.fields.get('rhs').node) + const ast = this.module.get(this.fields.get('rhs').node) assert(ast instanceof Ident) return ast.token as IdentifierOrOperatorIdentifierToken } @@ -775,15 +914,6 @@ export class MutablePropertyAccess extends PropertyAccess implements MutableAst const old = this.fields.get('rhs') this.fields.set('rhs', old ? { ...old, node } : unspaced(node)) } - - replaceChild(target: AstId, replacement: Owned) { - if (this.fields.get('lhs')?.node === target) { - this.setLhs(replacement) - } else if (this.fields.get('rhs')?.node === target) { - assert(replacement instanceof MutableIdent) - this.setRhs(replacement.token) - } - } } export interface MutablePropertyAccess extends PropertyAccess, MutableAst { get lhs(): MutableAst | undefined @@ -802,7 +932,7 @@ export class Generic extends Ast { static concrete(module: MutableModule, children: NodeChild[]) { const base = module.baseObject('Generic') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { children: children.map((child) => concreteChild(module, child, id_)), }) return asOwned(new MutableGeneric(module, fields)) @@ -815,14 +945,6 @@ export class Generic extends Ast { export class MutableGeneric extends Generic implements MutableAst { declare readonly module: MutableModule declare readonly fields: FixedMap - - replaceChild(target: AstId, replacement: Owned) { - const replacement_ = autospaced(this.claimChild(replacement)) - this.fields.set( - 'children', - this.fields.get('children').map((child) => (child.node === target ? replacement_ : child)), - ) - } } export interface MutableGeneric extends Generic, MutableAst {} applyMixins(MutableGeneric, [MutableAst]) @@ -899,22 +1021,22 @@ export class Import extends Ast { } get polyglot(): Ast | undefined { - return this.module.checkedGet(this.fields.get('polyglot')?.body?.node) + return this.module.get(this.fields.get('polyglot')?.body?.node) } get from(): Ast | undefined { - return this.module.checkedGet(this.fields.get('from')?.body?.node) + return this.module.get(this.fields.get('from')?.body?.node) } get import_(): Ast | undefined { - return this.module.checkedGet(this.fields.get('import').body?.node) + return this.module.get(this.fields.get('import').body?.node) } get all(): Token | undefined { return this.module.getToken(this.fields.get('all')?.node) } get as(): Ast | undefined { - return this.module.checkedGet(this.fields.get('as')?.body?.node) + return this.module.get(this.fields.get('as')?.body?.node) } get hiding(): Ast | undefined { - return this.module.checkedGet(this.fields.get('hiding')?.body?.node) + return this.module.get(this.fields.get('hiding')?.body?.node) } static concrete( @@ -928,7 +1050,7 @@ export class Import extends Ast { ) { const base = module.baseObject('Import') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { polyglot: multiSegmentAppSegmentToRaw(module, polyglot, id_), from: multiSegmentAppSegmentToRaw(module, from, id_), import: multiSegmentAppSegmentToRaw(module, import_, id_), @@ -1023,21 +1145,6 @@ export class MutableImport extends Import implements MutableAst { setHiding(value: Owned | undefined) { this.fields.set('hiding', this.toRaw(multiSegmentAppSegment('hiding', value))) } - - replaceChild(target: AstId, replacement: Owned) { - const { polyglot, from, import: import_, as, hiding } = getAll(this.fields) - polyglot?.body?.node === target - ? this.setPolyglot(replacement) - : from?.body?.node === target - ? this.setFrom(replacement) - : import_.body?.node === target - ? this.setImport(replacement) - : as?.body?.node === target - ? this.setAs(replacement) - : hiding?.body?.node === target - ? this.setHiding(replacement) - : bail(`Failed to find child ${target} in node ${this.externalId}.`) - } } export interface MutableImport extends Import, MutableAst { get polyglot(): MutableAst | undefined @@ -1092,7 +1199,7 @@ export class TextLiteral extends Ast { ) { const base = module.baseObject('TextLiteral') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { open, newline, elements: elements.map((elem) => concreteChild(module, elem, id_)), @@ -1119,14 +1226,6 @@ export class TextLiteral extends Ast { export class MutableTextLiteral extends TextLiteral implements MutableAst { declare readonly module: MutableModule declare readonly fields: FixedMap - - replaceChild(target: AstId, replacement: Owned) { - const replacement_ = autospaced(this.claimChild(replacement)) - this.fields.set( - 'elements', - this.fields.get('elements').map((child) => (child.node === target ? replacement_ : child)), - ) - } } export interface MutableTextLiteral extends TextLiteral, MutableAst {} applyMixins(MutableTextLiteral, [MutableAst]) @@ -1157,7 +1256,7 @@ export class Documented extends Ast { ) { const base = module.baseObject('Documented') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { open, elements: elements.map((elem) => concreteChild(module, elem, id_)), newlines, @@ -1167,7 +1266,7 @@ export class Documented extends Ast { } get expression(): Ast | undefined { - return this.module.checkedGet(this.fields.get('expression')?.node) + return this.module.get(this.fields.get('expression')?.node) } *concreteChildren(_verbatim?: boolean): IterableIterator { @@ -1185,18 +1284,6 @@ export class MutableDocumented extends Documented implements MutableAst { setExpression(value: Owned | undefined) { this.fields.set('expression', unspaced(this.claimChild(value))) } - - replaceChild(target: AstId, replacement: Owned) { - if (this.fields.get('expression')?.node === target) { - this.setExpression(replacement) - } else { - const replacement_ = unspaced(this.claimChild(replacement)) - this.fields.set( - 'elements', - this.fields.get('elements').map((child) => (child.node === target ? replacement_ : child)), - ) - } - } } export interface MutableDocumented extends Documented, MutableAst { get expression(): MutableAst | undefined @@ -1218,7 +1305,7 @@ export class Invalid extends Ast { } get expression(): Ast { - return this.module.checkedGet(this.fields.get('expression').node) + return this.module.get(this.fields.get('expression').node) } *concreteChildren(_verbatim?: boolean): IterableIterator { @@ -1240,7 +1327,7 @@ export function invalidFields( expression: NodeChild, ): FixedMap { const id_ = base.get('id') - return setAll(base, { expression: concreteChild(module, expression, id_) }) + return composeFieldData(base, { expression: concreteChild(module, expression, id_) }) } export class MutableInvalid extends Invalid implements MutableAst { declare readonly module: MutableModule @@ -1250,11 +1337,6 @@ export class MutableInvalid extends Invalid implements MutableAst { private setExpression(value: Owned) { this.fields.set('expression', unspaced(this.claimChild(value))) } - - replaceChild(target: AstId, replacement: Owned) { - assertEqual(this.fields.get('expression').node, target) - this.setExpression(replacement) - } } export interface MutableInvalid extends Invalid, MutableAst { /** The `expression` getter is intentionally not narrowed to provide mutable access: @@ -1286,7 +1368,11 @@ export class Group extends Ast { ) { const base = module.baseObject('Group') const id_ = base.get('id') - const fields = setAll(base, { open, expression: concreteChild(module, expression, id_), close }) + const fields = composeFieldData(base, { + open, + expression: concreteChild(module, expression, id_), + close, + }) return asOwned(new MutableGroup(module, fields)) } @@ -1297,7 +1383,7 @@ export class Group extends Ast { } get expression(): Ast | undefined { - return this.module.checkedGet(this.fields.get('expression')?.node) + return this.module.get(this.fields.get('expression')?.node) } *concreteChildren(_verbatim?: boolean): IterableIterator { @@ -1314,11 +1400,6 @@ export class MutableGroup extends Group implements MutableAst { setExpression(value: Owned | undefined) { this.fields.set('expression', unspaced(this.claimChild(value))) } - - replaceChild(target: AstId, replacement: Owned) { - assertEqual(this.fields.get('expression')?.node, target) - this.setExpression(replacement) - } } export interface MutableGroup extends Group, MutableAst { get expression(): MutableAst | undefined @@ -1344,7 +1425,7 @@ export class NumericLiteral extends Ast { static concrete(module: MutableModule, tokens: NodeChild[]) { const base = module.baseObject('NumericLiteral') - const fields = setAll(base, { tokens }) + const fields = composeFieldData(base, { tokens }) return asOwned(new MutableNumericLiteral(module, fields)) } @@ -1355,8 +1436,6 @@ export class NumericLiteral extends Ast { export class MutableNumericLiteral extends NumericLiteral implements MutableAst { declare readonly module: MutableModule declare readonly fields: FixedMap - - replaceChild(_target: AstId, _replacement: Owned) {} } export interface MutableNumericLiteral extends NumericLiteral, MutableAst {} applyMixins(MutableNumericLiteral, [MutableAst]) @@ -1398,10 +1477,10 @@ export class Function extends Ast { } get name(): Ast { - return this.module.checkedGet(this.fields.get('name').node) + return this.module.get(this.fields.get('name').node) } get body(): Ast | undefined { - return this.module.checkedGet(this.fields.get('body')?.node) + return this.module.get(this.fields.get('body')?.node) } get argumentDefinitions(): ArgumentDefinition[] { return this.fields.get('argumentDefinitions').map((raw) => @@ -1421,7 +1500,7 @@ export class Function extends Ast { ) { const base = module.baseObject('Function') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { name: concreteChild(module, name, id_), argumentDefinitions: argumentDefinitionsToRaw(module, argumentDefinitions, id_), equals, @@ -1478,7 +1557,11 @@ export class Function extends Ast { for (const def of argumentDefinitions) yield* def yield { whitespace: equals.whitespace ?? ' ', node: this.module.getToken(equals.node) } if (body) - yield ensureSpacedOnlyIf(body, !(this.module.get(body.node) instanceof BodyBlock), verbatim) + yield ensureSpacedOnlyIf( + body, + !(this.module.tryGet(body.node) instanceof BodyBlock), + verbatim, + ) } } export class MutableFunction extends Function implements MutableAst { @@ -1503,23 +1586,6 @@ export class MutableFunction extends Function implements MutableAst { if (oldBody) newBody.push(oldBody.take()) return newBody } - - replaceChild(target: AstId, replacement: Owned) { - const { name, argumentDefinitions, body } = getAll(this.fields) - if (name.node === target) { - this.setName(replacement) - } else if (body?.node === target) { - this.setBody(replacement) - } else { - const replacement_ = this.claimChild(replacement) - const replaceChild = (child: NodeChild) => - child.node === target ? { ...child, node: replacement_ } : child - this.fields.set( - 'argumentDefinitions', - argumentDefinitions.map((def) => def.map(replaceChild)), - ) - } - } } export interface MutableFunction extends Function, MutableAst { get name(): MutableAst @@ -1551,7 +1617,7 @@ export class Assignment extends Ast { ) { const base = module.baseObject('Assignment') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { pattern: concreteChild(module, pattern, id_), equals, expression: concreteChild(module, expression, id_), @@ -1569,10 +1635,10 @@ export class Assignment extends Ast { } get pattern(): Ast { - return this.module.checkedGet(this.fields.get('pattern').node) + return this.module.get(this.fields.get('pattern').node) } get expression(): Ast { - return this.module.checkedGet(this.fields.get('expression').node) + return this.module.get(this.fields.get('expression').node) } *concreteChildren(verbatim?: boolean): IterableIterator { @@ -1592,15 +1658,6 @@ export class MutableAssignment extends Assignment implements MutableAst { setExpression(value: Owned) { setNode(this.fields, 'expression', this.claimChild(value)) } - - replaceChild(target: AstId, replacement: Owned) { - const { pattern, expression } = getAll(this.fields) - if (pattern.node === target) { - this.setPattern(replacement) - } else if (expression.node === target) { - this.setExpression(replacement) - } - } } export interface MutableAssignment extends Assignment, MutableAst { get pattern(): MutableAst @@ -1625,7 +1682,7 @@ export class BodyBlock extends Ast { static concrete(module: MutableModule, lines: OwnedBlockLine[]) { const base = module.baseObject('BodyBlock') const id_ = base.get('id') - const fields = setAll(base, { + const fields = composeFieldData(base, { lines: lines.map((line) => lineToRaw(line, module, id_)), }) return asOwned(new MutableBodyBlock(module, fields)) @@ -1702,19 +1759,10 @@ export class MutableBodyBlock extends BodyBlock implements MutableAst { const oldLines = this.fields.get('lines') const filteredLines = oldLines.filter((line) => { if (!line.expression) return true - return keep(this.module.checkedGet(line.expression.node)) + return keep(this.module.get(line.expression.node)) }) this.fields.set('lines', filteredLines) } - - replaceChild(target: AstId, replacement: Owned) { - const replacement_ = this.claimChild(replacement) - const updateLine = (line: RawBlockLine) => - line.expression?.node === target - ? { ...line, expression: { ...line.expression, node: replacement_ } } - : line - this.fields.set('lines', this.fields.get('lines').map(updateLine)) - } } export interface MutableBodyBlock extends BodyBlock, MutableAst { statements(): IterableIterator @@ -1730,12 +1778,12 @@ interface Line { expression: NodeChild | undefined } -type RawBlockLine = RawLine +interface RawBlockLine extends RawLine {} export type BlockLine = Line export type OwnedBlockLine = Line function lineFromRaw(raw: RawBlockLine, module: Module): BlockLine { - const expression = raw.expression ? module.checkedGet(raw.expression.node) : undefined + const expression = raw.expression ? module.get(raw.expression.node) : undefined return { newline: { ...raw.newline, node: module.getToken(raw.newline.node) }, expression: expression @@ -1748,9 +1796,7 @@ function lineFromRaw(raw: RawBlockLine, module: Module): BlockLine { } function ownedLineFromRaw(raw: RawBlockLine, module: MutableModule): OwnedBlockLine { - const expression = raw.expression - ? module.checkedGet(raw.expression.node).takeIfParented() - : undefined + const expression = raw.expression ? module.get(raw.expression.node).takeIfParented() : undefined return { newline: { ...raw.newline, node: module.getToken(raw.newline.node) }, expression: expression @@ -1794,7 +1840,7 @@ export class Ident extends Ast { static concrete(module: MutableModule, token: NodeChild) { const base = module.baseObject('Ident') - const fields = setAll(base, { token }) + const fields = composeFieldData(base, { token }) return asOwned(new MutableIdent(module, fields)) } @@ -1823,8 +1869,6 @@ export class MutableIdent extends Ident implements MutableAst { this.fields.set('token', unspaced(toIdent(ident))) } - replaceChild(_target: AstId, _replacement: Owned) {} - code(): Identifier { return this.token.code() } @@ -1852,7 +1896,7 @@ export class Wildcard extends Ast { static concrete(module: MutableModule, token: NodeChild) { const base = module.baseObject('Wildcard') - const fields = setAll(base, { token }) + const fields = composeFieldData(base, { token }) return asOwned(new MutableWildcard(module, fields)) } @@ -1869,8 +1913,6 @@ export class Wildcard extends Ast { export class MutableWildcard extends Wildcard implements MutableAst { declare readonly module: MutableModule declare readonly fields: FixedMap - - replaceChild(_target: AstId, _replacement: Owned) {} } export interface MutableWildcard extends Wildcard, MutableAst {} applyMixins(MutableWildcard, [MutableAst]) @@ -2010,19 +2052,36 @@ function getAll(map: FixedMapView): Fields { return Object.fromEntries(map.entries()) as Fields } +declare const brandLegalFieldContent: unique symbol +/** Used to add a constraint to all `AstFields`s subtypes ensuring that they were produced by `composeFieldData`, which + * enforces a requirement that the provided fields extend `FieldObject`. + */ +interface LegalFieldContent { + [brandLegalFieldContent]: never +} + /** Modifies the input `map`. Returns the same object with an extended type. */ -export function setAll( +export function setAll>( map: FixedMap, fields: Fields2, ): FixedMap { const map_ = map as FixedMap for (const [k, v] of Object.entries(fields)) { const k_ = k as string & (keyof Fields1 | keyof Fields2) - map_.set(k_, v) + map_.set(k_, v as any) } return map_ } +/** Modifies the input `map`. Returns the same object with an extended type. The added fields are required to have only + * types extending `FieldData`; the returned object is branded as `LegalFieldContent`. */ +export function composeFieldData( + map: FixedMap, + fields: Fields2, +): FixedMap { + return setAll(map, fields) as FixedMap +} + function claimChild( module: MutableModule, child: Owned, diff --git a/app/gui2/shared/util/data/__tests__/iterable.test.ts b/app/gui2/shared/util/data/__tests__/iterable.test.ts new file mode 100644 index 000000000000..3ee69da0453e --- /dev/null +++ b/app/gui2/shared/util/data/__tests__/iterable.test.ts @@ -0,0 +1,8 @@ +import { tryGetSoleValue } from 'shared/util/data/iterable' +import { expect, test } from 'vitest' + +test('tryGetSoleValue', () => { + expect(tryGetSoleValue([])).toBeUndefined() + expect(tryGetSoleValue([1])).toEqual(1) + expect(tryGetSoleValue([1, 2])).toBeUndefined() +}) diff --git a/app/gui2/shared/util/data/__tests__/text.test.ts b/app/gui2/shared/util/data/__tests__/text.test.ts new file mode 100644 index 000000000000..2e0b2fb4771d --- /dev/null +++ b/app/gui2/shared/util/data/__tests__/text.test.ts @@ -0,0 +1,137 @@ +import { expect, test } from 'vitest' +import { applyTextEditsToSpans, textChangeToEdits, trimEnd } from '../text' + +/** Tests that: + * - When the code in `a[0]` is edited to become the code in `b[0]`, + * `applyTextEditsToSpans` followed by `trimEnd` transforms the spans in `a.slice(1)` into the spans in `b.slice(1)`. + * - The same holds when editing from `b` to `a`. + */ +function checkCorrespondence(a: string[], b: string[]) { + checkCorrespondenceForward(a, b) + checkCorrespondenceForward(b, a) +} + +/** Performs the same check as {@link checkCorrespondence}, for correspondences that are not expected to be reversible. + */ +function checkCorrespondenceForward(before: string[], after: string[]) { + const leadingSpacesAndLength = (input: string): [number, number] => [ + input.lastIndexOf(' ') + 1, + input.length, + ] + const spacesAndHyphens = ([spaces, length]: readonly [number, number]) => { + return ' '.repeat(spaces) + '-'.repeat(length - spaces) + } + const edits = textChangeToEdits(before[0]!, after[0]!) + const spansAfter = applyTextEditsToSpans(edits, before.slice(1).map(leadingSpacesAndLength)).map( + ([_spanBefore, spanAfter]) => trimEnd(spanAfter, after[0]!), + ) + expect([after[0]!, ...spansAfter.map(spacesAndHyphens)]).toEqual(after) +} + +test('applyTextEditsToSpans: Add and remove argument names.', () => { + checkCorrespondence( + [ + 'func arg1 arg2', // prettier-ignore + '----', + ' ----', + '---------', + ' ----', + '--------------', + ], + [ + 'func name1=arg1 name2=arg2', + '----', + ' ----', + '---------------', + ' ----', + '--------------------------', + ], + ) +}) + +test('applyTextEditsToSpans: Lengthen and shorten argument names.', () => { + checkCorrespondence( + [ + 'func name1=arg1 name2=arg2', + '----', + ' ----', + '---------------', + ' ----', + '--------------------------', + ], + [ + 'func longName1=arg1 longName2=arg2', + '----', + ' ----', + '-------------------', + ' ----', + '----------------------------------', + ], + ) +}) + +test('applyTextEditsToSpans: Add and remove inner application.', () => { + checkCorrespondence( + [ + 'func bbb2', // prettier-ignore + '----', + ' ----', + '---------', + ], + [ + 'func aaa1 bbb2', // prettier-ignore + '----', + ' ----', + '--------------', + ], + ) +}) + +test('applyTextEditsToSpans: Add and remove outer application.', () => { + checkCorrespondence( + [ + 'func arg1', // prettier-ignore + '----', + ' ----', + '---------', + ], + [ + 'func arg1 arg2', // prettier-ignore + '----', + ' ----', + '---------', + ], + ) +}) + +test('applyTextEditsToSpans: Distinguishing repeated subexpressions.', () => { + checkCorrespondence( + [ + 'foo (2 + 2) bar () (2 + 2)', // prettier-ignore + ' -----', + ' -------', + ' -----', + ' -------', + ], + [ + 'foo (2 + 2) bar (2 + 2) (2 + 2)', // prettier-ignore + ' -----', + ' -------', + ' -----', + ' -------', + ], + ) +}) + +test('applyTextEditsToSpans: Space after line content.', () => { + checkCorrespondenceForward( + [ + 'value = 1 +', // prettier-ignore + '-----------', + ], + [ + 'value = 1 ', // prettier-ignore + '---------', + ], + ) +}) diff --git a/app/gui2/shared/util/data/iterable.ts b/app/gui2/shared/util/data/iterable.ts new file mode 100644 index 000000000000..96d36d9406a4 --- /dev/null +++ b/app/gui2/shared/util/data/iterable.ts @@ -0,0 +1,98 @@ +/** @file Functions for manipulating {@link Iterable}s. */ + +export function* empty(): Generator {} + +export function* range(start: number, stop: number, step = start <= stop ? 1 : -1) { + if ((step > 0 && start > stop) || (step < 0 && start < stop)) { + throw new Error( + "The range's step is in the wrong direction - please use Infinity or -Infinity as the endpoint for an infinite range.", + ) + } + if (start <= stop) { + while (start < stop) { + yield start + start += step + } + } else { + while (start > stop) { + yield start + start += step + } + } +} + +export function* map(iter: Iterable, map: (value: T) => U) { + for (const value of iter) { + yield map(value) + } +} + +export function* chain(...iters: Iterable[]) { + for (const iter of iters) { + yield* iter + } +} + +export function* zip(left: Iterable, right: Iterable): Generator<[T, U]> { + const leftIterator = left[Symbol.iterator]() + const rightIterator = right[Symbol.iterator]() + while (true) { + const leftResult = leftIterator.next() + const rightResult = rightIterator.next() + if (leftResult.done || rightResult.done) break + yield [leftResult.value, rightResult.value] + } +} + +export function* zipLongest( + left: Iterable, + right: Iterable, +): Generator<[T | undefined, U | undefined]> { + const leftIterator = left[Symbol.iterator]() + const rightIterator = right[Symbol.iterator]() + while (true) { + const leftResult = leftIterator.next() + const rightResult = rightIterator.next() + if (leftResult.done && rightResult.done) break + yield [ + leftResult.done ? undefined : leftResult.value, + rightResult.done ? undefined : rightResult.value, + ] + } +} + +export function tryGetSoleValue(iter: Iterable): T | undefined { + const iterator = iter[Symbol.iterator]() + const result = iterator.next() + if (result.done) return + const excessResult = iterator.next() + if (!excessResult.done) return + return result.value +} + +/** Utility to simplify consuming an iterator a part at a time. */ +export class Resumable { + private readonly iterator: Iterator + private current: IteratorResult + constructor(iterable: Iterable) { + this.iterator = iterable[Symbol.iterator]() + this.current = this.iterator.next() + } + + /** The given function peeks at the current value. If the function returns `true`, the current value will be advanced + * and the function called again; if it returns `false`, the peeked value remains current and `advanceWhile` returns. + */ + advanceWhile(f: (value: T) => boolean) { + while (!this.current.done && f(this.current.value)) { + this.current = this.iterator.next() + } + } + + /** Apply the given function to all values remaining in the iterator. */ + forEach(f: (value: T) => void) { + while (!this.current.done) { + f(this.current.value) + this.current = this.iterator.next() + } + } +} diff --git a/app/gui2/shared/util/data/text.ts b/app/gui2/shared/util/data/text.ts new file mode 100644 index 000000000000..9f4027df94b4 --- /dev/null +++ b/app/gui2/shared/util/data/text.ts @@ -0,0 +1,149 @@ +import diff from 'fast-diff' +import { rangeEncloses, rangeLength, type SourceRange } from '../../yjsModel' +import { Resumable } from './iterable' + +export type SourceRangeEdit = { range: SourceRange; insert: string } + +/** Given text and a set of `TextEdit`s, return the result of applying the edits to the text. */ +export function applyTextEdits(oldText: string, textEdits: SourceRangeEdit[]) { + textEdits.sort((a, b) => a.range[0] - b.range[0]) + let start = 0 + let newText = '' + for (const textEdit of textEdits) { + newText += oldText.slice(start, textEdit.range[0]) + newText += textEdit.insert + start = textEdit.range[1] + } + newText += oldText.slice(start) + return newText +} + +/** Given text before and after a change, return one possible set of {@link SourceRangeEdit}s describing the change. */ +export function textChangeToEdits(before: string, after: string): SourceRangeEdit[] { + const textEdits: SourceRangeEdit[] = [] + let nextEdit: undefined | SourceRangeEdit + let pos = 0 + // Sequences fast-diff emits: + // EQUAL, INSERT + // EQUAL, DELETE + // DELETE, EQUAL + // DELETE, INSERT + // INSERT, EQUAL + for (const [op, text] of diff(before, after)) { + switch (op) { + case diff.INSERT: + if (!nextEdit) nextEdit = { range: [pos, pos], insert: '' } + nextEdit.insert = text + break + case diff.EQUAL: + if (nextEdit) { + textEdits.push(nextEdit) + nextEdit = undefined + } + pos += text.length + break + case diff.DELETE: { + if (nextEdit) textEdits.push(nextEdit) + const endPos = pos + text.length + nextEdit = { range: [pos, endPos], insert: '' } + pos = endPos + break + } + } + } + if (nextEdit) textEdits.push(nextEdit) + return textEdits +} + +/** Translate a `TextEdit` by the specified offset. */ +export function offsetEdit(textEdit: SourceRangeEdit, offset: number): SourceRangeEdit { + return { ...textEdit, range: [textEdit.range[0] + offset, textEdit.range[1] + offset] } +} + +/** Given: + * @param textEdits - A change described by a set of text edits. + * @param spansBefore - A collection of spans in the text before the edit. + * @returns - A sequence of: Each span from `spansBefore` paired with the smallest span of the text after the edit that + * contains all text that was in the original span and has not been deleted. */ +export function applyTextEditsToSpans(textEdits: SourceRangeEdit[], spansBefore: SourceRange[]) { + // Gather start and end points. + const numerically = (a: number, b: number) => a - b + const starts = new Resumable(spansBefore.map(([start, _end]) => start).sort(numerically)) + const ends = new Resumable(spansBefore.map(([_start, end]) => end).sort(numerically)) + + // Construct translations from old locations to new locations for all start and end points. + const startMap = new Map() + const endMap = new Map() + let offset = 0 + for (const { range, insert } of textEdits) { + starts.advanceWhile((start) => { + if (start < range[0]) { + startMap.set(start, start + offset) + return true + } else if (start <= range[1]) { + startMap.set(start, range[0] + offset + insert.length) + return true + } + return false + }) + ends.advanceWhile((end) => { + if (end <= range[0]) { + endMap.set(end, end + offset) + return true + } else if (end <= range[1]) { + endMap.set(end, range[0] + offset) + return true + } + return false + }) + offset += insert.length - rangeLength(range) + } + starts.forEach((start) => startMap.set(start, start + offset)) + ends.forEach((end) => endMap.set(end, end + offset)) + + // Apply the translations to the map. + const spansBeforeAndAfter = new Array() + for (const spanBefore of spansBefore) { + const startAfter = startMap.get(spanBefore[0])! + const endAfter = endMap.get(spanBefore[1])! + if (endAfter > startAfter) spansBeforeAndAfter.push([spanBefore, [startAfter, endAfter]]) + } + return spansBeforeAndAfter +} + +export interface SpanTree { + id(): NodeId + span(): SourceRange + children(): IterableIterator> +} + +/** Given a span tree and some ranges, for each range find the smallest node that fully encloses it. + * Return nodes paired with the ranges that are most closely enclosed by them. + */ +export function enclosingSpans( + tree: SpanTree, + ranges: SourceRange[], + resultsOut?: [NodeId, SourceRange[]][], +) { + const results = resultsOut ?? [] + for (const child of tree.children()) { + const childSpan = child.span() + const childRanges: SourceRange[] = [] + ranges = ranges.filter((range) => { + if (rangeEncloses(childSpan, range)) { + childRanges.push(range) + return false + } + return true + }) + if (childRanges.length) enclosingSpans(child, childRanges, results) + } + if (ranges.length) results.push([tree.id(), ranges]) + return results +} + +/** Return the given range with any trailing spaces stripped. */ +export function trimEnd(range: SourceRange, text: string): SourceRange { + const trimmedLength = text.slice(range[0], range[1]).search(/ +$/) + return trimmedLength === -1 ? range : [range[0], range[0] + trimmedLength] +} diff --git a/app/gui2/shared/yjsModel.ts b/app/gui2/shared/yjsModel.ts index bd333ae84dc5..609962081a81 100644 --- a/app/gui2/shared/yjsModel.ts +++ b/app/gui2/shared/yjsModel.ts @@ -120,15 +120,25 @@ export class DistributedModule { this.undoManager = new Y.UndoManager([this.doc.nodes]) } - transact(fn: () => T): T { - return this.doc.ydoc.transact(fn, 'local') - } - dispose(): void { this.doc.ydoc.destroy() } } +export const localOrigins = ['local', 'local:CodeEditor'] as const +export type LocalOrigin = (typeof localOrigins)[number] +export type Origin = LocalOrigin | 'remote' +/** Locally-originated changes not otherwise specified. */ +export const defaultLocalOrigin: LocalOrigin = 'local' +export function isLocalOrigin(origin: string): origin is LocalOrigin { + const localOriginNames: readonly string[] = localOrigins + return localOriginNames.includes(origin) +} +export function tryAsOrigin(origin: string): Origin | undefined { + if (isLocalOrigin(origin)) return origin + if (origin === 'remote') return origin +} + export type SourceRange = readonly [start: number, end: number] declare const brandSourceRangeKey: unique symbol export type SourceRangeKey = string & { [brandSourceRangeKey]: never } @@ -230,6 +240,14 @@ export function rangeEquals(a: SourceRange, b: SourceRange): boolean { return a[0] == b[0] && a[1] == b[1] } +export function rangeIncludes(a: SourceRange, b: number): boolean { + return a[0] <= b && a[1] >= b +} + +export function rangeLength(a: SourceRange): number { + return a[1] - a[0] +} + export function rangeEncloses(a: SourceRange, b: SourceRange): boolean { return a[0] <= b[0] && a[1] >= b[1] } diff --git a/app/gui2/src/components/CodeEditor.vue b/app/gui2/src/components/CodeEditor.vue index 3d101f23f572..239510734be0 100644 --- a/app/gui2/src/components/CodeEditor.vue +++ b/app/gui2/src/components/CodeEditor.vue @@ -1,8 +1,8 @@