Skip to content

Commit

Permalink
fix: semantic tokens performance (#3932)
Browse files Browse the repository at this point in the history
While implementing #3925, I noticed that the performance of the
`textDocument/semanticTokens/full` request is *extremely* bad due to a
quadratic implementation. Specifically, on my machine, computing the
full semantic tokens for `Lean/Elab/Do.lean` took a full 5s. In
practice, this means that while elaborating the file, one core is
entirely busy with computing the semantic tokens for the file.

This PR fixes this performance bug by re-implementing the semantic token
handling, reducing the latency for `Lean/Elab/Do.lean` from 5s to 60ms.
As a result, the overly cautious refresh latency of 5s in #3925 can
easily be reduced to 2s again.

Since the previous semantic tokens implementation used a very brittle
hack to identify projections, this PR also changes the projection
notation elaboration to augment the `InfoTree` syntax for the field of a
projection with a special syntax node of kind
`Lean.Parser.Term.identProjKind`. With this syntax kind, projection
fields can now easily be identified in the `InfoTree`.
  • Loading branch information
mhuisi authored Apr 18, 2024
1 parent 11ff004 commit faa4d16
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/Lean/Data/Lsp/LanguageFeatures.lean
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ inductive SemanticTokenType where
| decorator
-- Extensions
| leanSorryLike
deriving ToJson, FromJson
deriving ToJson, FromJson, BEq, Hashable

-- must be in the same order as the constructors
def SemanticTokenType.names : Array String :=
Expand Down
17 changes: 14 additions & 3 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,17 @@ private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Ar
argIdx := argIdx + 1
throwError "invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, it must be explicit or implicit with a unique name"

/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
private def addProjTermInfo
(stx : Syntax)
(e : Expr)
(expectedType? : Option Expr := none)
(lctx? : Option LocalContext := none)
(elaborator : Name := Name.anonymous)
(isBinder force : Bool := false)
: TermElabM Expr :=
addTermInfo (Syntax.node .none Parser.Term.identProjKind #[stx]) e expectedType? lctx? elaborator isBinder force

private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit ellipsis : Bool)
(f : Expr) (lvals : List LVal) : TermElabM Expr :=
let rec loop : Expr → List LVal → TermElabM Expr
Expand All @@ -1214,7 +1225,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
if isPrivateNameFromImportedModule (← getEnv) info.projFn then
throwError "field '{fieldName}' from structure '{structName}' is private"
let projFn ← mkConst info.projFn
let projFn ← addTermInfo lval.getRef projFn
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f }
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
Expand All @@ -1226,7 +1237,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
| LValResolution.const baseStructName structName constName =>
let f ← if baseStructName != structName then mkBaseProjections baseStructName structName f else pure f
let projFn ← mkConst constName
let projFn ← addTermInfo lval.getRef projFn
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let projFnType ← inferType projFn
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFnType
Expand All @@ -1235,7 +1246,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.localRec baseName fullName fvar =>
let fvar ← addTermInfo lval.getRef fvar
let fvar ← addProjTermInfo lval.getRef fvar
if lvals.isEmpty then
let fvarType ← inferType fvar
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvarType
Expand Down
11 changes: 11 additions & 0 deletions src/Lean/Parser/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,17 @@ is short for accessing the `i`-th field (1-indexed) of `e` if it is of a structu
@[builtin_term_parser] def arrow := trailing_parser
checkPrec 25 >> unicodeSymbol " → " " -> " >> termParser 25

