diff --git a/go.mod b/go.mod index c2da4ee..d451477 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/remoteoss/dexter go 1.26.1 require ( + github.com/google/go-cmp v0.5.6 github.com/mattn/go-sqlite3 v1.14.38 + github.com/phoenixframework/tree-sitter-heex v0.9.0 github.com/spf13/cobra v1.10.2 github.com/tree-sitter/go-tree-sitter v0.25.0 github.com/tree-sitter/tree-sitter-elixir v0.3.5 diff --git a/go.sum b/go.sum index a8ff704..c213a9f 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc= github.com/mattn/go-sqlite3 v1.14.38 h1:tDUzL85kMvOrvpCt8P64SbGgVFtJB11GPi2AdmITgb4= github.com/mattn/go-sqlite3 v1.14.38/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/phoenixframework/tree-sitter-heex v0.9.0 h1:19d/KenCYoturUoMq+fY5LXTwPhe5msaOx9cHGnPUj0= +github.com/phoenixframework/tree-sitter-heex v0.9.0/go.mod h1:ul+VP/WJ7qS+DPlkr15hyBrzYd1D1rvmyEKmw/7lGOQ= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/lsp/documents.go b/internal/lsp/documents.go index a1874b8..be699a4 100644 --- a/internal/lsp/documents.go +++ b/internal/lsp/documents.go @@ -7,10 +7,10 @@ import ( "sync" tree_sitter "github.com/tree-sitter/go-tree-sitter" - tree_sitter_elixir "github.com/tree-sitter/tree-sitter-elixir/bindings/go" "go.lsp.dev/protocol" "github.com/remoteoss/dexter/internal/parser" + "github.com/remoteoss/dexter/internal/treesitter" ) // defaultMaxTransient caps how many disk-loaded buffers may live in the @@ -45,7 +45,7 @@ type cachedDoc struct { // tree for free and only triggers ts_tree_delete if no handler still // holds a reference. type refTree struct { - tree *tree_sitter.Tree + tree *treesitter.Tree refs int retired bool } @@ -76,9 +76,9 @@ func (rt *refTree) retireLocked() { // (e.g. Claude Code) can still query references/hover/definition without // causing unbounded memory growth. type DocumentStore struct { - mu sync.RWMutex - docs map[string]*cachedDoc - parser *tree_sitter.Parser + mu sync.RWMutex + docs map[string]*cachedDoc + parsers map[treesitter.Language]*tree_sitter.Parser // LRU bookkeeping for transient (disk-loaded) entries only. The list // holds URIs in access-order, newest at the front. transientIdx maps @@ -89,11 +89,9 @@ type DocumentStore struct { } func NewDocumentStore() *DocumentStore { - p := tree_sitter.NewParser() - _ = p.SetLanguage(tree_sitter.NewLanguage(tree_sitter_elixir.Language())) return &DocumentStore{ docs: make(map[string]*cachedDoc), - parser: p, + parsers: treesitter.AllParsers(), transientList: list.New(), transientIdx: make(map[string]*list.Element), maxTransient: defaultMaxTransient, @@ -149,7 +147,9 @@ func (ds *DocumentStore) CloseAll() { ds.docs = nil ds.transientList = nil ds.transientIdx = nil - ds.parser.Close() + for _, p := range ds.parsers { + p.Close() + } } func (ds *DocumentStore) Get(uri string) (string, bool) { @@ -300,7 +300,7 @@ func (ds *DocumentStore) evictTransientLocked() { // Callers must not close the returned tree directly. // // When ok is false, release is nil and must not be called. -func (ds *DocumentStore) GetTree(uri string) (*tree_sitter.Tree, []byte, func(), bool) { +func (ds *DocumentStore) GetTree(uri string) (*treesitter.Tree, []byte, func(), bool) { ds.mu.Lock() defer ds.mu.Unlock() doc, ok := ds.docs[uri] @@ -309,7 +309,7 @@ func (ds *DocumentStore) GetTree(uri string) (*tree_sitter.Tree, []byte, func(), } if doc.tree == nil { doc.src = []byte(doc.text) - doc.tree = &refTree{tree: ds.parser.Parse(doc.src, nil)} + doc.tree = &refTree{tree: treesitter.NewTreeWithParsers(doc.src, ds.parsers)} } rt := doc.tree rt.refs++ diff --git a/internal/lsp/documents_test.go b/internal/lsp/documents_test.go index e9ee910..e94af2a 100644 --- a/internal/lsp/documents_test.go +++ b/internal/lsp/documents_test.go @@ -268,8 +268,8 @@ func TestDocumentStore_GetTree_DiskLoaded(t *testing.T) { if string(src) != contents { t.Fatalf("GetTree src mismatch: got %q want %q", src, contents) } - if tree.RootNode().Kind() != "source" { - t.Fatalf("expected root node kind 'source', got %q", tree.RootNode().Kind()) + if tree.Trunk.RootNode().Kind() != "source" { + t.Fatalf("expected root node kind 'source', got %q", tree.Trunk.RootNode().Kind()) } } @@ -470,7 +470,7 @@ func TestDocumentStore_GetTree_SurvivesEviction(t *testing.T) { // Capture the root node kind so we can re-read it after eviction. // Pre-fix, the eviction below would call ts_tree_delete on this tree // and the second RootNode() call would read freed C memory. - rootKindBefore := tree.RootNode().Kind() + rootKindBefore := tree.Trunk.RootNode().Kind() // Force eviction of this URI while we still hold a ref. ds.SetMaxTransient(0) @@ -480,7 +480,7 @@ func TestDocumentStore_GetTree_SurvivesEviction(t *testing.T) { // Walking the tree after eviction must still work - this is the UAF // the refcounting prevents. - rootKindAfter := tree.RootNode().Kind() + rootKindAfter := tree.Trunk.RootNode().Kind() if rootKindAfter != rootKindBefore { t.Fatalf("tree root kind changed across eviction: got %q want %q", rootKindAfter, rootKindBefore) } @@ -537,7 +537,7 @@ func TestDocumentStore_GetTree_ConcurrentEvictionStress(t *testing.T) { if !ok { continue } - root := tree.RootNode() + root := tree.Trunk.RootNode() _ = root.Kind() _ = root.ChildCount() release() diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index 1f33c3d..058c8ac 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/remoteoss/dexter/internal/parser" ) @@ -436,6 +437,45 @@ func TestExpressionAtCursor_ExprBounds(t *testing.T) { } } +func TestExpressionAtCursor_HEEX(t *testing.T) { + tests := []struct { + code string + line, col int + want CursorContext + }{ + // all delimiter styles should be supported + {"~H\"\"\"\n<.foo />\n\"\"\"", 1, 2, CursorContext{FunctionName: "foo", ExprStart: 2, ExprEnd: 5}}, + {"~H'''\n<.foo />\n'''", 1, 2, CursorContext{FunctionName: "foo", ExprStart: 2, ExprEnd: 5}}, + {"~H\"<.foo />\"", 0, 5, CursorContext{FunctionName: "foo", ExprStart: 5, ExprEnd: 8}}, + {"~H'<.foo />'", 0, 5, CursorContext{FunctionName: "foo", ExprStart: 5, ExprEnd: 8}}, + {"~H[<.foo />]", 0, 5, CursorContext{FunctionName: "foo", ExprStart: 5, ExprEnd: 8}}, + // newline after delimiter is optional + {"~H\"\"\"<.foo />\"\"\"", 0, 7, CursorContext{FunctionName: "foo", ExprStart: 7, ExprEnd: 10}}, + {"~H[]", 0, 5, CursorContext{ModuleRef: "Foo", ExprStart: 4, ExprEnd: 7}}, + {"~H[]", 0, 9, CursorContext{ModuleRef: "Foo", FunctionName: "bar", ExprStart: 4, ExprEnd: 11}}, + {"~H[<.live_component module={Foo.Bar} />]", 0, 28, CursorContext{ModuleRef: "Foo", ExprStart: 28, ExprEnd: 31}}, + {"~H[<.live_component module={Foo.Bar} />]", 0, 32, CursorContext{ModuleRef: "Foo.Bar", ExprStart: 28, ExprEnd: 35}}, + {"~H'''\n<.live_component module={Foo.Bar} />\n'''", 1, 29, CursorContext{ModuleRef: "Foo.Bar", ExprStart: 25, ExprEnd: 32}}, + // interpolated expressions that aren't module/function should be ignored + {"~H[
]", 0, 11, CursorContext{}}, + // HTML tags should be ignored + {"~H[
]", 0, 4, CursorContext{}}, + // custom sigils should be parsed correctly but ignored + {"~x[_]", 0, 3, CursorContext{}}, + {"~X[_]", 0, 3, CursorContext{}}, + {"~XXX[_]", 0, 5, CursorContext{}}, + {"~X12[_]", 0, 5, CursorContext{}}, + } + + for _, tt := range tests { + tokens, source, lineStarts := tokenize(tt.code) + got := ExpressionAtCursor(tokens, source, lineStarts, tt.line, tt.col) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("ExpressionAtCursor(_, %#v, _, %d, %d)\nparse mismatch (-want +got):\n%s", tt.code, tt.line, tt.col, diff) + } + } +} + func TestCursorContext_Expr(t *testing.T) { tests := []struct { mod, fn, want string @@ -583,7 +623,7 @@ end` text := `defmodule MyApp.Web do alias MyApp.Services.{ Accounts, - + def foo do # missing close brace end diff --git a/internal/lsp/server.go b/internal/lsp/server.go index c2fd345..ea47a25 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -27,7 +27,6 @@ import ( "github.com/remoteoss/dexter/internal/parser" "github.com/remoteoss/dexter/internal/stdlib" "github.com/remoteoss/dexter/internal/store" - "github.com/remoteoss/dexter/internal/treesitter" "github.com/remoteoss/dexter/internal/version" ) @@ -661,7 +660,8 @@ func (s *Server) Definition(ctx context.Context, params *protocol.DefinitionPara // The first occurrence in scope is the definition (pattern/assignment). if tree, src, release, ok := s.docs.GetTree(docURI); ok { defer release() - if occs := treesitter.FindVariableOccurrencesWithTree(tree.RootNode(), src, uint(lineNum), uint(col)); len(occs) > 0 { + + if occs := tree.FindVariableOccurrences(src, uint(lineNum), uint(col)); len(occs) > 0 { s.debugf("Definition: returning variable definition at line %d", occs[0].Line) return []protocol.Location{{ URI: params.TextDocument.URI, @@ -1731,7 +1731,7 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara var varsInScope []string if tree, src, release, ok := s.docs.GetTree(docURI); ok { defer release() - varsInScope = treesitter.FindVariablesInScopeWithTree(tree.RootNode(), src, uint(lineNum), uint(col)) + varsInScope = tree.FindVariablesInScope(src, uint(lineNum), uint(col)) } for _, varName := range varsInScope { if strings.HasPrefix(varName, funcPrefix) && !seen[varName] { @@ -2687,10 +2687,9 @@ func (s *Server) DocumentHighlight(ctx context.Context, params *protocol.Documen return nil, nil } defer release() - root := tree.RootNode() // Try scope-aware variable highlight first - if occs := treesitter.FindVariableOccurrencesWithTree(root, src, uint(lineNum), uint(col)); len(occs) > 0 { + if occs := tree.FindVariableOccurrences(src, uint(lineNum), uint(col)); len(occs) > 0 { var highlights []protocol.DocumentHighlight for _, occ := range occs { highlights = append(highlights, protocol.DocumentHighlight{ @@ -2722,7 +2721,7 @@ func (s *Server) DocumentHighlight(ctx context.Context, params *protocol.Documen } // Reuse the same parsed tree for token occurrences - occs := treesitter.FindTokenOccurrencesWithTree(root, src, token) + occs := tree.FindTokenOccurrences(src, token) if len(occs) == 0 { return nil, nil } @@ -3606,7 +3605,7 @@ func (s *Server) PrepareRename(ctx context.Context, params *protocol.PrepareRena if moduleRef == "" { if tree, src, release, ok := s.docs.GetTree(docURI); ok { defer release() - if occs := treesitter.FindVariableOccurrencesWithTree(tree.RootNode(), src, uint(lineNum), uint(col)); len(occs) > 0 { + if occs := tree.FindVariableOccurrences(src, uint(lineNum), uint(col)); len(occs) > 0 { for _, occ := range occs { if occ.Line == uint(lineNum) && uint(col) >= occ.StartCol && uint(col) < occ.EndCol { return &protocol.Range{ @@ -3804,7 +3803,7 @@ func (s *Server) References(ctx context.Context, params *protocol.ReferenceParam // function reference lookup. if tree, src, release, ok := s.docs.GetTree(docURI); ok { defer release() - if occs := treesitter.FindVariableOccurrencesWithTree(tree.RootNode(), src, uint(lineNum), uint(col)); len(occs) > 0 { + if occs := tree.FindVariableOccurrences(src, uint(lineNum), uint(col)); len(occs) > 0 { var locations []protocol.Location for _, occ := range occs { locations = append(locations, protocol.Location{ @@ -4010,8 +4009,8 @@ func (s *Server) Rename(ctx context.Context, params *protocol.RenameParams) (*pr if moduleRef == "" { if tree, src, release, ok := s.docs.GetTree(docURI); ok { defer release() - if occs := treesitter.FindVariableOccurrencesWithTree(tree.RootNode(), src, uint(lineNum), uint(col)); len(occs) > 0 { - if treesitter.NameExistsInScopeOf(tree.RootNode(), src, uint(lineNum), uint(col), params.NewName) { + if occs := tree.FindVariableOccurrences(src, uint(lineNum), uint(col)); len(occs) > 0 { + if tree.NameExistsInScopeOf(src, uint(lineNum), uint(col), params.NewName) { return nil, fmt.Errorf("variable %q already exists in this scope", params.NewName) } changes := make(map[protocol.DocumentURI][]protocol.TextEdit) @@ -5021,6 +5020,8 @@ func (s *Server) getFileLine(filePath string, lineNum int) (string, bool) { return scanner.Text(), true } } + // ignore any scan error + _ = scanner.Err() return "", false } diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 4477684..65df317 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -2207,6 +2207,70 @@ end` } } +func TestDefinition_HEEXFunction(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + src := `defmodule TestLive do + use Phoenix.LiveView + + def render(assigns) do + ~H""" + <.foo /> + +
+ """ + end + + defp foo(_), do: ~H"" + defp class, do: "" +end` + + uri := "file://" + filepath.Join(server.projectRoot, "test_live.ex") + indexFile(t, server.store, server.projectRoot, "test_live.ex", src) + server.docs.Set(uri, src) + + // Cursor on "foo" at line 6 col 6 (the `<.foo />` component inside `render`) + locs := definitionAt(t, server, uri, 5, 6) + if len(locs) == 0 { + t.Fatal("expected go-to-definition for function 'foo'") + } + // Should jump to line 11 where `foo` is defined + if locs[0].Range.Start.Line != 11 { + t.Errorf("expected definition on line 9, got line %d", locs[0].Range.Start.Line) + } + + // Cursor on "TestLive.foo" at line 7 col 6 (the `TestLive` module of ``) + locs = definitionAt(t, server, uri, 6, 6) + if len(locs) == 0 { + t.Fatal("expected go-to-definition for module 'TestLive'") + } + // Should jump to line 1 where `TestLive` is defined + if locs[0].Range.Start.Line != 0 { + t.Errorf("expected definition on line 1, got line %d", locs[0].Range.Start.Line) + } + + // Cursor on "TestLive.foo" at line 7 col 15 (the `foo` function of ``) + locs = definitionAt(t, server, uri, 6, 15) + if len(locs) == 0 { + t.Fatal("expected go-to-definition for function 'foo'") + } + // Should jump to line 11 where `foo` is defined + if locs[0].Range.Start.Line != 11 { + t.Errorf("expected definition on line 11, got line %d", locs[0].Range.Start.Line) + } + + // Cursor on "class()" at line 8 col 16 (the `class()` call of `
`) + locs = definitionAt(t, server, uri, 7, 16) + if len(locs) == 0 { + t.Fatal("expected go-to-definition for function 'class'") + } + // Should jump to line 12 where `class` is defined + if locs[0].Range.Start.Line != 12 { + t.Errorf("expected definition on line 12, got line %d", locs[0].Range.Start.Line) + } +} + func TestHover_AliasInjectedByUse(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -2857,6 +2921,50 @@ end } } +func TestReferences_HEEXNestedReference(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + // Nested module: defmodule MoneyResponse inside Money creates + // MyApp.Money.MoneyResponse, but the defmodule line says just "MoneyResponse" + src := `defmodule App do + use Phoenix.LiveView + + def foo(assigns), do: ~H"" + + def render(assigns) do + ~H""" + <.foo /> + + """ + end +end +` + indexFile(t, server.store, server.projectRoot, "lib/app.ex", src) + + uri := "file://" + filepath.Join(server.projectRoot, "lib", "app.ex") + server.docs.Set(uri, src) + + // Go-to-references on "foo" in the component (line 9, col 11) + locs := referencesAt(t, server, uri, 8, 11) + if len(locs) == 0 { + t.Fatal("expected references for function MyApp.foo") + } + if locs[0].Range.Start.Line != 8 { + t.Fatalf("expected reference on line 8, got line %d", locs[0].Range.Start.Line) + } + + // Go-to-references on "foo" in the <.foo /> line (line 8, col 6) + locs = referencesAt(t, server, uri, 7, 6) + if len(locs) == 0 { + t.Fatal("expected references for function .foo") + } + if locs[0].Range.Start.Line != 7 { + t.Fatalf("expected reference on line 7, got line %d", locs[0].Range.Start.Line) + } + +} + func TestDefinition_QualifiedCallOnNestedModule(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() diff --git a/internal/parser/parser_tokenized.go b/internal/parser/parser_tokenized.go index 04d6153..6ad7506 100644 --- a/internal/parser/parser_tokenized.go +++ b/internal/parser/parser_tokenized.go @@ -649,6 +649,15 @@ func parseTextFromTokens(path string, source []byte, tokens []Token) ([]Definiti case TokIdent: cm := currentModule() if cm != "" && len(injectors) > 0 { + isHEEXFunction := i > 1 && tokens[i-1].Kind == TokDot && + (tokens[i-2].Kind == TokHEEXOpenTag || tokens[i-2].Kind == TokHEEXCloseTag) + if isHEEXFunction { + name := tokenText(tok) + refs = append(refs, Reference{Module: cm, Function: name, Line: tok.Line, FilePath: path, Kind: "call"}) + i++ + continue + } + isStatementStart := i == 0 || tokens[i-1].Kind == TokEOL || tokens[i-1].Kind == TokComment if isStatementStart { name := tokenText(tok) diff --git a/internal/parser/testdata/fuzz/FuzzTokenizeHeex/4e129080bf679ec0 b/internal/parser/testdata/fuzz/FuzzTokenizeHeex/4e129080bf679ec0 new file mode 100644 index 0000000..2dee8a8 --- /dev/null +++ b/internal/parser/testdata/fuzz/FuzzTokenizeHeex/4e129080bf679ec0 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("<%") diff --git a/internal/parser/testdata/fuzz/FuzzTokenizeHeex/d312599b9dbee58c b/internal/parser/testdata/fuzz/FuzzTokenizeHeex/d312599b9dbee58c new file mode 100644 index 0000000..fecbcf9 --- /dev/null +++ b/internal/parser/testdata/fuzz/FuzzTokenizeHeex/d312599b9dbee58c @@ -0,0 +1,2 @@ +go test fuzz v1 +string("<0/") diff --git a/internal/parser/tokenizer.go b/internal/parser/tokenizer.go index c43471e..ac24d0f 100644 --- a/internal/parser/tokenizer.go +++ b/internal/parser/tokenizer.go @@ -1,11 +1,16 @@ package parser import ( + "bytes" + "fmt" + "strings" "unicode" "unicode/utf8" ) // TokenKind identifies the kind of a lexed token. +// +//go:generate stringer -type=TokenKind type TokenKind byte const ( @@ -60,6 +65,8 @@ const ( TokAssoc // => TokDoubleColon // :: TokPercent // % + TokHEEXOpenTag // < + TokHEEXCloseTag // = len(source) { return i, line } + // sigilLetter is the letter after ~ (e.g. 's' in ~s, 'S' in ~S). Uppercase sigil + // letters mean the content is "raw" — backslash is NOT an escape character. + start := i + startLine := line + sigilLetter := source[i+1] + i += 2 // consume ~ and first letter + // Multi-char sigils: continue reading uppercase letters/numbers + if isUpper(sigilLetter) { + for i < len(source) && (isUpper(source[i]) || isDigit(source[i])) { + i++ + } + } + + sigilChars := string(source[start+1 : i]) + if i == len(source) { + return i, line + } + escapes := isLower(sigilLetter) // only lowercase sigils process escapes openCh := source[i] + var contentsStart, contentsEnd int + // Check for heredoc sigil: ~s""" or ~S""" if openCh == '"' && i+2 < len(source) && source[i+1] == '"' && source[i+2] == '"' { i += 3 // consume """ + contentsStart = i if escapes { i, line = scanHeredocContent(source, i, line, '"', lineStarts) } else { i, line = scanRawHeredocContent(source, i, line, '"', lineStarts) } - return i, line - } - if openCh == '\'' && i+2 < len(source) && source[i+1] == '\'' && source[i+2] == '\'' { + contentsEnd = i - 3 + } else if openCh == '\'' && i+2 < len(source) && source[i+1] == '\'' && source[i+2] == '\'' { i += 3 // consume ''' + contentsStart = i if escapes { i, line = scanHeredocContent(source, i, line, '\'', lineStarts) } else { i, line = scanRawHeredocContent(source, i, line, '\'', lineStarts) } + contentsEnd = i - 3 + } else { + i++ // consume opening delimiter + contentsStart = i + + var closeCh byte + nested := false + + switch openCh { + case '(': + closeCh = ')' + nested = true + case '[': + closeCh = ']' + nested = true + case '{': + closeCh = '}' + nested = true + case '<': + closeCh = '>' + nested = true + default: + closeCh = openCh + nested = false + } + + if nested { + depth := 1 + for i < len(source) && depth > 0 { + ch := source[i] + if ch == '\n' { + line++ + i++ + *lineStarts = append(*lineStarts, i) + } else if escapes && ch == '\\' && i+1 < len(source) { + if source[i+1] == '\n' { + line++ + *lineStarts = append(*lineStarts, i+2) + } + i += 2 + } else if ch == openCh { + depth++ + i++ + } else if ch == closeCh { + depth-- + i++ + } else { + i++ + } + } + } else { + for i < len(source) { + ch := source[i] + if ch == '\n' { + line++ + i++ + *lineStarts = append(*lineStarts, i) + } else if escapes && ch == '\\' && i+1 < len(source) { + if source[i+1] == '\n' { + line++ + *lineStarts = append(*lineStarts, i+2) + } + i += 2 + } else if ch == closeCh { + i++ // consume closing delimiter + break + } else { + i++ + } + } + } + + contentsEnd = i - 1 + + // Consume trailing modifier letters (e.g. the 'i' in ~r/foo/i) + for i < len(source) && isLetter(source[i]) { + i++ + } + } + + // incomplete sigil at end of document + if contentsEnd < contentsStart { return i, line } - i++ // consume opening delimiter - - var closeCh byte - nested := false - - switch openCh { - case '(': - closeCh = ')' - nested = true - case '[': - closeCh = ']' - nested = true - case '{': - closeCh = '}' - nested = true - case '<': - closeCh = '>' - nested = true - default: - closeCh = openCh - nested = false + // emit tokens if requested + if tokens != nil { + if contentsEnd == contentsStart { + // empty sigil + *tokens = append(*tokens, Token{Kind: TokSigil, Start: start, End: i, Line: line}) + } else { + scanSigilContents(sigilChars, source, start, i, contentsStart, contentsEnd, startLine, lineStarts, tokens) + } + } + + return i, line +} + +func scanSigilContents(sigilChars string, source []byte, start, end, contentsStart, contentsEnd, line int, lineStarts *[]int, tokens *[]Token) { + // only scan the contents of HEEX `~H` sigils + if sigilChars != "H" { + *tokens = append(*tokens, Token{Kind: TokSigil, Start: start, End: end, Line: line}) + return + } + + // lineStarts has already been updated by `scanHeredocContent` / `scanRawHeredocContent` + result := TokenizeHeex(source[contentsStart:contentsEnd]) + for _, t := range result.Tokens { + if t.Kind != TokEOF { + *tokens = append(*tokens, Token{Kind: t.Kind, Start: t.Start + contentsStart, End: t.End + contentsStart, Line: t.Line + line - 1}) + } + } +} + +func TokenizeHeex(source []byte) TokenResult { + tokens := make([]Token, 0, len(source)/8) + lineStarts := make([]int, 1, 64) + lineStarts[0] = 0 // line 1 starts at byte 0 + line := 1 + i := 0 + + scanComment := func(delim string, i, line int, lineStarts *[]int) (int, int) { + for i < len(source) { + if bytes.HasPrefix(source[i:], []byte(delim)) { + i += len(delim) + break + } + if source[i] == '\n' { + line++ + *lineStarts = append(*lineStarts, i+1) + } + i++ + } + return i, line + } + + scanInterpolation := func(i, line int, terminator string, tokens *[]Token) (int, int) { + start := i + startLine := line + + // lineStarts has already been updated during heredoc scanning + i_, line_, result := tokenizeUntil(source[start:], []byte(terminator)) + for _, t := range result.Tokens { + if t.Kind != TokEOF { + *tokens = append(*tokens, Token{Kind: t.Kind, Start: t.Start + start, End: t.End + start, Line: t.Line + startLine - 1}) + } + } + + i += i_ + len(terminator) + line += line_ - 1 + + return i, line + } + + scanTagName := func(i, line int, tokens *[]Token) (int, int) { + for i < len(source) { + switch { + // <.foo + case source[i] == '.': + *tokens = append(*tokens, Token{Kind: TokDot, Start: i, End: i + 1, Line: line}) + i++ + + start := i + for i < len(source) && (isLetter(source[i]) || isDigit(source[i]) || source[i] == '_' || source[i] == '-') { + i++ + } + if i == start { + return i, line + } + + if isUpper(source[start]) { + *tokens = append(*tokens, Token{Kind: TokModule, Start: start, End: i, Line: line}) + } else { + *tokens = append(*tokens, Token{Kind: TokIdent, Start: start, End: i, Line: line}) + return i, line + } + + //
0 { + scanTagAttr := func(i, line int, lineStarts *[]int, tokens *[]Token) (int, int) { + var quoteChar byte + for i < len(source) { ch := source[i] - if ch == '\n' { + + switch { + case ch == '\n': + *tokens = append(*tokens, Token{Kind: TokEOL, Start: i, End: i + 1, Line: line}) line++ i++ *lineStarts = append(*lineStarts, i) - } else if escapes && ch == '\\' && i+1 < len(source) { - if source[i+1] == '\n' { - line++ - *lineStarts = append(*lineStarts, i+2) + + case quoteChar == 0 && isWhitespace(ch): + return i, line + + case quoteChar == 0 && (ch == '\'' || ch == '"'): + quoteChar = ch + i++ + + case quoteChar != 0 && ch == '\\': + i++ + if i < len(source) { + i++ } - i += 2 - } else if ch == openCh { - depth++ + + case quoteChar != 0 && ch == quoteChar: i++ - } else if ch == closeCh { - depth-- + return i, line + + case quoteChar != 0: i++ - } else { + + case ch == '{': + i++ + return scanInterpolation(i, line, "}", tokens) + + case ch == '>': + return i, line + + default: i++ } } - } else { + + return i, line + } + + scanTag := func(i, line int, lineStarts *[]int, tokens *[]Token) (int, int) { + hasName := false for i < len(source) { - ch := source[i] - if ch == '\n' { + switch { + case source[i] == '\n': + *tokens = append(*tokens, Token{Kind: TokEOL, Start: i, End: i + 1, Line: line}) line++ i++ *lineStarts = append(*lineStarts, i) - } else if escapes && ch == '\\' && i+1 < len(source) { - if source[i+1] == '\n' { - line++ - *lineStarts = append(*lineStarts, i+2) + + case isWhitespace(source[i]): + i++ + + // HTML tag name, function, or module/function + case !hasName: + i, line = scanTagName(i, line, tokens) + hasName = true + + // attribute + case isLetter(source[i]): + i, line = scanTagAttr(i, line, lineStarts, tokens) + + // HEEX special attribute ":if={}", ":for={}", ":let={}", ":type={}" + case source[i] == ':': + i++ + i, line = scanTagAttr(i, line, lineStarts, tokens) + + // self-closing tag + case source[i] == '/': + i++ + if i < len(source) && source[i] == '>' { + i++ + return i, line } - i += 2 - } else if ch == closeCh { - i++ // consume closing delimiter - break - } else { + + // finish open tag + case source[i] == '>': + i++ + return i, line + + default: i++ } } + + return i, line } - // Consume trailing modifier letters (e.g. the 'i' in ~r/foo/i) - for i < len(source) && isLetter(source[i]) { - i++ + for i < len(source) { + ch := source[i] + + switch { + case ch == '\n': + tokens = append(tokens, Token{Kind: TokEOL, Start: i, End: i + 1, Line: line}) + line++ + i++ + lineStarts = append(lineStarts, i) + + case isWhitespace(ch): + i++ + + case ch == '<': + // consume < + i++ + if i < len(source) { + if source[i] == '!' && i+2 < len(source) && source[i+1] == '-' && source[i+2] == '-' { + // HTML comment "", i, line, &lineStarts) + tokens = append(tokens, Token{Kind: TokComment, Start: start, End: i, Line: startLine}) + } else if source[i] == '%' { + // consume % + i++ + if i+2 < len(source) && source[i] == '!' && source[i+1] == '-' && source[i+2] == '-' { + // HEEX comment "<%!--" + i += 3 + start := i - 5 + startLine := line + i, line = scanComment("--%>", i, line, &lineStarts) + tokens = append(tokens, Token{Kind: TokComment, Start: start, End: i, Line: startLine}) + } else if i < len(source) { + // consume "=" output indicator from "<%=" special form prefix + if source[i] == '=' { + i++ + } + // EEX interpolation "<%" + // EEX special form "<% for", "<% if", "<% case", "<% cond", "<% else", "<% end", "<% _ ->" + i, line = scanInterpolation(i, line, "%>", &tokens) + } + } else if source[i] == '/' { + i++ + tokens = append(tokens, Token{Kind: TokHEEXCloseTag, Start: i - 2, End: i, Line: line}) + i, line = scanTagName(i, line, &tokens) + if i < len(source) && source[i] == '>' { + i++ + } + } else { + // HTML tag "= '0' && ch <= '9' } +// isWhitespace returns true for space, tab, and carriage return. +func isWhitespace(ch byte) bool { + return ch == ' ' || ch == '\t' || ch == '\r' +} + // isHexDigit returns true for [0-9a-fA-F]. func isHexDigit(ch byte) bool { return isDigit(ch) || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F') @@ -957,3 +1276,25 @@ func bytesEqual(b []byte, s string) bool { func isKeywordKey(source []byte, i int) bool { return i < len(source) && source[i] == ':' && (i+1 >= len(source) || source[i+1] != ':') } + +// DebugTokens returns a string represention similar to %+v for a slice of tokens. +func DebugTokens(source []byte, tokens []Token) string { + var s strings.Builder + + for _, t := range tokens { + s.WriteString(t.Debug(source)) + } + + return s.String() +} + +// Debug returns a string representation similar to %+v for a token. +func (token Token) Debug(source []byte) string { + switch token.Kind { + case TokDot, TokEOL, TokEOF, TokOpenBrace, TokCloseBrace, TokHEEXOpenTag, TokHEEXCloseTag: + return fmt.Sprintf("%s (%d:%d)\n", token.Kind.String(), token.Start, token.End) + + default: + return fmt.Sprintf("%s (%d:%d) %#v\n", token.Kind.String(), token.Start, token.End, TokenText(source, token)) + } +} diff --git a/internal/parser/tokenizer_test.go b/internal/parser/tokenizer_test.go index 7b38365..24515b1 100644 --- a/internal/parser/tokenizer_test.go +++ b/internal/parser/tokenizer_test.go @@ -1,9 +1,13 @@ package parser import ( + "context" "fmt" "strings" "testing" + "time" + + "github.com/google/go-cmp/cmp" ) // tokenizeNoEOF runs Tokenize and strips the trailing TokEOF for cleaner assertions. @@ -2077,9 +2081,8 @@ func TestTokenize_EscapedNewlineLineTracking(t *testing.T) { } func TestLineStartsAccuracy(t *testing.T) { - assertLineStarts := func(t *testing.T, src string) { + assertLineStarts := func(t *testing.T, src string, result TokenResult) { t.Helper() - result := TokenizeFull([]byte(src)) lineStarts := result.LineStarts lines := strings.Split(src, "\n") if len(lineStarts) != len(lines) { @@ -2100,9 +2103,8 @@ func TestLineStartsAccuracy(t *testing.T) { } } - assertTokenAt := func(t *testing.T, src string, line0, col int, wantKind TokenKind, wantText string) { + assertTokenAt := func(t *testing.T, src string, result TokenResult, line0, col int, wantKind TokenKind, wantText string) { t.Helper() - result := TokenizeFull([]byte(src)) offset := LineColToOffset(result.LineStarts, line0, col) idx := TokenAtOffset(result.Tokens, offset) if idx < 0 { @@ -2119,25 +2121,129 @@ func TestLineStartsAccuracy(t *testing.T) { t.Run("heredoc", func(t *testing.T) { src := "defmodule MyApp.Example do\n @moduledoc \"\"\"\n This is a long\n multiline heredoc\n with several lines\n of documentation.\n \"\"\"\n\n @type t :: %__MODULE__{\n name: String.t(),\n age: Integer.t()\n }\n\n def hello do\n :world\n end\nend" - assertLineStarts(t, src) - assertTokenAt(t, src, 9, 16, TokModule, "String") + result := TokenizeFull([]byte(src)) + assertLineStarts(t, src, result) + assertTokenAt(t, src, result, 9, 16, TokModule, "String") }) t.Run("multiline string", func(t *testing.T) { src := "x = \"line one\nline two\nline three\"\ny = Enum.map(list, fn x -> x end)" - assertLineStarts(t, src) - assertTokenAt(t, src, 3, 4, TokModule, "Enum") + result := TokenizeFull([]byte(src)) + assertLineStarts(t, src, result) + assertTokenAt(t, src, result, 3, 4, TokModule, "Enum") }) t.Run("sigil heredoc", func(t *testing.T) { src := "x = ~s\"\"\"\nline one\nline two\n\"\"\"\ny = MyModule.func()" - assertLineStarts(t, src) - assertTokenAt(t, src, 4, 4, TokModule, "MyModule") + result := TokenizeFull([]byte(src)) + assertLineStarts(t, src, result) + assertTokenAt(t, src, result, 4, 4, TokModule, "MyModule") }) t.Run("multiline interpolation", func(t *testing.T) { src := "x = \"hello #{\n some_func()\n}\"\ny = String.trim(x)" - assertLineStarts(t, src) - assertTokenAt(t, src, 3, 4, TokModule, "String") + result := TokenizeFull([]byte(src)) + assertLineStarts(t, src, result) + assertTokenAt(t, src, result, 3, 4, TokModule, "String") + }) + + t.Run("HEEX: comment", func(t *testing.T) { + src := "" + result := TokenizeHeex([]byte(src)) + assertLineStarts(t, src, result) + assertTokenAt(t, src, result, 0, 0, TokComment, "") + }) + + t.Run("HEEX: sigil contents", func(t *testing.T) { + src := "defmodule PageLive do\n def render(assigns) do\n ~H\"\"\"\n
\n \"\"\"\n end\nend" + result := TokenizeFull([]byte(src)) + assertLineStarts(t, src, result) + assertTokenAt(t, src, result, 6, 2, TokEnd, "end") + }) +} + +func TestTokenizeHeex(t *testing.T) { + tests := []struct { + src, want string + }{ + {"<%!-- hello, world! --%>", + `TokComment (0:24) "<%!-- hello, world! --%>" +TokEOF (24:24) +`}, + {"
hello!
", `TokHEEXOpenTag (0:1) +TokHEEXCloseTag (11:13) +TokEOF (17:17) +`}, + {"<.foo>", `TokHEEXOpenTag (0:1) +TokDot (1:2) +TokIdent (2:5) "foo" +TokHEEXCloseTag (6:8) +TokDot (8:9) +TokIdent (9:12) "foo" +TokEOF (13:13) +`}, + {"<.foo />", `TokHEEXOpenTag (0:1) +TokDot (1:2) +TokIdent (2:5) "foo" +TokEOF (8:8) +`}, + {"<.live_component id=\"foo\" module={Foo.Bar} no-value />", `TokHEEXOpenTag (0:1) +TokDot (1:2) +TokIdent (2:16) "live_component" +TokModule (34:37) "Foo" +TokDot (37:38) +TokModule (38:41) "Bar" +TokEOF (54:54) +`}, + {"
", `TokHEEXOpenTag (0:1) +TokString (12:16) "\"{}\"" +TokEOF (20:20) +`}, + } + + for _, tt := range tests { + err := withTimeout(2_000, func() { + result := TokenizeHeex([]byte(tt.src)) + got := DebugTokens([]byte(tt.src), result.Tokens) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("TokenizeHeex(src) (-want +got)\n\n%.512s\n\n%s", tt.src, diff) + } + }) + if err == context.DeadlineExceeded { + t.Errorf("TokenizeHeex(src) timeout after 2s\n\n%.512s", tt.src) + } + } +} + +func FuzzTokenizeHeex(f *testing.F) { + f.Fuzz(func(t *testing.T, src string) { + err := withTimeout(2_000, func() { + result := TokenizeHeex([]byte(src)) + // should always output at least TokEOF + if len(result.Tokens) == 0 { + t.Errorf("TokenizeHeex(src) empty output\n\n%.512s", src) + } + }) + if err == context.DeadlineExceeded { + t.Errorf("TokenizeHeex(src) timeout after 2s\n\n%.512s", src) + } }) } + +func withTimeout(ms time.Duration, cb func()) error { + ctx, cancel := context.WithTimeout(context.Background(), ms*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + cb() + done <- struct{}{} + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/internal/parser/tokenkind_string.go b/internal/parser/tokenkind_string.go new file mode 100644 index 0000000..f1f0ba0 --- /dev/null +++ b/internal/parser/tokenkind_string.go @@ -0,0 +1,81 @@ +// Code generated by "stringer -type=TokenKind"; DO NOT EDIT. + +package parser + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[TokDefmodule-0] + _ = x[TokDef-1] + _ = x[TokDefp-2] + _ = x[TokDefmacro-3] + _ = x[TokDefmacrop-4] + _ = x[TokDefguard-5] + _ = x[TokDefguardp-6] + _ = x[TokDefdelegate-7] + _ = x[TokDefprotocol-8] + _ = x[TokDefimpl-9] + _ = x[TokDefstruct-10] + _ = x[TokDefexception-11] + _ = x[TokAlias-12] + _ = x[TokImport-13] + _ = x[TokUse-14] + _ = x[TokRequire-15] + _ = x[TokDo-16] + _ = x[TokEnd-17] + _ = x[TokFn-18] + _ = x[TokWhen-19] + _ = x[TokIdent-20] + _ = x[TokModule-21] + _ = x[TokAttr-22] + _ = x[TokAttrDoc-23] + _ = x[TokAttrSpec-24] + _ = x[TokAttrType-25] + _ = x[TokAttrBehaviour-26] + _ = x[TokAttrCallback-27] + _ = x[TokString-28] + _ = x[TokHeredoc-29] + _ = x[TokSigil-30] + _ = x[TokCharLiteral-31] + _ = x[TokAtom-32] + _ = x[TokDot-33] + _ = x[TokComma-34] + _ = x[TokColon-35] + _ = x[TokOpenParen-36] + _ = x[TokCloseParen-37] + _ = x[TokOpenBracket-38] + _ = x[TokCloseBracket-39] + _ = x[TokOpenBrace-40] + _ = x[TokCloseBrace-41] + _ = x[TokOpenAngle-42] + _ = x[TokCloseAngle-43] + _ = x[TokPipe-44] + _ = x[TokBackslash-45] + _ = x[TokRightArrow-46] + _ = x[TokLeftArrow-47] + _ = x[TokAssoc-48] + _ = x[TokDoubleColon-49] + _ = x[TokPercent-50] + _ = x[TokHEEXOpenTag-51] + _ = x[TokHEEXCloseTag-52] + _ = x[TokNumber-53] + _ = x[TokComment-54] + _ = x[TokEOL-55] + _ = x[TokEOF-56] + _ = x[TokOther-57] +} + +const _TokenKind_name = "TokDefmoduleTokDefTokDefpTokDefmacroTokDefmacropTokDefguardTokDefguardpTokDefdelegateTokDefprotocolTokDefimplTokDefstructTokDefexceptionTokAliasTokImportTokUseTokRequireTokDoTokEndTokFnTokWhenTokIdentTokModuleTokAttrTokAttrDocTokAttrSpecTokAttrTypeTokAttrBehaviourTokAttrCallbackTokStringTokHeredocTokSigilTokCharLiteralTokAtomTokDotTokCommaTokColonTokOpenParenTokCloseParenTokOpenBracketTokCloseBracketTokOpenBraceTokCloseBraceTokOpenAngleTokCloseAngleTokPipeTokBackslashTokRightArrowTokLeftArrowTokAssocTokDoubleColonTokPercentTokHEEXOpenTagTokHEEXCloseTagTokNumberTokCommentTokEOLTokEOFTokOther" + +var _TokenKind_index = [...]uint16{0, 12, 18, 25, 36, 48, 59, 71, 85, 99, 109, 121, 136, 144, 153, 159, 169, 174, 180, 185, 192, 200, 209, 216, 226, 237, 248, 264, 279, 288, 298, 306, 320, 327, 333, 341, 349, 361, 374, 388, 403, 415, 428, 440, 453, 460, 472, 485, 497, 505, 519, 529, 543, 558, 567, 577, 583, 589, 597} + +func (i TokenKind) String() string { + idx := int(i) - 0 + if i < 0 || idx >= len(_TokenKind_index)-1 { + return "TokenKind(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _TokenKind_name[_TokenKind_index[idx]:_TokenKind_index[idx+1]] +} diff --git a/internal/treesitter/tree.go b/internal/treesitter/tree.go new file mode 100644 index 0000000..a6619f5 --- /dev/null +++ b/internal/treesitter/tree.go @@ -0,0 +1,306 @@ +package treesitter + +import ( + "unsafe" + + tree_sitter_heex "github.com/phoenixframework/tree-sitter-heex/bindings/go" + tree_sitter "github.com/tree-sitter/go-tree-sitter" + tree_sitter_elixir "github.com/tree-sitter/tree-sitter-elixir/bindings/go" +) + +// Tree contains a document trunk tree and a map of any branch sub-trees. +// For Elixir trunks, Branches is a map of `quoted_content` node IDs within sigils +// in the document tree to their corresponding HEEX sub-tree. For HEEX trunks, +// Branches is a map of `expression_value` node IDs within interpolated expressions +// in the document tree to their corresponding Elixir sub-tree. Sub-trees may +// be nested arbitrarily deep, though in practice it will typically be 1-3 levels. +// +// For nested sub-trees, Root points back to the parent tree that contains the +// sub-tree. Navigation is possible both up (using Parent()) and down (using Child(i)). +// +// Elixir->HEEX: (sigil (sigil_name) node: (quoted_content)) +// HEEX->Elixir: (expression node: (expression_value)) +type Tree struct { + Root *TreeNode + Trunk *tree_sitter.Tree + Branches map[uintptr]*Tree + Language Language +} + +// TrunkNode returns a TreeNode pointing to the root node of the trunk. +func (t *Tree) TrunkNode() *TreeNode { + return &TreeNode{Tree: t, Node: t.Trunk.RootNode()} +} + +// Close recursively closes the trunk tree and any branch sub-trees. +func (t *Tree) Close() { + for _, b := range t.Branches { + b.Close() + } + t.Trunk.Close() +} + +// TreeNode represents a node within a tree or sub-tree. +// This facilitates traversal between trunk trees and branch sub-trees. +type TreeNode struct { + Tree *Tree + Node *tree_sitter.Node +} + +// See tree_sitter.Node.Kind(). +func (tn *TreeNode) Kind() string { + return tn.Node.Kind() +} + +// See tree_sitter.Node.IsNamed(). +func (tn *TreeNode) IsNamed() bool { + return tn.Node.IsNamed() +} + +// See tree_sitter.Node.ToSexp(). +func (tn *TreeNode) ToSexp() string { + return tn.Node.ToSexp() +} + +// See tree_sitter.Node.StartByte(). +func (tn *TreeNode) StartByte() uint { + if tn.Tree.Root == nil { + return tn.Node.StartByte() + } + return tn.Tree.Root.StartByte() + tn.Node.StartByte() +} + +// See tree_sitter.Node.EndByte(). +func (tn *TreeNode) EndByte() uint { + if tn.Tree.Root == nil { + return tn.Node.EndByte() + } + return tn.Tree.Root.StartByte() + tn.Node.EndByte() +} + +// Parent returns the node containing the given node in the tree, or the node +// in the root tree that contains the node if the node is the root of a branch +// sub-tree. If the node is the top-most root, returns nil. +func (tn *TreeNode) Parent() *TreeNode { + if parent := tn.Node.Parent(); parent != nil { + return &TreeNode{Tree: tn.Tree, Node: parent} + } + return tn.Tree.Root +} + +// ChildCount returns the number of children for the given node, returning +// 1 for nodes that link to a branch sub-tree. +func (tn *TreeNode) ChildCount() uint { + if branch := tn.Tree.Branches[tn.Node.Id()]; branch != nil { + return 1 + } + return tn.Node.ChildCount() +} + +// Child returns the tree/child of the given node, moving into a sub-tree if +// the node links to a branch sub-tree. +func (tn *TreeNode) Child(i uint) *TreeNode { + if branch := tn.Tree.Branches[tn.Node.Id()]; branch != nil { + return branch.TrunkNode() + } + return &TreeNode{Tree: tn.Tree, Node: tn.Node.Child(i)} +} + +// StartPosition returns the (row, col) start position of the given node +// within the top-most root tree. +func (tn *TreeNode) StartPosition() tree_sitter.Point { + if tn.Tree.Root == nil { + return tn.Node.StartPosition() + } + p := tn.Tree.Root.StartPosition() + sp := tn.Node.StartPosition() + p.Row += sp.Row + if sp.Row == 0 { + p.Column += sp.Column + } else { + p.Column = sp.Column + } + return p +} + +// EndPosition returns the (row, col) end position of the given node +// within the top-most root tree. +func (tn *TreeNode) EndPosition() tree_sitter.Point { + if tn.Tree.Root == nil { + return tn.Node.EndPosition() + } + p := tn.Tree.Root.StartPosition() + ep := tn.Node.EndPosition() + p.Row += ep.Row + if ep.Row == 0 { + p.Column += ep.Column + } else { + p.Column = ep.Column + } + return p +} + +// Utf8Text returns the UTF-8 encoded string representation of the given node +// within the top-most root tree. +func (tn *TreeNode) Utf8Text(src []byte) string { + if tn.Tree.Root == nil { + return tn.Node.Utf8Text(src) + } + return tn.Tree.Root.Utf8Text(src)[tn.Node.StartByte():tn.Node.EndByte()] +} + +// ContainsPosition returns true if the node contains the given position +// in the top-most root tree. Tree-sitter end positions are exclusive, +// consistent with nodeAtPosition. +func (tn *TreeNode) ContainsPosition(line, col uint) bool { + start := tn.StartPosition() + end := tn.EndPosition() + if line < uint(start.Row) || line > uint(end.Row) { + return false + } + if line == uint(start.Row) && col < uint(start.Column) { + return false + } + if line == uint(end.Row) && col >= uint(end.Column) { + return false + } + return true +} + +// ChildAtPosition find the deepest (most specific) child node at the given position +// within the top-most root tree. +func (tn *TreeNode) ChildAtPosition(line, col uint) *TreeNode { + // Check if position is within this node + if tn == nil || !tn.ContainsPosition(line, col) { + return nil + } + + // Try to find a more specific child + for i := uint(0); i < tn.ChildCount(); i++ { + if found := tn.Child(i).ChildAtPosition(line, col); found != nil { + return found + } + } + + return tn +} + +// NewTree creates parsers, parses src, parses nested HEEX templates, and returns the created trees. +// Used by the standalone (non-cached) entry points. Returns nil on failure. +func NewTree(src []byte) *Tree { + parsers := AllParsers() + if parsers == nil { + return nil + } + for _, p := range parsers { + defer p.Close() + } + return NewTreeWithParsers(src, parsers) +} + +// NewTreeWithParsers parses src, parses nested HEEX templates, and returns the created trees. +// Used by cached entry points . Returns nil on failure. +func NewTreeWithParsers(src []byte, parsers map[Language]*tree_sitter.Parser) *Tree { + return newTree(LangElixir, src, parsers) +} + +func newTree(lang Language, src []byte, parsers map[Language]*tree_sitter.Parser) *Tree { + trunk := parsers[lang].Parse(src, nil) + if trunk == nil { + return nil + } + + t := &Tree{ + Language: lang, + Trunk: trunk, + Branches: make(map[uintptr]*Tree), + } + + visitTree(trunk.RootNode(), func(node *tree_sitter.Node) { + // when visiting Elixir trees, parse nested ~H sigils as HEEX sub-trees + if lang == LangElixir && + node.Kind() == "quoted_content" && + node.Parent() != nil && node.Parent().Kind() == "sigil" && + /* sigil_name */ node.PrevNamedSibling() != nil && node.PrevNamedSibling().Utf8Text(src) == "H" { + if tree := newTree(LangHeex, src[node.StartByte():node.EndByte()], parsers); tree != nil { + tree.Root = &TreeNode{Tree: t, Node: node} + t.Branches[node.Id()] = tree + } + } + + // when visiting HEEX trees, parse nested expressions as Elixir sub-trees + if lang == LangHeex && node.Kind() == "expression_value" { + if tree := newTree(LangElixir, src[node.StartByte():node.EndByte()], parsers); tree != nil { + tree.Root = &TreeNode{Tree: t, Node: node} + t.Branches[node.Id()] = tree + } + } + }) + + return t +} + +type Language byte + +const ( + LangElixir Language = iota + LangHeex +) + +func NewParser(lang Language) *tree_sitter.Parser { + var language unsafe.Pointer + switch lang { + case LangElixir: + language = tree_sitter_elixir.Language() + case LangHeex: + language = tree_sitter_heex.Language() + } + + p := tree_sitter.NewParser() + if err := p.SetLanguage(tree_sitter.NewLanguage(language)); err != nil { + return nil + } + + return p +} + +func AllParsers() map[Language]*tree_sitter.Parser { + parsers := make(map[Language]*tree_sitter.Parser) + + for _, l := range []Language{LangElixir, LangHeex} { + p := NewParser(l) + if p == nil { + // if a parser fails to initialize, close any already-opened parsers + for _, pp := range parsers { + pp.Close() + } + return nil + } + parsers[l] = p + } + + return parsers +} + +func visitTree(root *tree_sitter.Node, onNode func(node *tree_sitter.Node)) { + cursor := root.Walk() + defer cursor.Close() + + for { + // visit current node + onNode(cursor.Node()) + + // traverse down one level, if possible + if cursor.GotoFirstChild() { + continue + } + + // traverse via siblings, if possible + for !cursor.GotoNextSibling() { + // move back up and recurse, returning once we're back to the root + if !cursor.GotoParent() { + return + } + } + } +} diff --git a/internal/treesitter/tree_test.go b/internal/treesitter/tree_test.go new file mode 100644 index 0000000..b6dbec2 --- /dev/null +++ b/internal/treesitter/tree_test.go @@ -0,0 +1,75 @@ +package treesitter + +import ( + "maps" + "slices" + "testing" +) + +func TestNewTree(t *testing.T) { + src := `def render(assigns) do + ~H""" +
+ <%= bar() %> +
+ """ +end` + tree := NewTree([]byte(src)) + if tree.Language != LangElixir { + t.Errorf("expected Elixir root tree, got %#v", tree.Language) + } + + heexNodeIds := slices.Collect(maps.Keys(tree.Branches)) + if len(heexNodeIds) != 1 { + t.Errorf("expected 1 Heex branch, got %d", len(heexNodeIds)) + } + heexTree := tree.Branches[heexNodeIds[0]] + if heexTree.Language != LangHeex { + t.Errorf("expected Heex branch sub-tree, got %#v", heexTree.Language) + } + if rootId := heexTree.Root.Node.Id(); rootId != heexNodeIds[0] { + t.Errorf("expected Heex root to match branch node ID %d, got %d", heexNodeIds[0], rootId) + } + wantHeex := "
\n <%= bar() %>\n
\n " + if heexText := heexTree.TrunkNode().Utf8Text([]byte(src)); heexText != wantHeex { + t.Errorf("unexpected Heex text (-want, +got)\n- %#v\n+ %#v", wantHeex, heexText) + } + + exNodeIds := slices.Collect(maps.Keys(heexTree.Branches)) + if len(exNodeIds) != 2 { + t.Errorf("expected 2 Elixir branch, got %d", len(exNodeIds)) + } + for _, branch := range heexTree.Branches { + if exText := branch.TrunkNode().Utf8Text([]byte(src)); !slices.Contains([]string{"foo()", "bar()"}, exText) { + t.Errorf("unexpected nested Elixir text, got %#v", exText) + } + } +} + +func TestTreeNode_ByteAndPosition(t *testing.T) { + src := `def render(assigns) do + ~H""" +
+ <%= bar() %> +
+ """ +end` + + tree := NewTree([]byte(src)) + // bar() on line 4 col 8 + node := tree.TrunkNode().ChildAtPosition(3, 8) + text := node.Utf8Text([]byte(src)) + if node.StartByte() != 61 { + t.Errorf("expected %#v to start at byte %d, got %d", text, 61, node.StartByte()) + } + if node.EndByte() != 64 { + t.Errorf("expected %#v to end at byte %d, got %d", text, 64, node.EndByte()) + } + + if sp := node.StartPosition(); sp.Row != 3 || sp.Column != 8 { + t.Errorf("expected %#v to start at position (Row: %d, Col: %d), got (Row: %d, Col: %d)", text, 0, 0, sp.Row, sp.Column) + } + if ep := node.EndPosition(); ep.Row != 3 || ep.Column != 11 { + t.Errorf("expected %#v to end at position (Row: %d, Col: %d), got (Row: %d, Col: %d)", text, 0, 0, ep.Row, ep.Column) + } +} diff --git a/internal/treesitter/variables.go b/internal/treesitter/variables.go index 282ee81..4d563a0 100644 --- a/internal/treesitter/variables.go +++ b/internal/treesitter/variables.go @@ -2,27 +2,8 @@ package treesitter import ( "strings" - - tree_sitter "github.com/tree-sitter/go-tree-sitter" - tree_sitter_elixir "github.com/tree-sitter/tree-sitter-elixir/bindings/go" ) -// parseElixir creates a parser, parses src, and returns the root node plus a -// cleanup function that closes both tree and parser. Used by the standalone -// (non-cached) entry points. Returns (nil, nil) on failure. -func parseElixir(src []byte) (root *tree_sitter.Node, cleanup func()) { - p := tree_sitter.NewParser() - if err := p.SetLanguage(tree_sitter.NewLanguage(tree_sitter_elixir.Language())); err != nil { - p.Close() - return nil, nil - } - tree := p.Parse(src, nil) - return tree.RootNode(), func() { - tree.Close() - p.Close() - } -} - // VariableOccurrence is a position where a variable name appears. type VariableOccurrence struct { Line uint // 0-based @@ -34,18 +15,18 @@ type VariableOccurrence struct { // occurrences of the variable at the given cursor position within the // enclosing function scope. Returns nil if the cursor is not on a variable. func FindVariableOccurrences(src []byte, line, col uint) []VariableOccurrence { - root, cleanup := parseElixir(src) - if root == nil { + tree := NewTree(src) + if tree == nil { return nil } - defer cleanup() - return FindVariableOccurrencesWithTree(root, src, line, col) + defer tree.Close() + return tree.FindVariableOccurrences(src, line, col) } // FindVariableOccurrencesWithTree is like FindVariableOccurrences but uses a // pre-parsed tree root, avoiding redundant parsing when a cached tree exists. -func FindVariableOccurrencesWithTree(root *tree_sitter.Node, src []byte, line, col uint) []VariableOccurrence { - resolved := resolveVariableScope(root, src, line, col) +func (t *Tree) FindVariableOccurrences(src []byte, line, col uint) []VariableOccurrence { + resolved := t.resolveVariableScope(src, line, col) if resolved == nil { return nil } @@ -93,8 +74,8 @@ func FindVariableOccurrencesWithTree(root *tree_sitter.Node, src []byte, line, c // // Bare identifiers that are zero-arity function calls (not bound as variables) // are NOT considered collisions — in Elixir, a variable simply shadows them. -func NameExistsInScopeOf(root *tree_sitter.Node, src []byte, line, col uint, newName string) bool { - resolved := resolveVariableScope(root, src, line, col) +func (t *Tree) NameExistsInScopeOf(src []byte, line, col uint, newName string) bool { + resolved := t.resolveVariableScope(src, line, col) if resolved == nil { return false } @@ -113,7 +94,7 @@ func NameExistsInScopeOf(root *tree_sitter.Node, src []byte, line, col uint, new // rather than a bare zero-arity function call. Reuses the full variable // resolution logic so the same scoping rules apply. pos := target.StartPosition() - return len(FindVariableOccurrencesWithTree(root, src, uint(pos.Row), uint(pos.Column))) > 0 + return len(t.FindVariableOccurrences(src, uint(pos.Row), uint(pos.Column))) > 0 } // findFirstNonCallIdentifier returns the first identifier node in the subtree @@ -122,11 +103,11 @@ func NameExistsInScopeOf(root *tree_sitter.Node, src []byte, line, col uint, new // descended into — a same-named binding inside one is not a collision in the // scope rooted at node. (The root itself may be such a def call when renaming a // function-local; that is the chosen scope and is always searched.) -func findFirstNonCallIdentifier(node *tree_sitter.Node, src []byte, name string) *tree_sitter.Node { +func findFirstNonCallIdentifier(node *TreeNode, src []byte, name string) *TreeNode { return findFirstNonCallIdentifierInScope(node, src, name, true) } -func findFirstNonCallIdentifierInScope(node *tree_sitter.Node, src []byte, name string, isRoot bool) *tree_sitter.Node { +func findFirstNonCallIdentifierInScope(node *TreeNode, src []byte, name string, isRoot bool) *TreeNode { if node == nil { return nil } @@ -146,8 +127,8 @@ func findFirstNonCallIdentifierInScope(node *tree_sitter.Node, src []byte, name // resolvedScope holds the result of locating a variable's scope. type resolvedScope struct { - cursorNode *tree_sitter.Node - scope *tree_sitter.Node + cursorNode *TreeNode + scope *TreeNode varName string moduleAttribute bool // true when the identifier is a module attribute (@foo) } @@ -155,8 +136,9 @@ type resolvedScope struct { // resolveVariableScope locates the cursor node at (line, col), validates it as // a variable or module attribute, and returns the enclosing scope. Returns nil // if the position is not on a renameable variable. -func resolveVariableScope(root *tree_sitter.Node, src []byte, line, col uint) *resolvedScope { - cursorNode := nodeAtPosition(root, line, col) +func (t *Tree) resolveVariableScope(src []byte, line, col uint) *resolvedScope { + cursorNode := t.TrunkNode().ChildAtPosition(line, col) + if cursorNode == nil || cursorNode.Kind() != "identifier" { return nil } @@ -202,7 +184,7 @@ func resolveVariableScope(root *tree_sitter.Node, src []byte, line, col uint) *r } // moduleAttributeExists returns true if @name appears in the subtree. -func moduleAttributeExists(node *tree_sitter.Node, src []byte, name string) bool { +func moduleAttributeExists(node *TreeNode, src []byte, name string) bool { if node == nil { return false } @@ -217,40 +199,10 @@ func moduleAttributeExists(node *tree_sitter.Node, src []byte, name string) bool return false } -// nodeAtPosition finds the deepest (most specific) node at the given position. -func nodeAtPosition(node *tree_sitter.Node, line, col uint) *tree_sitter.Node { - if node == nil { - return nil - } - start := node.StartPosition() - end := node.EndPosition() - - // Check if position is within this node - if line < uint(start.Row) || line > uint(end.Row) { - return nil - } - if line == uint(start.Row) && col < uint(start.Column) { - return nil - } - if line == uint(end.Row) && col >= uint(end.Column) { - return nil - } - - // Try to find a more specific child - for i := uint(0); i < uint(node.ChildCount()); i++ { - child := node.Child(i) - if found := nodeAtPosition(child, line, col); found != nil { - return found - } - } - - return node -} - // isFunctionNameInCall returns true if the identifier is the function name // in a call expression (e.g., `foo` in `foo(args)`) or a function name being // defined (e.g., `foo` in `def foo(args) do`). -func isFunctionNameInCall(node *tree_sitter.Node, src []byte) bool { +func isFunctionNameInCall(node *TreeNode, src []byte) bool { parent := node.Parent() if parent == nil { return false @@ -315,7 +267,7 @@ var moduleKeywords = map[string]bool{ // function definition do not leak to (and cannot reference) an enclosing // module/script scope, so traversals rooted at an outer scope must not descend // into these. -func isFunctionDefinitionCall(node *tree_sitter.Node, src []byte) bool { +func isFunctionDefinitionCall(node *TreeNode, src []byte) bool { if node.Kind() != "call" || node.ChildCount() == 0 { return false } @@ -325,7 +277,7 @@ func isFunctionDefinitionCall(node *tree_sitter.Node, src []byte) bool { // isModuleDefinitionCall reports whether node is a defmodule/defprotocol/defimpl // call, which opens a module-body scope. -func isModuleDefinitionCall(node *tree_sitter.Node, src []byte) bool { +func isModuleDefinitionCall(node *TreeNode, src []byte) bool { if node.Kind() != "call" || node.ChildCount() == 0 { return false } @@ -337,13 +289,13 @@ func isModuleDefinitionCall(node *tree_sitter.Node, src []byte) bool { // variable scope — a function or module definition. A traversal rooted at an // outer scope (a module body, or the whole file) must not descend into these, // or a rename/collision check would wrongly reach into an unrelated scope. -func definesNestedScope(node *tree_sitter.Node, src []byte) bool { +func definesNestedScope(node *TreeNode, src []byte) bool { return isFunctionDefinitionCall(node, src) || isModuleDefinitionCall(node, src) } // isAssignmentTarget returns true if node is on the left-hand side of a `=` // binary operator, meaning it is unambiguously a variable binding. -func isAssignmentTarget(node *tree_sitter.Node, src []byte) bool { +func isAssignmentTarget(node *TreeNode, src []byte) bool { parent := node.Parent() if parent == nil || parent.Kind() != "binary_operator" || parent.ChildCount() < 3 { return false @@ -360,7 +312,7 @@ func isAssignmentTarget(node *tree_sitter.Node, src []byte) bool { // at a position other than the cursor. A bare identifier that only appears // at the cursor position is ambiguous (could be a zero-arity function call) // and should not be treated as a variable. -func variableDefinedInScope(scope *tree_sitter.Node, src []byte, varName string, cursorLine, cursorCol uint) bool { +func variableDefinedInScope(scope *TreeNode, src []byte, varName string, cursorLine, cursorCol uint) bool { return identifierExistsElsewhere(scope, src, varName, cursorLine, cursorCol, true) } @@ -371,7 +323,7 @@ func variableDefinedInScope(scope *tree_sitter.Node, src []byte, varName string, // the chosen scope itself, which may be such a def call) — otherwise a bare // top-level call sharing a name with a function-local would be misread as a // variable. -func identifierExistsElsewhere(node *tree_sitter.Node, src []byte, name string, line, col uint, isRoot bool) bool { +func identifierExistsElsewhere(node *TreeNode, src []byte, name string, line, col uint, isRoot bool) bool { if node == nil { return false } @@ -400,7 +352,7 @@ func identifierExistsElsewhere(node *tree_sitter.Node, src []byte, name string, // boundary ONLY when the cursor is inside the do_block — not when it's on the // right side of a <- clause, which is evaluated in the outer scope. // Otherwise, the enclosing def/defp/defmacro/test call is the scope. -func findEnclosingScope(node *tree_sitter.Node, src []byte, varName string) *tree_sitter.Node { +func findEnclosingScope(node *TreeNode, src []byte, varName string) *TreeNode { prev := node current := node.Parent() for current != nil { @@ -437,7 +389,7 @@ func findEnclosingScope(node *tree_sitter.Node, src []byte, varName string) *tre } // Reached the file root without an inner scope: top-level script // bindings (e.g. config/runtime.exs) are scoped to the whole file. - if current.Kind() == "source" { + if current.Kind() == "source" && current.Parent() == nil { return current } prev = current @@ -447,7 +399,7 @@ func findEnclosingScope(node *tree_sitter.Node, src []byte, varName string) *tre } // nodeIsInsideDoBlock returns true if child is inside the do_block of callNode. -func nodeIsInsideDoBlock(callNode, child *tree_sitter.Node) bool { +func nodeIsInsideDoBlock(callNode, child *TreeNode) bool { for i := uint(0); i < uint(callNode.ChildCount()); i++ { block := callNode.Child(i) if block.Kind() == "do_block" && @@ -463,7 +415,7 @@ func nodeIsInsideDoBlock(callNode, child *tree_sitter.Node) bool { // given with/for call should act as a scope boundary: inside the do_block, // on a lvalue of <-/=, or on the rhs of clause N>0 (which references clause // N-1's binding, not the outer scope). -func cursorNeedsWithScope(callNode, prev, cursor *tree_sitter.Node, src []byte, varName string) bool { +func cursorNeedsWithScope(callNode, prev, cursor *TreeNode, src []byte, varName string) bool { if nodeIsInsideDoBlock(callNode, prev) { return true } @@ -491,7 +443,7 @@ func cursorNeedsWithScope(callNode, prev, cursor *tree_sitter.Node, src []byte, // varName within the given subtree, skipping function names in calls. // skipScopeCheck should be true when node is the scope root itself (so we // don't immediately bail out of the scope we chose). -func collectVariableOccurrences(node *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence, skipScopeCheck bool) { +func collectVariableOccurrences(node *TreeNode, src []byte, varName string, out *[]VariableOccurrence, skipScopeCheck bool) { if node == nil { return } @@ -551,7 +503,7 @@ func collectVariableOccurrences(node *tree_sitter.Node, src []byte, varName stri // stabBodyRebindsVariable returns true if the body of the stab_clause contains // an assignment (=) whose left-hand side unpinnedly binds varName. -func stabBodyRebindsVariable(stabClause *tree_sitter.Node, src []byte, varName string) bool { +func stabBodyRebindsVariable(stabClause *TreeNode, src []byte, varName string) bool { for i := uint(0); i < uint(stabClause.ChildCount()); i++ { child := stabClause.Child(i) if child.Kind() == "arguments" { @@ -566,7 +518,7 @@ func stabBodyRebindsVariable(stabClause *tree_sitter.Node, src []byte, varName s // subtreeContainsAssignmentOf returns true if the subtree has a binary "=" // whose lvalue unpinnedly binds varName. -func subtreeContainsAssignmentOf(node *tree_sitter.Node, src []byte, varName string) bool { +func subtreeContainsAssignmentOf(node *TreeNode, src []byte, varName string) bool { if node == nil { return false } @@ -587,7 +539,7 @@ func subtreeContainsAssignmentOf(node *tree_sitter.Node, src []byte, varName str // collectStabArgs collects variable occurrences from the args of a stab_clause // only (not the body). Used when the body rebinds the variable. -func collectStabArgs(stabClause *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence) { +func collectStabArgs(stabClause *TreeNode, src []byte, varName string, out *[]VariableOccurrence) { for i := uint(0); i < uint(stabClause.ChildCount()); i++ { child := stabClause.Child(i) if child.Kind() == "arguments" { @@ -601,7 +553,7 @@ func collectStabArgs(stabClause *tree_sitter.Node, src []byte, varName string, o // // @foo → unary_operator("@") → identifier("foo") // @foo value → unary_operator("@") → call → identifier("foo") … -func isModuleAttributeIdent(node *tree_sitter.Node, src []byte) bool { +func isModuleAttributeIdent(node *TreeNode, src []byte) bool { parent := node.Parent() if parent == nil { return false @@ -621,7 +573,7 @@ func isModuleAttributeIdent(node *tree_sitter.Node, src []byte) bool { } // isAtUnaryOp returns true if node is a unary_operator with the @ operator. -func isAtUnaryOp(node *tree_sitter.Node, src []byte) bool { +func isAtUnaryOp(node *TreeNode, src []byte) bool { if node.Kind() != "unary_operator" { return false } @@ -635,7 +587,7 @@ func isAtUnaryOp(node *tree_sitter.Node, src []byte) bool { } // findEnclosingModule walks up from node to find the nearest defmodule call. -func findEnclosingModule(node *tree_sitter.Node, src []byte) *tree_sitter.Node { +func findEnclosingModule(node *TreeNode, src []byte) *TreeNode { current := node.Parent() for current != nil { if current.Kind() == "call" && current.ChildCount() > 0 { @@ -652,7 +604,7 @@ func findEnclosingModule(node *tree_sitter.Node, src []byte) *tree_sitter.Node { // collectModuleAttributeOccurrences collects all @attrName occurrences within // the subtree — that is, identifier nodes named attrName that are part of a // module attribute expression (@attrName or @attrName value). -func collectModuleAttributeOccurrences(node *tree_sitter.Node, src []byte, attrName string, out *[]VariableOccurrence) { +func collectModuleAttributeOccurrences(node *TreeNode, src []byte, attrName string, out *[]VariableOccurrence) { if node == nil { return } @@ -673,23 +625,23 @@ func collectModuleAttributeOccurrences(node *tree_sitter.Node, src []byte, attrN // string search, this naturally skips strings, comments, atoms, and other // non-code contexts. func FindTokenOccurrences(src []byte, token string) []VariableOccurrence { - root, cleanup := parseElixir(src) - if root == nil { + tree := NewTree(src) + if tree == nil { return nil } - defer cleanup() - return FindTokenOccurrencesWithTree(root, src, token) + defer tree.Close() + return tree.FindTokenOccurrences(src, token) } // FindTokenOccurrencesWithTree is like FindTokenOccurrences but uses a // pre-parsed tree root. -func FindTokenOccurrencesWithTree(root *tree_sitter.Node, src []byte, token string) []VariableOccurrence { +func (t *Tree) FindTokenOccurrences(src []byte, token string) []VariableOccurrence { var occurrences []VariableOccurrence - collectTokenOccurrences(root, src, token, &occurrences) + collectTokenOccurrences(t.TrunkNode(), src, token, &occurrences) return occurrences } -func collectTokenOccurrences(node *tree_sitter.Node, src []byte, token string, out *[]VariableOccurrence) { +func collectTokenOccurrences(node *TreeNode, src []byte, token string, out *[]VariableOccurrence) { if node == nil { return } @@ -697,7 +649,7 @@ func collectTokenOccurrences(node *tree_sitter.Node, src []byte, token string, o kind := node.Kind() // Skip subtrees that can't contain meaningful identifier references - if kind == "string" || kind == "comment" || kind == "sigil" || kind == "charlist" { + if kind == "string" || kind == "comment" || kind == "charlist" { return } @@ -746,20 +698,20 @@ func collectTokenOccurrences(node *tree_sitter.Node, src []byte, token string, o // function scope. Respects clause boundaries: variables from other case/fn // clauses are excluded. Returns nil if the cursor is not inside a function. func FindVariablesInScope(src []byte, line, col uint) []string { - root, cleanup := parseElixir(src) - if root == nil { + tree := NewTree(src) + if tree == nil { return nil } - defer cleanup() - return FindVariablesInScopeWithTree(root, src, line, col) + defer tree.Close() + return tree.FindVariablesInScope(src, line, col) } // FindVariablesInScopeWithTree is like FindVariablesInScope but uses a // pre-parsed tree root. -func FindVariablesInScopeWithTree(root *tree_sitter.Node, src []byte, line, col uint) []string { - cursorNode := nodeAtPosition(root, line, col) +func (t *Tree) FindVariablesInScope(src []byte, line, col uint) []string { + cursorNode := t.TrunkNode().ChildAtPosition(line, col) if cursorNode == nil && col > 0 { - cursorNode = nodeAtPosition(root, line, col-1) + cursorNode = t.TrunkNode().ChildAtPosition(line, col-1) } if cursorNode == nil { return nil @@ -777,7 +729,7 @@ func FindVariablesInScopeWithTree(root *tree_sitter.Node, src []byte, line, col } // findEnclosingFunction walks up from node to find the nearest def/defp/etc scope. -func findEnclosingFunction(node *tree_sitter.Node, src []byte) *tree_sitter.Node { +func findEnclosingFunction(node *TreeNode, src []byte) *TreeNode { current := node.Parent() for current != nil { if current.Kind() == "call" && current.ChildCount() > 0 { @@ -795,12 +747,12 @@ func findEnclosingFunction(node *tree_sitter.Node, src []byte) *tree_sitter.Node // excluding function names, definition keywords, and module attributes. // Skips stab_clauses and do..end calls that don't contain the cursor, // since variables don't leak out of those scopes in Elixir. -func collectVariableNames(node *tree_sitter.Node, src []byte, seen map[string]bool, out *[]string, cursorLine, cursorCol uint) { +func collectVariableNames(node *TreeNode, src []byte, seen map[string]bool, out *[]string, cursorLine, cursorCol uint) { if node == nil { return } - if !nodeContainsPosition(node, cursorLine, cursorCol) { + if !node.ContainsPosition(cursorLine, cursorCol) { // Variables in other case/fn clauses are not in scope. if node.Kind() == "stab_clause" { return @@ -831,8 +783,8 @@ func collectVariableNames(node *tree_sitter.Node, src []byte, seen map[string]bo // extractArrowClauses returns the binary_operator nodes for <- and = in the // call's arguments, in source order. -func extractArrowClauses(callNode *tree_sitter.Node, src []byte) []*tree_sitter.Node { - var clauses []*tree_sitter.Node +func extractArrowClauses(callNode *TreeNode, src []byte) []*TreeNode { + var clauses []*TreeNode for i := uint(0); i < uint(callNode.ChildCount()); i++ { child := callNode.Child(i) if child.Kind() != "arguments" { @@ -861,7 +813,7 @@ func extractArrowClauses(callNode *tree_sitter.Node, src []byte) []*tree_sitter. // - Cursor on rhs1: uses lhs0's binding — collect lhs0 + rhs1 (+ further rhs until rebind) + body // - Cursor on lhs1: collect lhs1 + body // - Cursor in body: uses last clause's binding — collect last lhs + body -func collectWithOccurrences(callNode, cursor *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence) { +func collectWithOccurrences(callNode, cursor *TreeNode, src []byte, varName string, out *[]VariableOccurrence) { clauses := extractArrowClauses(callNode, src) // Find which clause and side the cursor is on @@ -884,7 +836,7 @@ func collectWithOccurrences(callNode, cursor *tree_sitter.Node, src []byte, varN } // Find the do_block - var doBlock *tree_sitter.Node + var doBlock *TreeNode for i := uint(0); i < uint(callNode.ChildCount()); i++ { child := callNode.Child(i) if child.Kind() == "do_block" { @@ -957,7 +909,7 @@ func collectWithOccurrences(callNode, cursor *tree_sitter.Node, src []byte, varN // of =/← binary operators in a call's arguments, processing clauses // sequentially. Once a clause's pattern (left side) rebinds varName, // subsequent clauses and the do_block use the new binding — so we stop. -func collectPatternExpressionOccurrences(callNode *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence) { +func collectPatternExpressionOccurrences(callNode *TreeNode, src []byte, varName string, out *[]VariableOccurrence) { for i := uint(0); i < uint(callNode.ChildCount()); i++ { child := callNode.Child(i) if child.Kind() != "arguments" { @@ -987,7 +939,7 @@ func collectPatternExpressionOccurrences(callNode *tree_sitter.Node, src []byte, // callArgumentPatternsBindVariable checks whether a call's argument patterns // (left side of = or <- operators) contain an unpinned binding of varName. -func callArgumentPatternsBindVariable(node *tree_sitter.Node, src []byte, varName string) bool { +func callArgumentPatternsBindVariable(node *TreeNode, src []byte, varName string) bool { for i := uint(0); i < uint(node.ChildCount()); i++ { child := node.Child(i) if child.Kind() != "arguments" { @@ -1008,7 +960,7 @@ func callArgumentPatternsBindVariable(node *tree_sitter.Node, src []byte, varNam return false } -func callHasDoBlock(node *tree_sitter.Node) bool { +func callHasDoBlock(node *TreeNode) bool { for i := uint(0); i < uint(node.ChildCount()); i++ { if node.Child(i).Kind() == "do_block" { return true @@ -1017,28 +969,11 @@ func callHasDoBlock(node *tree_sitter.Node) bool { return false } -// nodeContainsPosition returns true if the node's range includes the given position. -// Tree-sitter end positions are exclusive, consistent with nodeAtPosition. -func nodeContainsPosition(node *tree_sitter.Node, line, col uint) bool { - start := node.StartPosition() - end := node.EndPosition() - if line < uint(start.Row) || line > uint(end.Row) { - return false - } - if line == uint(start.Row) && col < uint(start.Column) { - return false - } - if line == uint(end.Row) && col >= uint(end.Column) { - return false - } - return true -} - // stabBindsVariable returns true if the stab_clause's arguments (pattern) // contain an unpinned identifier matching varName, meaning it creates a new // binding. Pinned variables (^varName) reference the outer scope and do NOT // create a new binding. -func stabBindsVariable(stabClause *tree_sitter.Node, src []byte, varName string) bool { +func stabBindsVariable(stabClause *TreeNode, src []byte, varName string) bool { for i := uint(0); i < uint(stabClause.ChildCount()); i++ { child := stabClause.Child(i) if child.Kind() == "arguments" { @@ -1051,7 +986,7 @@ func stabBindsVariable(stabClause *tree_sitter.Node, src []byte, varName string) // subtreeContainsUnpinnedIdentifier returns true if any identifier node in the // subtree has the given name AND is not pinned (^varName). Pinned variables // reference an outer binding and do not create a new one. -func subtreeContainsUnpinnedIdentifier(node *tree_sitter.Node, src []byte, name string) bool { +func subtreeContainsUnpinnedIdentifier(node *TreeNode, src []byte, name string) bool { if node == nil { return false } @@ -1071,7 +1006,7 @@ func subtreeContainsUnpinnedIdentifier(node *tree_sitter.Node, src []byte, name } // isPinOperator returns true if node is a unary_operator with the ^ operator. -func isPinOperator(node *tree_sitter.Node, src []byte) bool { +func isPinOperator(node *TreeNode, src []byte) bool { if node.Kind() != "unary_operator" { return false } diff --git a/internal/treesitter/variables_test.go b/internal/treesitter/variables_test.go index 4400e29..a61173b 100644 --- a/internal/treesitter/variables_test.go +++ b/internal/treesitter/variables_test.go @@ -655,11 +655,11 @@ func TestFindVariableOccurrences_FullWorkerFile(t *testing.T) { defdelegate backoff(job), to: MyApp.Oban.EmailWorker end`) - root, cleanup := parseElixir(src) - if root == nil { + tree := NewTree(src) + if tree == nil { t.Fatal("failed to parse") } - defer cleanup() + defer tree.Close() // Find the actual line for "transfer_amount = Money.new" in this test source lines := strings.Split(string(src), "\n") @@ -675,7 +675,7 @@ end`) } t.Logf("transfer_amount rebind is at line %d: %q", transferLine, lines[transferLine]) - occs := FindVariableOccurrences(src, uint(transferLine), 6) + occs := tree.FindVariableOccurrences(src, uint(transferLine), 6) t.Logf("transfer_amount from line %d col 6: %d occs: %+v", transferLine, len(occs), occs) if occs == nil { t.Fatal("expected variable occurrences for 'transfer_amount', got nil") @@ -1459,12 +1459,15 @@ end apply(config) `) - root, cleanup := parseElixir(src) - defer cleanup() + tree := NewTree(src) + if tree == nil { + t.Fatal("failed to parse") + } + defer tree.Close() // Renaming top-level "config" to "other" is safe: "other" only exists as a // def-local, which is a different scope. - if NameExistsInScopeOf(root, src, 0, 0, "other") { + if tree.NameExistsInScopeOf(src, 0, 0, "other") { t.Error("false-positive collision: 'other' is a def-local, not in the top-level scope") } } @@ -1555,5 +1558,6 @@ config :app, value: some_helper() occs := FindVariableOccurrences(src, 2, uint(len("config :app, value: "))) if occs != nil { t.Errorf("expected nil for bare top-level call, got %d occurrences: %+v", len(occs), occs) + } } diff --git a/internal/version/version.go b/internal/version/version.go index 2e2c620..416e634 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -5,4 +5,4 @@ const Version = "0.7.0" // IndexVersion is incremented whenever the index schema or parser changes in a // way that requires a full rebuild. Bump this alongside Version when releasing // a change that makes existing indexes stale. -const IndexVersion = 12 +const IndexVersion = 13