From d120e3385466403a50b56c7b8bfb4bfcdc03d193 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Fri, 3 May 2024 09:11:59 +0100 Subject: [PATCH] fix: multi-byte character positions in LSP (fixes #482) (#712) --- .version | 2 +- cmd/templ/lspcmd/lsp_test.go | 160 ++++++++++++++++--- cmd/templ/lspcmd/proxy/server.go | 13 +- cmd/templ/lspcmd/testdata/templates.templ | 2 + cmd/templ/lspcmd/testdata/templates_templ.go | 2 + generator/rangewriter.go | 10 +- generator/rangewriter_test.go | 4 +- parser/v2/sourcemap.go | 16 +- parser/v2/sourcemap_test.go | 44 +++-- 9 files changed, 202 insertions(+), 51 deletions(-) diff --git a/.version b/.version index f279e70ea..fe4758d9e 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.2.668 +0.2.670 \ No newline at end of file diff --git a/cmd/templ/lspcmd/lsp_test.go b/cmd/templ/lspcmd/lsp_test.go index 7d26c0589..0803183d3 100644 --- a/cmd/templ/lspcmd/lsp_test.go +++ b/cmd/templ/lspcmd/lsp_test.go @@ -12,6 +12,7 @@ import ( "sync" "testing" "time" + "unicode/utf8" "github.com/a-h/protocol" "github.com/a-h/templ/cmd/templ/generatecmd/modcheck" @@ -53,6 +54,12 @@ func createTestProject(moduleRoot string) (dir string, err error) { return dir, nil } +func mustReplaceLine(file string, line int, replacement string) string { + lines := strings.Split(file, "\n") + lines[line-1] = replacement + return strings.Join(lines, "\n") +} + func TestCompletion(t *testing.T) { if testing.Short() { return @@ -137,7 +144,7 @@ func TestCompletion(t *testing.T) { for i, test := range tests { t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { // Edit the file. - updated := strings.ReplaceAll(string(templFile), `
{ fmt.Sprintf("%d", count) }
`, test.replacement) + updated := mustReplaceLine(string(templFile), test.line, test.replacement) err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{ TextDocument: protocol.VersionedTextDocumentIdentifier{ TextDocumentIdentifier: protocol.TextDocumentIdentifier{ @@ -173,7 +180,7 @@ func TestCompletion(t *testing.T) { // Positions are zero indexed. Position: protocol.Position{ Line: uint32(test.line - 1), - Character: uint32(len(test.replacement) - 1), + Character: uint32(len(test.cursor) - 1), }, }, }) @@ -227,30 +234,143 @@ func TestHover(t *testing.T) { return } log.Info("Calling hover") - hr, err := server.Hover(ctx, &protocol.HoverParams{ - TextDocumentPositionParams: protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: uri.URI("file://" + appDir + "/templates.templ"), + + // Edit the file. + // Replace: + //
{ fmt.Sprintf("%d", count) }
+ // With various tests: + //
{ f + tests := []struct { + line int + replacement string + cursor string + assert func(t *testing.T, hr *protocol.Hover) (msg string, ok bool) + }{ + { + line: 13, + replacement: `
{ fmt.Sprintf("%d", count) }
`, + cursor: ` ^`, + assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) { + expectedHover := protocol.Hover{ + Contents: protocol.MarkupContent{ + Kind: "markdown", + Value: "```go\npackage fmt\n```\n\n[`fmt` on pkg.go.dev](https://pkg.go.dev/fmt)", + }, + } + if diff := lspdiff.Hover(expectedHover, *actual); diff != "" { + return fmt.Sprintf("unexpected hover: %v\n\n: markdown: %#v", diff, actual.Contents.Value), false + } + return "", true + }, + }, + { + line: 13, + replacement: `
{ fmt.Sprintf("%d", count) }
`, + cursor: ` ^`, + assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) { + expectedHover := protocol.Hover{ + Contents: protocol.MarkupContent{ + Kind: "markdown", + Value: "```go\nfunc fmt.Sprintf(format string, a ...any) string\n```\n\nSprintf formats according to a format specifier and returns the resulting string.\n\n\n[`fmt.Sprintf` on pkg.go.dev](https://pkg.go.dev/fmt#Sprintf)", + }, + } + if diff := lspdiff.Hover(expectedHover, *actual); diff != "" { + return fmt.Sprintf("unexpected hover: %v", diff), false + } + return "", true }, - Position: protocol.Position{ - Line: 12, - Character: 34, + }, + { + line: 19, + replacement: `var nihao = "你好"`, + cursor: ` ^`, + assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) { + // There's nothing to hover, just want to make sure it doesn't panic. + return "", true }, }, - }) - if err != nil { - t.Errorf("failed to get hover: %v", err) - } - expectedHover := protocol.Hover{ - Contents: protocol.MarkupContent{ - Kind: "markdown", - Value: "```go\nfunc fmt.Sprintf(format string, a ...any) string\n```\n\nSprintf formats according to a format specifier and returns the resulting string.\n\n\n[`fmt.Sprintf` on pkg.go.dev](https://pkg.go.dev/fmt#Sprintf)", + { + line: 19, + replacement: `var nihao = "你好"`, + cursor: ` ^`, // Your text editor might not render this well, but it's the hao. + assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) { + // There's nothing to hover, just want to make sure it doesn't panic. + return "", true + }, }, } - if diff := lspdiff.Hover(expectedHover, *hr); diff != "" { - t.Errorf("unexpected hover: %v", diff) - return + + for i, test := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + // Put the file back to the initial point. + err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{ + TextDocument: protocol.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: protocol.TextDocumentIdentifier{ + URI: uri.URI("file://" + appDir + "/templates.templ"), + }, + Version: int32(i + 2), + }, + ContentChanges: []protocol.TextDocumentContentChangeEvent{ + { + Range: nil, + Text: string(templFile), + }, + }, + }) + if err != nil { + t.Errorf("failed to change file: %v", err) + return + } + + // Give CI/CD pipeline executors some time because they're often quite slow. + var ok bool + var msg string + for i := 0; i < 3; i++ { + lspCharIndex, err := runeIndexToUTF8ByteIndex(test.replacement, len(test.cursor)-1) + if err != nil { + t.Error(err) + } + actual, err := server.Hover(ctx, &protocol.HoverParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: uri.URI("file://" + appDir + "/templates.templ"), + }, + // Positions are zero indexed. + Position: protocol.Position{ + Line: uint32(test.line - 1), + Character: lspCharIndex, + }, + }, + }) + if err != nil { + t.Errorf("failed to hover: %v", err) + return + } + msg, ok = test.assert(t, actual) + if !ok { + break + } + time.Sleep(time.Millisecond * 500) + } + if !ok { + t.Error(msg) + } + }) + } +} + +func runeIndexToUTF8ByteIndex(s string, runeIndex int) (lspChar uint32, err error) { + for i, r := range []rune(s) { + if i == runeIndex { + break + } + l := utf8.RuneLen(r) + if l < 0 { + return 0, fmt.Errorf("invalid rune in string at index %d", runeIndex) + } + lspChar += uint32(l) } + return lspChar, nil } func NewTestClient(log *zap.Logger) TestClient { diff --git a/cmd/templ/lspcmd/proxy/server.go b/cmd/templ/lspcmd/proxy/server.go index 662b689e2..f993e2a07 100644 --- a/cmd/templ/lspcmd/proxy/server.go +++ b/cmd/templ/lspcmd/proxy/server.go @@ -75,6 +75,7 @@ func (p *Server) updatePosition(templURI lsp.DocumentURI, current lsp.Position) zap.String("toGo", fmt.Sprintf("%d:%d", to.Line, to.Col))) updated.Line = to.Line updated.Character = to.Col + return true, goURI, updated } @@ -761,6 +762,7 @@ func (p *Server) Hover(ctx context.Context, params *lsp.HoverParams) (result *ls if !ok { return nil, nil } + // Call gopls. result, err = p.Target.Hover(ctx, params) if err != nil { return @@ -769,7 +771,7 @@ func (p *Server) Hover(ctx context.Context, params *lsp.HoverParams) (result *ls if result != nil && result.Range != nil { p.Log.Info("hover: result returned") r := p.convertGoRangeToTemplRange(templURI, *result.Range) - p.Log.Info("hover: setting range", zap.Any("range", r)) + p.Log.Info("hover: setting range") result.Range = &r } return @@ -882,11 +884,10 @@ func (p *Server) References(ctx context.Context, params *lsp.ReferenceParams) (r defer p.Log.Info("client -> server: References end") templURI := params.TextDocument.URI // Rewrite the request. - var isTemplURI bool - isTemplURI, params.TextDocument.URI = convertTemplToGoURI(params.TextDocument.URI) - if !isTemplURI { - err = fmt.Errorf("not a templ file") - return + var ok bool + ok, params.TextDocument.URI, params.Position = p.updatePosition(params.TextDocument.URI, params.Position) + if !ok { + return nil, nil } // Call gopls. result, err = p.Target.References(ctx, params) diff --git a/cmd/templ/lspcmd/testdata/templates.templ b/cmd/templ/lspcmd/testdata/templates.templ index c6c039405..966ca9065 100644 --- a/cmd/templ/lspcmd/testdata/templates.templ +++ b/cmd/templ/lspcmd/testdata/templates.templ @@ -15,3 +15,5 @@ templ Page(count int) { } + +var nihao = "你好" diff --git a/cmd/templ/lspcmd/testdata/templates_templ.go b/cmd/templ/lspcmd/testdata/templates_templ.go index f499db39a..a6e2a0c2a 100644 --- a/cmd/templ/lspcmd/testdata/templates_templ.go +++ b/cmd/templ/lspcmd/testdata/templates_templ.go @@ -47,3 +47,5 @@ func Page(count int) templ.Component { return templ_7745c5c3_Err }) } + +var nihao = "你好" diff --git a/generator/rangewriter.go b/generator/rangewriter.go index 9df933e05..d028251c9 100644 --- a/generator/rangewriter.go +++ b/generator/rangewriter.go @@ -4,6 +4,7 @@ import ( "io" "strconv" "strings" + "unicode/utf8" "github.com/a-h/templ/parser/v2" ) @@ -126,15 +127,16 @@ func (rw *RangeWriter) write(s string) (r parser.Range, err error) { Line: rw.Current.Line, Col: rw.Current.Col, } - var n int + utf8Bytes := make([]byte, 4) for _, c := range s { - rw.Current.Col++ + rlen := utf8.EncodeRune(utf8Bytes, c) + rw.Current.Col += uint32(rlen) if c == '\n' { rw.Current.Line++ rw.Current.Col = 0 } - n, err = io.WriteString(rw.w, string(c)) - rw.Current.Index += int64(n) + _, err = rw.w.Write(utf8Bytes[:rlen]) + rw.Current.Index += int64(rlen) if err != nil { return r, err } diff --git a/generator/rangewriter_test.go b/generator/rangewriter_test.go index 350568042..05ee8aea8 100644 --- a/generator/rangewriter_test.go +++ b/generator/rangewriter_test.go @@ -32,11 +32,11 @@ func TestRangeWriter(t *testing.T) { t.Error(diff) } }) - t.Run("multi-byte characters count as a single column position", func(t *testing.T) { + t.Run("multi-byte characters count as 3, because that's their UTF8 representation", func(t *testing.T) { if _, err := rw.Write("\n你"); err != nil { t.Fatalf("failed to write: %v", err) } - if diff := cmp.Diff(parser.NewPosition(9, 2, 1), rw.Current); diff != "" { + if diff := cmp.Diff(parser.NewPosition(9, 2, 3), rw.Current); diff != "" { t.Error(diff) } }) diff --git a/parser/v2/sourcemap.go b/parser/v2/sourcemap.go index 95f4026df..d342e31e6 100644 --- a/parser/v2/sourcemap.go +++ b/parser/v2/sourcemap.go @@ -2,6 +2,7 @@ package parser import ( "strings" + "unicode/utf8" ) // NewSourceMap creates a new lookup to map templ source code to items in the @@ -36,7 +37,7 @@ func (sm *SourceMap) Add(src Expression, tgt Range) (updatedFrom Position) { } // Process the cols. - for colIndex := 0; colIndex < len(line); colIndex++ { + for _, r := range line { if _, ok := sm.SourceLinesToTarget[srcLine]; !ok { sm.SourceLinesToTarget[srcLine] = make(map[uint32]Position) } @@ -47,10 +48,15 @@ func (sm *SourceMap) Add(src Expression, tgt Range) (updatedFrom Position) { } sm.TargetLinesToSource[tgtLine][tgtCol] = NewPosition(srcIndex, srcLine, srcCol) - srcCol++ - tgtCol++ - srcIndex++ - tgtIndex++ + // Ignore invalid runes. + rlen := utf8.RuneLen(r) + if rlen < 0 { + rlen = 1 + } + srcCol += uint32(rlen) + tgtCol += uint32(rlen) + srcIndex += int64(rlen) + tgtIndex += int64(rlen) } // LSPs include the newline char as a col. diff --git a/parser/v2/sourcemap_test.go b/parser/v2/sourcemap_test.go index 1b17b377e..362f06f76 100644 --- a/parser/v2/sourcemap_test.go +++ b/parser/v2/sourcemap_test.go @@ -10,19 +10,19 @@ import ( // Test data. // -// | 0 1 2 3 4 5 6 7 8 9 -// -// - - - - - - - - - - - -// 0 | -// 1 | a b c d e f g h i -// 2 | j k l m n o -// 3 | p q r s t u v -// 4 | -// 5 | w x y -// 6 | z -// 7 | m u l t i -// 8 | l i n e -// 9 | m a t c h +// | - | 0 1 2 3 4 5 6 7 8 9 +// | - | - - - - - - - - - +// | 0 | +// | 1 | a b c d e f g h i +// | 2 | j k l m n o +// | 3 | p q r s t u v +// | 4 | +// | 5 | w x y +// | 6 | z +// | 7 | m u l t i +// | 8 | l i n e +// | 9 | m a t c h +// | 10 | 生 日 快 乐 func pos(index, line, col int) parse.Position { return parse.Position{ Index: index, @@ -113,6 +113,24 @@ func TestSourceMapPosition(t *testing.T) { source: NewPosition(11, 2, 0), // m (atch) target: NewPosition(12, 3, 0), }, + { + name: "unicode characters are indexed correctly (sheng)", + setup: func(sm *SourceMap) { + sm.Add(NewExpression("生日快乐", pos(0, 10, 0), pos(12, 10, 4)), + Range{From: NewPosition(1, 11, 1), To: NewPosition(13, 11, 5)}) + }, + source: NewPosition(0, 10, 0), // 生 + target: NewPosition(1, 11, 1), + }, + { + name: "unicode characters are indexed correctly (ri)", + setup: func(sm *SourceMap) { + sm.Add(NewExpression("生日快乐", pos(0, 10, 0), pos(12, 10, 4)), + Range{From: NewPosition(1, 11, 1), To: NewPosition(13, 11, 5)}) + }, + source: NewPosition(3, 10, 3), // 日 + target: NewPosition(4, 11, 4), + }, } for _, tt := range tests { tt := tt