diff --git a/languageServer/src/TextDocumentWrapper.ts b/languageServer/src/TextDocumentWrapper.ts index 641d013340..90a42d7ca4 100644 --- a/languageServer/src/TextDocumentWrapper.ts +++ b/languageServer/src/TextDocumentWrapper.ts @@ -1,34 +1,20 @@ -import { - DocumentSymbol, - Position, - SymbolKind, - Range, - Location -} from 'vscode-languageserver/node' +import { DocumentSymbol, Position, Location } from 'vscode-languageserver/node' import { TextDocument } from 'vscode-languageserver-textdocument' -import { - isNode, - isScalar, - parseDocument, - visit, - Node, - Scalar, - Pair, - isPair -} from 'yaml' -import { has } from 'lodash' -import { findNodeAtLocation, parse, parseTree } from 'jsonc-parser' -import * as RegExes from './regexes' +import { parseDocument } from 'yaml' import { ITextDocumentWrapper } from './ITextDocumentWrapper' +import { LanguageHelper } from './languageHelpers/baseLanguageHelper' +import { createLanguageHelper } from './languageHelpers' export class TextDocumentWrapper implements ITextDocumentWrapper { uri: string private textDocument: TextDocument + private languageHelper: LanguageHelper constructor(textDocument: TextDocument) { this.textDocument = textDocument this.uri = this.textDocument.uri + this.languageHelper = createLanguageHelper(this.textDocument) } public offsetAt(position: Position) { @@ -48,248 +34,10 @@ export class TextDocumentWrapper implements ITextDocumentWrapper { } public findLocationsFor(symbol: DocumentSymbol): Location[] { - const propertyPath = symbol.kind === SymbolKind.Property && symbol.detail - const itIsHere = propertyPath && this.hasProperty(propertyPath) - - if (itIsHere) { - const location = this.getPropertyLocation(propertyPath) - - return location ? [location] : [] - } - - const parts = symbol.name.split(/\s/g) - const txt = this.getText() - - const acc: Location[] = [] - for (const str of parts) { - const index = txt.indexOf(str) - if (index <= 0) { - continue - } - const pos = this.positionAt(index) - const range = this.symbolAt(pos)?.range - if (!range) { - continue - } - acc.push(Location.create(this.uri, range as Range)) - } - return acc + return this.languageHelper.findLocationsFor(symbol) } public symbolAt(position: Position): DocumentSymbol | undefined { - return this.symbolScopeAt(position).pop() - } - - private getSymbolsFromPropertyPath(pathSegment: string, startIndex: number) { - const templateSymbols: DocumentSymbol[] = [] - const symbols = pathSegment.matchAll(RegExes.alphadecimalWords) - - const jsonPath: string[] = [] // Safe to assume, based on https://dvc.org/doc/user-guide/project-structure/dvcyaml-files#vars - - for (const templateSymbol of symbols) { - const symbolName = templateSymbol[0] - const symbolJsonPath = [...jsonPath, symbolName] - const symbolStart = (templateSymbol.index ?? 0) + startIndex - const symbolEnd = symbolStart + templateSymbol[0].length - const symbolRange = Range.create( - this.positionAt(symbolStart), - this.positionAt(symbolEnd) - ) - - templateSymbols.push( - DocumentSymbol.create( - templateSymbol[0], - symbolJsonPath.join('.'), - SymbolKind.Property, - symbolRange, - symbolRange - ) - ) - - jsonPath.push(symbolName) - } - - return templateSymbols - } - - private extractPropertyPathSymbolsFrom(text: string, startIndex: number) { - const symbols: DocumentSymbol[] = [] - const pathLikeSegments = text.matchAll(RegExes.propertyPathLike) - - for (const path of pathLikeSegments) { - const matchIndex = path.index ?? 0 - - symbols.push( - ...this.getSymbolsFromPropertyPath(path[0], startIndex + matchIndex) - ) - } - - return symbols - } - - private getTemplateExpressionSymbolsInsideScalar( - scalarValue: string, - nodeOffset: number - ) { - const templateSymbols: DocumentSymbol[] = [] - - const matches = scalarValue.matchAll(RegExes.variableTemplates) - - for (const match of matches) { - const expression = match[1] - const matchOffset = match.index || 0 - const expressionOffset: number = nodeOffset + matchOffset + 2 // To account for the '${' - - templateSymbols.push( - ...this.extractPropertyPathSymbolsFrom(expression, expressionOffset) - ) - } - - return templateSymbols - } - - private yamlScalarNodeToDocumentSymbols( - node: Scalar, - [nodeStart, valueEnd, nodeEnd]: [number, number, number] - ) { - const nodeValue = `${node.value}` - - let symbolKind: SymbolKind = SymbolKind.String - - if (/\.[A-Za-z]+$/.test(nodeValue)) { - symbolKind = SymbolKind.File - } - - const children: DocumentSymbol[] = [] - - const variableTemplateSymbols = [ - ...this.getTemplateExpressionSymbolsInsideScalar(nodeValue, nodeStart) - ] - - if (variableTemplateSymbols.length > 0) { - children.push(...variableTemplateSymbols) - } else { - const propertyPathSymbols = this.extractPropertyPathSymbolsFrom( - nodeValue, - nodeStart - ) - children.push(...propertyPathSymbols) - } - - const symbolsSoFar: DocumentSymbol[] = [ - DocumentSymbol.create( - nodeValue, - undefined, - symbolKind, - Range.create(this.positionAt(nodeStart), this.positionAt(nodeEnd)), - Range.create(this.positionAt(nodeStart), this.positionAt(valueEnd)) - ), - ...children - ] - - return symbolsSoFar - } - - private yamlNodeToDocumentSymbols( - node: Node | Pair, - range: [number, number, number] - ): DocumentSymbol[] { - if (isScalar(node)) { - return this.yamlScalarNodeToDocumentSymbols(node, range) - } - - if (isPair(node)) { - return this.yamlNodeToDocumentSymbols(node.value as Node | Pair, range) - } - - return [] - } - - private symbolScopeAt(position: Position): DocumentSymbol[] { - const cursorOffset: number = this.offsetAt(position) - - const symbolsFound: Array = [] - - if (this.uri.endsWith('yaml')) { - visit(this.getYamlDocument(), (_, node) => { - if (isNode(node) && node.range) { - const range = node.range - const nodeStart = range[0] - const nodeEnd = range[2] - const isCursorInsideNode = - cursorOffset >= nodeStart && cursorOffset <= nodeEnd - - if (isCursorInsideNode) { - symbolsFound.push(...this.yamlNodeToDocumentSymbols(node, range)) - } - } - }) - } - - const symbolStack = (symbolsFound.filter(Boolean) as DocumentSymbol[]).sort( - (a, b) => { - const offA = this.offsetAt(a.range.end) - this.offsetAt(a.range.start) - const offB = this.offsetAt(b.range.end) - this.offsetAt(b.range.start) - - return offB - offA // We want the tighter fits for last, so we can just pop them - } - ) - - return [...symbolStack] - } - - private hasProperty(path: string) { - const parsedObj = this.toJSON() - - return has(parsedObj, path) - } - - private getPropertyLocation(path: string) { - const pathArray = path.split('.') - - if (this.uri.endsWith('yaml')) { - const node = this.getYamlDocument().getIn(pathArray, true) - - if (isNode(node) && node.range) { - const [nodeStart, , nodeEnd] = node.range - const start = this.positionAt(nodeStart) - const end = this.positionAt(nodeEnd) - const range = Range.create(start, end) - return Location.create(this.uri, range) - } - - return null - } - - if (this.uri.endsWith('json')) { - const rootNode = parseTree(this.getText()) - const node = rootNode && findNodeAtLocation(rootNode, pathArray) - - if (!node) { - return null - } - const nodeSrcIndex = node.offset - const nodeSrcLength = node.length - const nodeEnd = nodeSrcIndex + nodeSrcLength - const start = this.positionAt(nodeSrcIndex) - const end = this.positionAt(nodeEnd) - const range = Range.create(start, end) - return Location.create(this.uri, range) - } - - return null - } - - private toJSON() { - if (this.uri.endsWith('yaml')) { - return this.getYamlDocument().toJS() - } - - if (this.uri.endsWith('json')) { - const src = this.getText() - return parse(src) - } - - return null + return this.languageHelper.findSymbolAtPosition(position) } } diff --git a/languageServer/src/languageHelpers/baseLanguageHelper.ts b/languageServer/src/languageHelpers/baseLanguageHelper.ts new file mode 100644 index 0000000000..ca44604979 --- /dev/null +++ b/languageServer/src/languageHelpers/baseLanguageHelper.ts @@ -0,0 +1,105 @@ +import { has } from 'lodash' +import { TextDocument } from 'vscode-languageserver-textdocument' +import { + DocumentSymbol, + Position, + Location, + SymbolKind +} from 'vscode-languageserver/node' + +export interface LanguageHelper { + findSymbolAtPosition(position: Position): DocumentSymbol | undefined + findLocationsFor(symbol: DocumentSymbol): Location[] +} + +export abstract class BaseLanguageHelper implements LanguageHelper { + protected textDocument: TextDocument + protected rootNode?: RootNode + + constructor(textDocument: TextDocument) { + this.textDocument = textDocument + this.rootNode = this.parse(this.getText()) + } + + public findSymbolAtPosition(position: Position): DocumentSymbol | undefined { + const cursorOffset: number = this.offsetAt(position) + const symbolsAroundOffset = this.findEnclosingSymbols(cursorOffset) + + const symbolStack = symbolsAroundOffset.sort((a, b) => { + const offA = this.offsetAt(a.range.end) - this.offsetAt(a.range.start) + const offB = this.offsetAt(b.range.end) - this.offsetAt(b.range.start) + + return offB - offA // We want the tighter fits for last, so we can just pop them + }) + + return [...symbolStack].pop() + } + + public findLocationsFor(symbol: DocumentSymbol): Location[] { + if (symbol.kind === SymbolKind.Property) { + return this.findLocationsForPropertySymbol(symbol) + } + + return this.findLocationsForNormalSymbol(symbol) + } + + protected getText() { + return this.textDocument.getText() + } + + protected offsetAt(position: Position) { + return this.textDocument.offsetAt(position) + } + + protected positionAt(offset: number) { + return this.textDocument.positionAt(offset) + } + + private findLocationsForPropertySymbol(symbol: DocumentSymbol) { + const propertyPath = symbol.detail + const itIsHere = propertyPath && this.hasProperty(propertyPath) + + if (itIsHere) { + const pathArray = propertyPath.split('.') + const location = this.getPropertyLocation(pathArray) + + return location ? [location] : [] + } + + return this.findLocationsForNormalSymbol(symbol) + } + + private findLocationsForNormalSymbol(symbol: DocumentSymbol) { + const parts = symbol.name.split(/\s/g) + const txt = this.getText() + + const acc: Location[] = [] + for (const str of parts) { + const index = txt.indexOf(str) + if (index <= 0) { + continue + } + const pos = this.positionAt(index) + const range = this.findSymbolAtPosition(pos)?.range + + if (range) { + acc.push(Location.create(this.textDocument.uri, range)) + } + } + return acc + } + + private hasProperty(path: string) { + const parsedObj = this.toJSON() + + return has(parsedObj, path) + } + + protected abstract parse(source: string): RootNode | undefined + protected abstract findEnclosingSymbols(offset: number): DocumentSymbol[] + protected abstract getPropertyLocation( + pathArray: Array + ): Location | null + + protected abstract toJSON(): unknown +} diff --git a/languageServer/src/languageHelpers/index.ts b/languageServer/src/languageHelpers/index.ts new file mode 100644 index 0000000000..b70bd67477 --- /dev/null +++ b/languageServer/src/languageHelpers/index.ts @@ -0,0 +1,18 @@ +import { TextDocument } from 'vscode-languageserver-textdocument' +import { JsonHelper } from './jsonHelper' +import { PlainTextHelper } from './plainTextHelper' +import { YamlHelper } from './yamlHelper' + +export const createLanguageHelper = (textDocument: TextDocument) => { + const language = textDocument.languageId + + if (language === 'yaml') { + return new YamlHelper(textDocument) + } + + if (language === 'json') { + return new JsonHelper(textDocument) + } + + return new PlainTextHelper(textDocument) +} diff --git a/languageServer/src/languageHelpers/jsonHelper.ts b/languageServer/src/languageHelpers/jsonHelper.ts new file mode 100644 index 0000000000..c0e01621cd --- /dev/null +++ b/languageServer/src/languageHelpers/jsonHelper.ts @@ -0,0 +1,34 @@ +import { findNodeAtLocation, getNodeValue, Node, parseTree } from 'jsonc-parser' +import { DocumentSymbol, Location, Range } from 'vscode-languageserver' +import { BaseLanguageHelper } from './baseLanguageHelper' + +export class JsonHelper extends BaseLanguageHelper { + protected parse(source: string) { + return parseTree(source) + } + + protected findEnclosingSymbols(): DocumentSymbol[] { + return [] + } + + protected getPropertyLocation( + pathArray: Array + ): Location | null { + const node = this.rootNode && findNodeAtLocation(this.rootNode, pathArray) + + if (!node) { + return null + } + const nodeSrcIndex = node.offset + const nodeSrcLength = node.length + const nodeEnd = nodeSrcIndex + nodeSrcLength + const start = this.positionAt(nodeSrcIndex) + const end = this.positionAt(nodeEnd) + const range = Range.create(start, end) + return Location.create(this.textDocument.uri, range) + } + + protected toJSON(): unknown { + return this.rootNode ? getNodeValue(this.rootNode) : undefined + } +} diff --git a/languageServer/src/languageHelpers/plainTextHelper.ts b/languageServer/src/languageHelpers/plainTextHelper.ts new file mode 100644 index 0000000000..0a216ca7c5 --- /dev/null +++ b/languageServer/src/languageHelpers/plainTextHelper.ts @@ -0,0 +1,19 @@ +import { BaseLanguageHelper } from './baseLanguageHelper' + +export class PlainTextHelper extends BaseLanguageHelper { + protected parse(source: string): string | undefined { + return source + } + + protected findEnclosingSymbols() { + return [] + } + + protected getPropertyLocation() { + return null + } + + protected toJSON(): unknown { + return this.rootNode + } +} diff --git a/languageServer/src/languageHelpers/regexes.ts b/languageServer/src/languageHelpers/regexes.ts new file mode 100644 index 0000000000..6a5ea5aa5c --- /dev/null +++ b/languageServer/src/languageHelpers/regexes.ts @@ -0,0 +1,4 @@ +export const variableTemplates = /\${([^}]+)}/g +export const filePaths = /[\d/A-Za-z]+\.[A-Za-z]+/g +export const alphadecimalWords = /[\dA-Za-z]+/g +export const propertyPathLike = /[\d.A-Z[\]a-z]+/g diff --git a/languageServer/src/languageHelpers/yamlHelper.ts b/languageServer/src/languageHelpers/yamlHelper.ts new file mode 100644 index 0000000000..cc2124215d --- /dev/null +++ b/languageServer/src/languageHelpers/yamlHelper.ts @@ -0,0 +1,191 @@ +import { + DocumentSymbol, + SymbolKind, + Range, + Location +} from 'vscode-languageserver/node' +import { + Document, + isNode, + isPair, + isScalar, + Pair, + parseDocument, + Scalar, + Node, + visit +} from 'yaml' +import { BaseLanguageHelper } from './baseLanguageHelper' +import * as RegExes from './regexes' + +export class YamlHelper extends BaseLanguageHelper { + protected rootNode!: Document + + protected findEnclosingSymbols(offset: number): DocumentSymbol[] { + const symbolsFound: Array = [] + + visit(this.rootNode, (_, node) => { + if (isNode(node) && node.range) { + const range = node.range + const nodeStart = range[0] + const nodeEnd = range[2] + const nodeContainsTheOffset = offset >= nodeStart && offset <= nodeEnd + + if (nodeContainsTheOffset) { + symbolsFound.push(...this.yamlNodeToDocumentSymbols(node, range)) + } + } + }) + + return symbolsFound + } + + protected parse(source: string) { + return parseDocument(source) + } + + protected getPropertyLocation( + pathArray: Array + ): Location | null { + const node = this.rootNode.getIn(pathArray, true) + + if (isNode(node) && node.range) { + const [nodeStart, , nodeEnd] = node.range + const start = this.positionAt(nodeStart) + const end = this.positionAt(nodeEnd) + const range = Range.create(start, end) + return Location.create(this.textDocument.uri, range) + } + + return null + } + + protected toJSON(): unknown { + return this.rootNode.toJS() + } + + private yamlNodeToDocumentSymbols( + node: Node | Pair, + range: [number, number, number] + ): DocumentSymbol[] { + if (isScalar(node)) { + return this.yamlScalarNodeToDocumentSymbols(node, range) + } + + if (isPair(node)) { + return this.yamlNodeToDocumentSymbols(node.value as Node | Pair, range) + } + + return [] + } + + private yamlScalarNodeToDocumentSymbols( + node: Scalar, + [nodeStart, valueEnd, nodeEnd]: [number, number, number] + ) { + const nodeValue = `${node.value}` + + let symbolKind: SymbolKind = SymbolKind.String + + if (/\.[A-Za-z]+$/.test(nodeValue)) { + symbolKind = SymbolKind.File + } + + const children: DocumentSymbol[] = [] + + const variableTemplateSymbols = [ + ...this.getTemplateExpressionSymbolsInsideScalar(nodeValue, nodeStart) + ] + + if (variableTemplateSymbols.length > 0) { + children.push(...variableTemplateSymbols) + } else { + const propertyPathSymbols = this.extractPropertyPathSymbolsFrom( + nodeValue, + nodeStart + ) + children.push(...propertyPathSymbols) + } + + const symbolsSoFar: DocumentSymbol[] = [ + DocumentSymbol.create( + nodeValue, + undefined, + symbolKind, + Range.create(this.positionAt(nodeStart), this.positionAt(nodeEnd)), + Range.create(this.positionAt(nodeStart), this.positionAt(valueEnd)) + ), + ...children + ] + + return symbolsSoFar + } + + private extractPropertyPathSymbolsFrom(text: string, startIndex: number) { + const symbols: DocumentSymbol[] = [] + const pathLikeSegments = text.matchAll(RegExes.propertyPathLike) + + for (const path of pathLikeSegments) { + const matchIndex = path.index ?? 0 + + symbols.push( + ...this.getSymbolsFromPropertyPath(path[0], startIndex + matchIndex) + ) + } + + return symbols + } + + private getSymbolsFromPropertyPath(pathSegment: string, startIndex: number) { + const templateSymbols: DocumentSymbol[] = [] + const symbols = pathSegment.matchAll(RegExes.alphadecimalWords) + + const jsonPath: string[] = [] // Safe to assume, based on https://dvc.org/doc/user-guide/project-structure/dvcyaml-files#vars + + for (const templateSymbol of symbols) { + const symbolName = templateSymbol[0] + const symbolJsonPath = [...jsonPath, symbolName] + const symbolStart = (templateSymbol.index ?? 0) + startIndex + const symbolEnd = symbolStart + templateSymbol[0].length + const symbolRange = Range.create( + this.positionAt(symbolStart), + this.positionAt(symbolEnd) + ) + + templateSymbols.push( + DocumentSymbol.create( + templateSymbol[0], + symbolJsonPath.join('.'), + SymbolKind.Property, + symbolRange, + symbolRange + ) + ) + + jsonPath.push(symbolName) + } + + return templateSymbols + } + + private getTemplateExpressionSymbolsInsideScalar( + scalarValue: string, + nodeOffset: number + ) { + const templateSymbols: DocumentSymbol[] = [] + + const matches = scalarValue.matchAll(RegExes.variableTemplates) + + for (const match of matches) { + const expression = match[1] + const matchOffset = match.index || 0 + const expressionOffset: number = nodeOffset + matchOffset + 2 // To account for the '${' + + templateSymbols.push( + ...this.extractPropertyPathSymbolsFrom(expression, expressionOffset) + ) + } + + return templateSymbols + } +} diff --git a/languageServer/src/test/definitions.test.ts b/languageServer/src/test/definitions.test.ts index f6021183e7..c8294d8a0c 100644 --- a/languageServer/src/test/definitions.test.ts +++ b/languageServer/src/test/definitions.test.ts @@ -1,9 +1,5 @@ import { Position, Range } from 'vscode-languageserver/node' -import { - foreach_dvc_yaml, - params_dvc_yaml, - vars_dvc_yaml -} from './fixtures/examples/valid' +import { foreach_dvc_yaml, params_dvc_yaml } from './fixtures/examples/valid' import { params } from './fixtures/params' import { requestDefinitions } from './utils/requestDefinitions' import { openTheseFilesAndNotifyServer } from './utils/openTheseFilesAndNotifyServer' @@ -11,7 +7,6 @@ import { disposeTestConnections, setupTestConnections } from './utils/setup-test-connections' -import { sendTheseFilesToServer } from './utils/sendTheseFilesToServer' describe('textDocument/definitions', () => { beforeEach(() => { @@ -53,7 +48,7 @@ describe('textDocument/definitions', () => { expect(response).toBeTruthy() expect(response).toStrictEqual({ - range: Range.create(Position.create(4, 0), Position.create(4, 3)), + range: Range.create(Position.create(3, 5), Position.create(4, 0)), uri: 'file:///params.yaml' }) }) @@ -74,26 +69,4 @@ describe('textDocument/definitions', () => { uri: 'file:///dvc.yaml' }) }) - - it('should provide a single location that points to the top of the file path symbol', async () => { - const [dvcYaml] = await sendTheseFilesToServer([ - { - languageId: 'yaml', - mockContents: vars_dvc_yaml, - mockPath: 'dvc.yaml' - }, - { - languageId: 'json', - mockContents: '', - mockPath: 'params.json' - } - ]) - const response = await requestDefinitions(dvcYaml, 'params.json') - - expect(response).toBeTruthy() - expect(response).toStrictEqual({ - range: Range.create(Position.create(0, 0), Position.create(0, 0)), - uri: 'file:///params.json' - }) - }) }) diff --git a/languageServer/src/test/fixtures/params/index.ts b/languageServer/src/test/fixtures/params/index.ts index 069c7cbe27..94dd2eb7bf 100644 --- a/languageServer/src/test/fixtures/params/index.ts +++ b/languageServer/src/test/fixtures/params/index.ts @@ -4,4 +4,4 @@ weight_decay: 0 epochs: 15 auc: 0.9 loss: 0.2 -` +`.trim()