/--
Syntax kind for syntax nodes representing the field of a projection in the `InfoTree`.
Specifically, the `InfoTree` node for a projection `s.f` contains a child `InfoTree` node
with syntax ``(Syntax.node .none identProjKind #[`f])``.
This is necessary because projection syntax cannot always be detected purely syntactically
(`s.f` may refer to either the identifier `s.f` or a projection `s.f` depending on
the available context).
-/
def identProjKind := `Lean.Parser.Term.identProj

def isIdent (stx : Syntax) : Bool :=
-- antiquotations should also be allowed where an identifier is expected
stx.isAntiquot || stx.isIdent
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Server/FileWorker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def runRefreshTask : WorkerM (Task (Except IO.Error Unit)) := do
IO.sleep 1000
continue
sendServerRequest ctx "workspace/semanticTokens/refresh" (none : Option Nat)
IO.sleep 5000
IO.sleep 2000

def initAndRunWorker (i o e : FS.Stream) (opts : Options) : IO UInt32 := do
let i ← maybeTee "fwIn.txt" false i
Expand Down
204 changes: 125 additions & 79 deletions src/Lean/Server/FileWorker/RequestHandling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ where
return toDocumentSymbols text stxs (syms.push sym) stack
toDocumentSymbols text stxs syms stack

/--
`SyntaxNodeKind`s for which the syntax node and its children receive no semantic highlighting.
-/
def noHighlightKinds : Array SyntaxNodeKind := #[
-- usually have special highlighting by the client
``Lean.Parser.Term.sorry,
Expand All @@ -429,25 +432,121 @@ def noHighlightKinds : Array SyntaxNodeKind := #[
``Lean.Parser.Command.docComment,
``Lean.Parser.Command.moduleDoc]

structure SemanticTokensContext where
beginPos : String.Pos
endPos? : Option String.Pos
text : FileMap
snap : Snapshot

structure SemanticTokensState where
data : Array Nat
lastLspPos : Lsp.Position

-- TODO: make extensible, or don't
/-- Keywords for which a specific semantic token is provided. -/
def keywordSemanticTokenMap : RBMap String SemanticTokenType compare :=
RBMap.empty
|>.insert "sorry" .leanSorryLike
|>.insert "admit" .leanSorryLike
|>.insert "stop" .leanSorryLike
|>.insert "#exit" .leanSorryLike

partial def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
/-- Semantic token information for a given `Syntax`. -/
structure LeanSemanticToken where
/-- Syntax of the semantic token. -/
stx : Syntax
/-- Type of the semantic token. -/
type : SemanticTokenType

/-- Semantic token information with absolute LSP positions. -/
structure AbsoluteLspSemanticToken where
/-- Start position of the semantic token. -/
pos : Lsp.Position
/-- End position of the semantic token. -/
tailPos : Lsp.Position
/-- Start position of the semantic token. -/
type : SemanticTokenType
deriving BEq, Hashable, FromJson, ToJson

/--
Given a set of `LeanSemanticToken`, computes the `AbsoluteLspSemanticToken` with absolute
LSP position information for each token.
-/
def computeAbsoluteLspSemanticTokens
(text : FileMap)
(beginPos : String.Pos)
(endPos? : Option String.Pos)
(tokens : Array LeanSemanticToken)
: Array AbsoluteLspSemanticToken :=
tokens.filterMap fun ⟨stx, type⟩ => do
let (pos, tailPos) := (← stx.getPos?, ← stx.getTailPos?)
guard <| beginPos <= pos && endPos?.all (pos < ·)
let (lspPos, lspTailPos) := (text.utf8PosToLspPos pos, text.utf8PosToLspPos tailPos)
return ⟨lspPos, lspTailPos, type⟩

/-- Filters all duplicate semantic tokens with the same `pos`, `tailPos` and `type`. -/
def filterDuplicateSemanticTokens (tokens : Array AbsoluteLspSemanticToken) : Array AbsoluteLspSemanticToken :=
tokens.groupByKey id |>.toArray.map (·.1)

/--
Given a set of `AbsoluteLspSemanticToken`, computes the LSP `SemanticTokens` data with
token-relative positioning.
See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_semanticTokens.
-/
def computeDeltaLspSemanticTokens (tokens : Array AbsoluteLspSemanticToken) : SemanticTokens := Id.run do
let tokens := tokens.qsort fun ⟨pos1, tailPos1, _⟩ ⟨pos2, tailPos2, _⟩ =>
pos1 < pos2 || pos1 == pos2 && tailPos1 <= tailPos2
let mut data : Array Nat := Array.mkEmpty (5*tokens.size)
let mut lastPos : Lsp.Position := ⟨0, 0
for ⟨pos, tailPos, type⟩ in tokens do
let deltaLine := pos.line - lastPos.line
let deltaStart := pos.character - (if pos.line == lastPos.line then lastPos.character else 0)
let length := tailPos.character - pos.character
let tokenType := type.toNat
let tokenModifiers := 0
data := data ++ #[deltaLine, deltaStart, length, tokenType, tokenModifiers]
lastPos := pos
return { data }

/--
Collects all semantic tokens that can be deduced purely from `Syntax`
without elaboration information.
-/
partial def collectSyntaxBasedSemanticTokens : (stx : Syntax) → Array LeanSemanticToken
| `($e.$id:ident) =>
let tokens := collectSyntaxBasedSemanticTokens e
tokens.push ⟨id, SemanticTokenType.property⟩
| `($e |>.$field:ident) =>
let tokens := collectSyntaxBasedSemanticTokens e
tokens.push ⟨field, SemanticTokenType.property⟩
| stx => Id.run do
if noHighlightKinds.contains stx.getKind then
return #[]
let mut tokens :=
if stx.isOfKind choiceKind then
collectSyntaxBasedSemanticTokens stx[0]
else
stx.getArgs.map collectSyntaxBasedSemanticTokens |>.flatten
let Syntax.atom _ val := stx
| return tokens
let isRegularKeyword := val.length > 0 && val.front.isAlpha
let isHashKeyword := val.length > 1 && val.front == '#' && (val.get ⟨1⟩).isAlpha
if ! isRegularKeyword && ! isHashKeyword then
return tokens
return tokens.push ⟨stx, keywordSemanticTokenMap.findD val .keyword⟩

/-- Collects all semantic tokens from the given `Elab.InfoTree`. -/
def collectInfoBasedSemanticTokens (i : Elab.InfoTree) : Array LeanSemanticToken :=
List.toArray <| i.deepestNodes fun _ i _ => do
let .ofTermInfo ti := i
| none
let .original .. := ti.stx.getHeadInfo
| none
if let `($_:ident) := ti.stx then
if let Expr.fvar fvarId .. := ti.expr then
if let some localDecl := ti.lctx.find? fvarId then
-- Recall that `isAuxDecl` is an auxiliary declaration used to elaborate a recursive definition.
if localDecl.isAuxDecl then
if ti.isBinder then
return ⟨ti.stx, SemanticTokenType.function⟩
else if ! localDecl.isImplementationDetail then
return ⟨ti.stx, SemanticTokenType.variable⟩
if ti.stx.getKind == Parser.Term.identProjKind then
return ⟨ti.stx, SemanticTokenType.property⟩
none

/-- Computes the semantic tokens in the range [beginPos, endPos?). -/
def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
: RequestM (RequestTask SemanticTokens) := do
let doc ← readDoc
match endPos? with
Expand All @@ -462,78 +561,25 @@ partial def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option Strin
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => run doc snaps
where
run doc snaps : RequestM SemanticTokens :=
StateT.run' (s := { data := #[], lastLspPos := ⟨0, 0⟩ : SemanticTokensState }) do
for s in snaps do
if s.endPos <= beginPos then
continue
ReaderT.run (r := SemanticTokensContext.mk beginPos endPos? doc.meta.text s) <|
go s.stx
return { data := (← get).data }
go (stx : Syntax) := do
match stx with
| `($e.$id:ident) => go e; addToken id SemanticTokenType.property
-- indistinguishable from next pattern
--| `(level|$id:ident) => addToken id SemanticTokenType.variable
| `($id:ident) => highlightId id
| _ =>
if !noHighlightKinds.contains stx.getKind then
highlightKeyword stx
if stx.isOfKind choiceKind then
go stx[0]
else
stx.getArgs.forM go
highlightId (stx : Syntax) : ReaderT SemanticTokensContext (StateT SemanticTokensState RequestM) _ := do
if let some range := stx.getRange? then
let mut lastPos := range.start
for ti in (← read).snap.infoTree.deepestNodes (fun
| _, i@(Elab.Info.ofTermInfo ti), _ => match i.pos? with
| some ipos => if range.contains ipos then some ti else none
| _ => none
| _, _, _ => none) do
let pos := ti.stx.getPos?.get!
-- avoid reporting same position twice; the info node can occur multiple times if
-- e.g. the term is elaborated multiple times
if pos < lastPos then
continue
if let Expr.fvar fvarId .. := ti.expr then
if let some localDecl := ti.lctx.find? fvarId then
-- Recall that `isAuxDecl` is an auxiliary declaration used to elaborate a recursive definition.
if localDecl.isAuxDecl then
if ti.isBinder then
addToken ti.stx SemanticTokenType.function
else
addToken ti.stx SemanticTokenType.variable
else if ti.stx.getPos?.get! > lastPos then
-- any info after the start position: must be projection notation
addToken ti.stx SemanticTokenType.property
lastPos := ti.stx.getPos?.get!
highlightKeyword stx := do
if let Syntax.atom _ val := stx then
if (val.length > 0 && val.front.isAlpha) ||
-- Support for keywords of the form `#<alpha>...`
(val.length > 1 && val.front == '#' && (val.get ⟨1⟩).isAlpha) then
addToken stx (keywordSemanticTokenMap.findD val .keyword)
addToken stx type := do
let ⟨beginPos, endPos?, text, _⟩ ← read
if let (some pos, some tailPos) := (stx.getPos?, stx.getTailPos?) then
if beginPos <= pos && endPos?.all (pos < ·) then
let lspPos := (← get).lastLspPos
let lspPos' := text.utf8PosToLspPos pos
let deltaLine := lspPos'.line - lspPos.line
let deltaStart := lspPos'.character - (if lspPos'.line == lspPos.line then lspPos.character else 0)
let length := (text.utf8PosToLspPos tailPos).character - lspPos'.character
let tokenType := type.toNat
let tokenModifiers := 0
modify fun st => {
data := st.data ++ #[deltaLine, deltaStart, length, tokenType, tokenModifiers]
lastLspPos := lspPos'
}

run doc snaps : RequestM SemanticTokens := do
let mut leanSemanticTokens := #[]
for s in snaps do
if s.endPos <= beginPos then
continue
let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx
let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree
leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens
let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens
let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens
let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens
return semanticTokens

/-- Computes all semantic tokens for the document. -/
def handleSemanticTokensFull (_ : SemanticTokensParams)
: RequestM (RequestTask SemanticTokens) := do
handleSemanticTokens 0 none

/-- Computes the semantic tokens in the range provided by `p`. -/
def handleSemanticTokensRange (p : SemanticTokensRangeParams)
: RequestM (RequestTask SemanticTokens) := do
let doc ← readDoc
Expand Down

0 comments on commit faa4d16

Please sign in to comment.