0 {
+ scanTagAttr := func(i, line int, lineStarts *[]int, tokens *[]Token) (int, int) {
+ var quoteChar byte
+ for i < len(source) {
ch := source[i]
- if ch == '\n' {
+
+ switch {
+ case ch == '\n':
+ *tokens = append(*tokens, Token{Kind: TokEOL, Start: i, End: i + 1, Line: line})
line++
i++
*lineStarts = append(*lineStarts, i)
- } else if escapes && ch == '\\' && i+1 < len(source) {
- if source[i+1] == '\n' {
- line++
- *lineStarts = append(*lineStarts, i+2)
+
+ case quoteChar == 0 && isWhitespace(ch):
+ return i, line
+
+ case quoteChar == 0 && (ch == '\'' || ch == '"'):
+ quoteChar = ch
+ i++
+
+ case quoteChar != 0 && ch == '\\':
+ i++
+ if i < len(source) {
+ i++
}
- i += 2
- } else if ch == openCh {
- depth++
+
+ case quoteChar != 0 && ch == quoteChar:
i++
- } else if ch == closeCh {
- depth--
+ return i, line
+
+ case quoteChar != 0:
i++
- } else {
+
+ case ch == '{':
+ i++
+ return scanInterpolation(i, line, "}", tokens)
+
+ case ch == '>':
+ return i, line
+
+ default:
i++
}
}
- } else {
+
+ return i, line
+ }
+
+ scanTag := func(i, line int, lineStarts *[]int, tokens *[]Token) (int, int) {
+ hasName := false
for i < len(source) {
- ch := source[i]
- if ch == '\n' {
+ switch {
+ case source[i] == '\n':
+ *tokens = append(*tokens, Token{Kind: TokEOL, Start: i, End: i + 1, Line: line})
line++
i++
*lineStarts = append(*lineStarts, i)
- } else if escapes && ch == '\\' && i+1 < len(source) {
- if source[i+1] == '\n' {
- line++
- *lineStarts = append(*lineStarts, i+2)
+
+ case isWhitespace(source[i]):
+ i++
+
+ // HTML tag name, function, or module/function
+ case !hasName:
+ i, line = scanTagName(i, line, tokens)
+ hasName = true
+
+ // attribute
+ case isLetter(source[i]):
+ i, line = scanTagAttr(i, line, lineStarts, tokens)
+
+ // HEEX special attribute ":if={}", ":for={}", ":let={}", ":type={}"
+ case source[i] == ':':
+ i++
+ i, line = scanTagAttr(i, line, lineStarts, tokens)
+
+ // self-closing tag
+ case source[i] == '/':
+ i++
+ if i < len(source) && source[i] == '>' {
+ i++
+ return i, line
}
- i += 2
- } else if ch == closeCh {
- i++ // consume closing delimiter
- break
- } else {
+
+ // finish open tag
+ case source[i] == '>':
+ i++
+ return i, line
+
+ default:
i++
}
}
+
+ return i, line
}
- // Consume trailing modifier letters (e.g. the 'i' in ~r/foo/i)
- for i < len(source) && isLetter(source[i]) {
- i++
+ for i < len(source) {
+ ch := source[i]
+
+ switch {
+ case ch == '\n':
+ tokens = append(tokens, Token{Kind: TokEOL, Start: i, End: i + 1, Line: line})
+ line++
+ i++
+ lineStarts = append(lineStarts, i)
+
+ case isWhitespace(ch):
+ i++
+
+ case ch == '<':
+ // consume <
+ i++
+ if i < len(source) {
+ if source[i] == '!' && i+2 < len(source) && source[i+1] == '-' && source[i+2] == '-' {
+ // HTML comment "", i, line, &lineStarts)
+ tokens = append(tokens, Token{Kind: TokComment, Start: start, End: i, Line: startLine})
+ } else if source[i] == '%' {
+ // consume %
+ i++
+ if i+2 < len(source) && source[i] == '!' && source[i+1] == '-' && source[i+2] == '-' {
+ // HEEX comment "<%!--"
+ i += 3
+ start := i - 5
+ startLine := line
+ i, line = scanComment("--%>", i, line, &lineStarts)
+ tokens = append(tokens, Token{Kind: TokComment, Start: start, End: i, Line: startLine})
+ } else if i < len(source) {
+ // consume "=" output indicator from "<%=" special form prefix
+ if source[i] == '=' {
+ i++
+ }
+ // EEX interpolation "<%"
+ // EEX special form "<% for", "<% if", "<% case", "<% cond", "<% else", "<% end", "<% _ ->"
+ i, line = scanInterpolation(i, line, "%>", &tokens)
+ }
+ } else if source[i] == '/' {
+ i++
+ tokens = append(tokens, Token{Kind: TokHEEXCloseTag, Start: i - 2, End: i, Line: line})
+ i, line = scanTagName(i, line, &tokens)
+ if i < len(source) && source[i] == '>' {
+ i++
+ }
+ } else {
+ // HTML tag "
= '0' && ch <= '9'
}
+// isWhitespace returns true for space, tab, and carriage return.
+func isWhitespace(ch byte) bool {
+ return ch == ' ' || ch == '\t' || ch == '\r'
+}
+
// isHexDigit returns true for [0-9a-fA-F].
func isHexDigit(ch byte) bool {
return isDigit(ch) || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F')
@@ -957,3 +1276,25 @@ func bytesEqual(b []byte, s string) bool {
func isKeywordKey(source []byte, i int) bool {
return i < len(source) && source[i] == ':' && (i+1 >= len(source) || source[i+1] != ':')
}
+
+// DebugTokens returns a string represention similar to %+v for a slice of tokens.
+func DebugTokens(source []byte, tokens []Token) string {
+ var s strings.Builder
+
+ for _, t := range tokens {
+ s.WriteString(t.Debug(source))
+ }
+
+ return s.String()
+}
+
+// Debug returns a string representation similar to %+v for a token.
+func (token Token) Debug(source []byte) string {
+ switch token.Kind {
+ case TokDot, TokEOL, TokEOF, TokOpenBrace, TokCloseBrace, TokHEEXOpenTag, TokHEEXCloseTag:
+ return fmt.Sprintf("%s (%d:%d)\n", token.Kind.String(), token.Start, token.End)
+
+ default:
+ return fmt.Sprintf("%s (%d:%d) %#v\n", token.Kind.String(), token.Start, token.End, TokenText(source, token))
+ }
+}
diff --git a/internal/parser/tokenizer_test.go b/internal/parser/tokenizer_test.go
index 7b38365..24515b1 100644
--- a/internal/parser/tokenizer_test.go
+++ b/internal/parser/tokenizer_test.go
@@ -1,9 +1,13 @@
package parser
import (
+ "context"
"fmt"
"strings"
"testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
)
// tokenizeNoEOF runs Tokenize and strips the trailing TokEOF for cleaner assertions.
@@ -2077,9 +2081,8 @@ func TestTokenize_EscapedNewlineLineTracking(t *testing.T) {
}
func TestLineStartsAccuracy(t *testing.T) {
- assertLineStarts := func(t *testing.T, src string) {
+ assertLineStarts := func(t *testing.T, src string, result TokenResult) {
t.Helper()
- result := TokenizeFull([]byte(src))
lineStarts := result.LineStarts
lines := strings.Split(src, "\n")
if len(lineStarts) != len(lines) {
@@ -2100,9 +2103,8 @@ func TestLineStartsAccuracy(t *testing.T) {
}
}
- assertTokenAt := func(t *testing.T, src string, line0, col int, wantKind TokenKind, wantText string) {
+ assertTokenAt := func(t *testing.T, src string, result TokenResult, line0, col int, wantKind TokenKind, wantText string) {
t.Helper()
- result := TokenizeFull([]byte(src))
offset := LineColToOffset(result.LineStarts, line0, col)
idx := TokenAtOffset(result.Tokens, offset)
if idx < 0 {
@@ -2119,25 +2121,129 @@ func TestLineStartsAccuracy(t *testing.T) {
t.Run("heredoc", func(t *testing.T) {
src := "defmodule MyApp.Example do\n @moduledoc \"\"\"\n This is a long\n multiline heredoc\n with several lines\n of documentation.\n \"\"\"\n\n @type t :: %__MODULE__{\n name: String.t(),\n age: Integer.t()\n }\n\n def hello do\n :world\n end\nend"
- assertLineStarts(t, src)
- assertTokenAt(t, src, 9, 16, TokModule, "String")
+ result := TokenizeFull([]byte(src))
+ assertLineStarts(t, src, result)
+ assertTokenAt(t, src, result, 9, 16, TokModule, "String")
})
t.Run("multiline string", func(t *testing.T) {
src := "x = \"line one\nline two\nline three\"\ny = Enum.map(list, fn x -> x end)"
- assertLineStarts(t, src)
- assertTokenAt(t, src, 3, 4, TokModule, "Enum")
+ result := TokenizeFull([]byte(src))
+ assertLineStarts(t, src, result)
+ assertTokenAt(t, src, result, 3, 4, TokModule, "Enum")
})
t.Run("sigil heredoc", func(t *testing.T) {
src := "x = ~s\"\"\"\nline one\nline two\n\"\"\"\ny = MyModule.func()"
- assertLineStarts(t, src)
- assertTokenAt(t, src, 4, 4, TokModule, "MyModule")
+ result := TokenizeFull([]byte(src))
+ assertLineStarts(t, src, result)
+ assertTokenAt(t, src, result, 4, 4, TokModule, "MyModule")
})
t.Run("multiline interpolation", func(t *testing.T) {
src := "x = \"hello #{\n some_func()\n}\"\ny = String.trim(x)"
- assertLineStarts(t, src)
- assertTokenAt(t, src, 3, 4, TokModule, "String")
+ result := TokenizeFull([]byte(src))
+ assertLineStarts(t, src, result)
+ assertTokenAt(t, src, result, 3, 4, TokModule, "String")
+ })
+
+ t.Run("HEEX: comment", func(t *testing.T) {
+ src := ""
+ result := TokenizeHeex([]byte(src))
+ assertLineStarts(t, src, result)
+ assertTokenAt(t, src, result, 0, 0, TokComment, "")
+ })
+
+ t.Run("HEEX: sigil contents", func(t *testing.T) {
+ src := "defmodule PageLive do\n def render(assigns) do\n ~H\"\"\"\n
\n \"\"\"\n end\nend"
+ result := TokenizeFull([]byte(src))
+ assertLineStarts(t, src, result)
+ assertTokenAt(t, src, result, 6, 2, TokEnd, "end")
+ })
+}
+
+func TestTokenizeHeex(t *testing.T) {
+ tests := []struct {
+ src, want string
+ }{
+ {"<%!-- hello, world! --%>",
+ `TokComment (0:24) "<%!-- hello, world! --%>"
+TokEOF (24:24)
+`},
+ {"
hello!
", `TokHEEXOpenTag (0:1)
+TokHEEXCloseTag (11:13)
+TokEOF (17:17)
+`},
+ {"<.foo>", `TokHEEXOpenTag (0:1)
+TokDot (1:2)
+TokIdent (2:5) "foo"
+TokHEEXCloseTag (6:8)
+TokDot (8:9)
+TokIdent (9:12) "foo"
+TokEOF (13:13)
+`},
+ {"<.foo />", `TokHEEXOpenTag (0:1)
+TokDot (1:2)
+TokIdent (2:5) "foo"
+TokEOF (8:8)
+`},
+ {"<.live_component id=\"foo\" module={Foo.Bar} no-value />", `TokHEEXOpenTag (0:1)
+TokDot (1:2)
+TokIdent (2:16) "live_component"
+TokModule (34:37) "Foo"
+TokDot (37:38)
+TokModule (38:41) "Bar"
+TokEOF (54:54)
+`},
+ {"
", `TokHEEXOpenTag (0:1)
+TokString (12:16) "\"{}\""
+TokEOF (20:20)
+`},
+ }
+
+ for _, tt := range tests {
+ err := withTimeout(2_000, func() {
+ result := TokenizeHeex([]byte(tt.src))
+ got := DebugTokens([]byte(tt.src), result.Tokens)
+ if diff := cmp.Diff(tt.want, got); diff != "" {
+ t.Errorf("TokenizeHeex(src) (-want +got)\n\n%.512s\n\n%s", tt.src, diff)
+ }
+ })
+ if err == context.DeadlineExceeded {
+ t.Errorf("TokenizeHeex(src) timeout after 2s\n\n%.512s", tt.src)
+ }
+ }
+}
+
+func FuzzTokenizeHeex(f *testing.F) {
+ f.Fuzz(func(t *testing.T, src string) {
+ err := withTimeout(2_000, func() {
+ result := TokenizeHeex([]byte(src))
+ // should always output at least TokEOF
+ if len(result.Tokens) == 0 {
+ t.Errorf("TokenizeHeex(src) empty output\n\n%.512s", src)
+ }
+ })
+ if err == context.DeadlineExceeded {
+ t.Errorf("TokenizeHeex(src) timeout after 2s\n\n%.512s", src)
+ }
})
}
+
+func withTimeout(ms time.Duration, cb func()) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ms*time.Millisecond)
+ defer cancel()
+
+ done := make(chan struct{})
+ go func() {
+ cb()
+ done <- struct{}{}
+ }()
+
+ select {
+ case <-done:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
diff --git a/internal/parser/tokenkind_string.go b/internal/parser/tokenkind_string.go
new file mode 100644
index 0000000..f1f0ba0
--- /dev/null
+++ b/internal/parser/tokenkind_string.go
@@ -0,0 +1,81 @@
+// Code generated by "stringer -type=TokenKind"; DO NOT EDIT.
+
+package parser
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[TokDefmodule-0]
+ _ = x[TokDef-1]
+ _ = x[TokDefp-2]
+ _ = x[TokDefmacro-3]
+ _ = x[TokDefmacrop-4]
+ _ = x[TokDefguard-5]
+ _ = x[TokDefguardp-6]
+ _ = x[TokDefdelegate-7]
+ _ = x[TokDefprotocol-8]
+ _ = x[TokDefimpl-9]
+ _ = x[TokDefstruct-10]
+ _ = x[TokDefexception-11]
+ _ = x[TokAlias-12]
+ _ = x[TokImport-13]
+ _ = x[TokUse-14]
+ _ = x[TokRequire-15]
+ _ = x[TokDo-16]
+ _ = x[TokEnd-17]
+ _ = x[TokFn-18]
+ _ = x[TokWhen-19]
+ _ = x[TokIdent-20]
+ _ = x[TokModule-21]
+ _ = x[TokAttr-22]
+ _ = x[TokAttrDoc-23]
+ _ = x[TokAttrSpec-24]
+ _ = x[TokAttrType-25]
+ _ = x[TokAttrBehaviour-26]
+ _ = x[TokAttrCallback-27]
+ _ = x[TokString-28]
+ _ = x[TokHeredoc-29]
+ _ = x[TokSigil-30]
+ _ = x[TokCharLiteral-31]
+ _ = x[TokAtom-32]
+ _ = x[TokDot-33]
+ _ = x[TokComma-34]
+ _ = x[TokColon-35]
+ _ = x[TokOpenParen-36]
+ _ = x[TokCloseParen-37]
+ _ = x[TokOpenBracket-38]
+ _ = x[TokCloseBracket-39]
+ _ = x[TokOpenBrace-40]
+ _ = x[TokCloseBrace-41]
+ _ = x[TokOpenAngle-42]
+ _ = x[TokCloseAngle-43]
+ _ = x[TokPipe-44]
+ _ = x[TokBackslash-45]
+ _ = x[TokRightArrow-46]
+ _ = x[TokLeftArrow-47]
+ _ = x[TokAssoc-48]
+ _ = x[TokDoubleColon-49]
+ _ = x[TokPercent-50]
+ _ = x[TokHEEXOpenTag-51]
+ _ = x[TokHEEXCloseTag-52]
+ _ = x[TokNumber-53]
+ _ = x[TokComment-54]
+ _ = x[TokEOL-55]
+ _ = x[TokEOF-56]
+ _ = x[TokOther-57]
+}
+
+const _TokenKind_name = "TokDefmoduleTokDefTokDefpTokDefmacroTokDefmacropTokDefguardTokDefguardpTokDefdelegateTokDefprotocolTokDefimplTokDefstructTokDefexceptionTokAliasTokImportTokUseTokRequireTokDoTokEndTokFnTokWhenTokIdentTokModuleTokAttrTokAttrDocTokAttrSpecTokAttrTypeTokAttrBehaviourTokAttrCallbackTokStringTokHeredocTokSigilTokCharLiteralTokAtomTokDotTokCommaTokColonTokOpenParenTokCloseParenTokOpenBracketTokCloseBracketTokOpenBraceTokCloseBraceTokOpenAngleTokCloseAngleTokPipeTokBackslashTokRightArrowTokLeftArrowTokAssocTokDoubleColonTokPercentTokHEEXOpenTagTokHEEXCloseTagTokNumberTokCommentTokEOLTokEOFTokOther"
+
+var _TokenKind_index = [...]uint16{0, 12, 18, 25, 36, 48, 59, 71, 85, 99, 109, 121, 136, 144, 153, 159, 169, 174, 180, 185, 192, 200, 209, 216, 226, 237, 248, 264, 279, 288, 298, 306, 320, 327, 333, 341, 349, 361, 374, 388, 403, 415, 428, 440, 453, 460, 472, 485, 497, 505, 519, 529, 543, 558, 567, 577, 583, 589, 597}
+
+func (i TokenKind) String() string {
+ idx := int(i) - 0
+ if i < 0 || idx >= len(_TokenKind_index)-1 {
+ return "TokenKind(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _TokenKind_name[_TokenKind_index[idx]:_TokenKind_index[idx+1]]
+}
diff --git a/internal/treesitter/tree.go b/internal/treesitter/tree.go
new file mode 100644
index 0000000..a6619f5
--- /dev/null
+++ b/internal/treesitter/tree.go
@@ -0,0 +1,306 @@
+package treesitter
+
+import (
+ "unsafe"
+
+ tree_sitter_heex "github.com/phoenixframework/tree-sitter-heex/bindings/go"
+ tree_sitter "github.com/tree-sitter/go-tree-sitter"
+ tree_sitter_elixir "github.com/tree-sitter/tree-sitter-elixir/bindings/go"
+)
+
+// Tree contains a document trunk tree and a map of any branch sub-trees.
+// For Elixir trunks, Branches is a map of `quoted_content` node IDs within sigils
+// in the document tree to their corresponding HEEX sub-tree. For HEEX trunks,
+// Branches is a map of `expression_value` node IDs within interpolated expressions
+// in the document tree to their corresponding Elixir sub-tree. Sub-trees may
+// be nested arbitrarily deep, though in practice it will typically be 1-3 levels.
+//
+// For nested sub-trees, Root points back to the parent tree that contains the
+// sub-tree. Navigation is possible both up (using Parent()) and down (using Child(i)).
+//
+// Elixir->HEEX: (sigil (sigil_name) node: (quoted_content))
+// HEEX->Elixir: (expression node: (expression_value))
+type Tree struct {
+ Root *TreeNode
+ Trunk *tree_sitter.Tree
+ Branches map[uintptr]*Tree
+ Language Language
+}
+
+// TrunkNode returns a TreeNode pointing to the root node of the trunk.
+func (t *Tree) TrunkNode() *TreeNode {
+ return &TreeNode{Tree: t, Node: t.Trunk.RootNode()}
+}
+
+// Close recursively closes the trunk tree and any branch sub-trees.
+func (t *Tree) Close() {
+ for _, b := range t.Branches {
+ b.Close()
+ }
+ t.Trunk.Close()
+}
+
+// TreeNode represents a node within a tree or sub-tree.
+// This facilitates traversal between trunk trees and branch sub-trees.
+type TreeNode struct {
+ Tree *Tree
+ Node *tree_sitter.Node
+}
+
+// See tree_sitter.Node.Kind().
+func (tn *TreeNode) Kind() string {
+ return tn.Node.Kind()
+}
+
+// See tree_sitter.Node.IsNamed().
+func (tn *TreeNode) IsNamed() bool {
+ return tn.Node.IsNamed()
+}
+
+// See tree_sitter.Node.ToSexp().
+func (tn *TreeNode) ToSexp() string {
+ return tn.Node.ToSexp()
+}
+
+// See tree_sitter.Node.StartByte().
+func (tn *TreeNode) StartByte() uint {
+ if tn.Tree.Root == nil {
+ return tn.Node.StartByte()
+ }
+ return tn.Tree.Root.StartByte() + tn.Node.StartByte()
+}
+
+// See tree_sitter.Node.EndByte().
+func (tn *TreeNode) EndByte() uint {
+ if tn.Tree.Root == nil {
+ return tn.Node.EndByte()
+ }
+ return tn.Tree.Root.StartByte() + tn.Node.EndByte()
+}
+
+// Parent returns the node containing the given node in the tree, or the node
+// in the root tree that contains the node if the node is the root of a branch
+// sub-tree. If the node is the top-most root, returns nil.
+func (tn *TreeNode) Parent() *TreeNode {
+ if parent := tn.Node.Parent(); parent != nil {
+ return &TreeNode{Tree: tn.Tree, Node: parent}
+ }
+ return tn.Tree.Root
+}
+
+// ChildCount returns the number of children for the given node, returning
+// 1 for nodes that link to a branch sub-tree.
+func (tn *TreeNode) ChildCount() uint {
+ if branch := tn.Tree.Branches[tn.Node.Id()]; branch != nil {
+ return 1
+ }
+ return tn.Node.ChildCount()
+}
+
+// Child returns the tree/child of the given node, moving into a sub-tree if
+// the node links to a branch sub-tree.
+func (tn *TreeNode) Child(i uint) *TreeNode {
+ if branch := tn.Tree.Branches[tn.Node.Id()]; branch != nil {
+ return branch.TrunkNode()
+ }
+ return &TreeNode{Tree: tn.Tree, Node: tn.Node.Child(i)}
+}
+
+// StartPosition returns the (row, col) start position of the given node
+// within the top-most root tree.
+func (tn *TreeNode) StartPosition() tree_sitter.Point {
+ if tn.Tree.Root == nil {
+ return tn.Node.StartPosition()
+ }
+ p := tn.Tree.Root.StartPosition()
+ sp := tn.Node.StartPosition()
+ p.Row += sp.Row
+ if sp.Row == 0 {
+ p.Column += sp.Column
+ } else {
+ p.Column = sp.Column
+ }
+ return p
+}
+
+// EndPosition returns the (row, col) end position of the given node
+// within the top-most root tree.
+func (tn *TreeNode) EndPosition() tree_sitter.Point {
+ if tn.Tree.Root == nil {
+ return tn.Node.EndPosition()
+ }
+ p := tn.Tree.Root.StartPosition()
+ ep := tn.Node.EndPosition()
+ p.Row += ep.Row
+ if ep.Row == 0 {
+ p.Column += ep.Column
+ } else {
+ p.Column = ep.Column
+ }
+ return p
+}
+
+// Utf8Text returns the UTF-8 encoded string representation of the given node
+// within the top-most root tree.
+func (tn *TreeNode) Utf8Text(src []byte) string {
+ if tn.Tree.Root == nil {
+ return tn.Node.Utf8Text(src)
+ }
+ return tn.Tree.Root.Utf8Text(src)[tn.Node.StartByte():tn.Node.EndByte()]
+}
+
+// ContainsPosition returns true if the node contains the given position
+// in the top-most root tree. Tree-sitter end positions are exclusive,
+// consistent with nodeAtPosition.
+func (tn *TreeNode) ContainsPosition(line, col uint) bool {
+ start := tn.StartPosition()
+ end := tn.EndPosition()
+ if line < uint(start.Row) || line > uint(end.Row) {
+ return false
+ }
+ if line == uint(start.Row) && col < uint(start.Column) {
+ return false
+ }
+ if line == uint(end.Row) && col >= uint(end.Column) {
+ return false
+ }
+ return true
+}
+
+// ChildAtPosition find the deepest (most specific) child node at the given position
+// within the top-most root tree.
+func (tn *TreeNode) ChildAtPosition(line, col uint) *TreeNode {
+ // Check if position is within this node
+ if tn == nil || !tn.ContainsPosition(line, col) {
+ return nil
+ }
+
+ // Try to find a more specific child
+ for i := uint(0); i < tn.ChildCount(); i++ {
+ if found := tn.Child(i).ChildAtPosition(line, col); found != nil {
+ return found
+ }
+ }
+
+ return tn
+}
+
+// NewTree creates parsers, parses src, parses nested HEEX templates, and returns the created trees.
+// Used by the standalone (non-cached) entry points. Returns nil on failure.
+func NewTree(src []byte) *Tree {
+ parsers := AllParsers()
+ if parsers == nil {
+ return nil
+ }
+ for _, p := range parsers {
+ defer p.Close()
+ }
+ return NewTreeWithParsers(src, parsers)
+}
+
+// NewTreeWithParsers parses src, parses nested HEEX templates, and returns the created trees.
+// Used by cached entry points . Returns nil on failure.
+func NewTreeWithParsers(src []byte, parsers map[Language]*tree_sitter.Parser) *Tree {
+ return newTree(LangElixir, src, parsers)
+}
+
+func newTree(lang Language, src []byte, parsers map[Language]*tree_sitter.Parser) *Tree {
+ trunk := parsers[lang].Parse(src, nil)
+ if trunk == nil {
+ return nil
+ }
+
+ t := &Tree{
+ Language: lang,
+ Trunk: trunk,
+ Branches: make(map[uintptr]*Tree),
+ }
+
+ visitTree(trunk.RootNode(), func(node *tree_sitter.Node) {
+ // when visiting Elixir trees, parse nested ~H sigils as HEEX sub-trees
+ if lang == LangElixir &&
+ node.Kind() == "quoted_content" &&
+ node.Parent() != nil && node.Parent().Kind() == "sigil" &&
+ /* sigil_name */ node.PrevNamedSibling() != nil && node.PrevNamedSibling().Utf8Text(src) == "H" {
+ if tree := newTree(LangHeex, src[node.StartByte():node.EndByte()], parsers); tree != nil {
+ tree.Root = &TreeNode{Tree: t, Node: node}
+ t.Branches[node.Id()] = tree
+ }
+ }
+
+ // when visiting HEEX trees, parse nested expressions as Elixir sub-trees
+ if lang == LangHeex && node.Kind() == "expression_value" {
+ if tree := newTree(LangElixir, src[node.StartByte():node.EndByte()], parsers); tree != nil {
+ tree.Root = &TreeNode{Tree: t, Node: node}
+ t.Branches[node.Id()] = tree
+ }
+ }
+ })
+
+ return t
+}
+
+type Language byte
+
+const (
+ LangElixir Language = iota
+ LangHeex
+)
+
+func NewParser(lang Language) *tree_sitter.Parser {
+ var language unsafe.Pointer
+ switch lang {
+ case LangElixir:
+ language = tree_sitter_elixir.Language()
+ case LangHeex:
+ language = tree_sitter_heex.Language()
+ }
+
+ p := tree_sitter.NewParser()
+ if err := p.SetLanguage(tree_sitter.NewLanguage(language)); err != nil {
+ return nil
+ }
+
+ return p
+}
+
+func AllParsers() map[Language]*tree_sitter.Parser {
+ parsers := make(map[Language]*tree_sitter.Parser)
+
+ for _, l := range []Language{LangElixir, LangHeex} {
+ p := NewParser(l)
+ if p == nil {
+ // if a parser fails to initialize, close any already-opened parsers
+ for _, pp := range parsers {
+ pp.Close()
+ }
+ return nil
+ }
+ parsers[l] = p
+ }
+
+ return parsers
+}
+
+func visitTree(root *tree_sitter.Node, onNode func(node *tree_sitter.Node)) {
+ cursor := root.Walk()
+ defer cursor.Close()
+
+ for {
+ // visit current node
+ onNode(cursor.Node())
+
+ // traverse down one level, if possible
+ if cursor.GotoFirstChild() {
+ continue
+ }
+
+ // traverse via siblings, if possible
+ for !cursor.GotoNextSibling() {
+ // move back up and recurse, returning once we're back to the root
+ if !cursor.GotoParent() {
+ return
+ }
+ }
+ }
+}
diff --git a/internal/treesitter/tree_test.go b/internal/treesitter/tree_test.go
new file mode 100644
index 0000000..b6dbec2
--- /dev/null
+++ b/internal/treesitter/tree_test.go
@@ -0,0 +1,75 @@
+package treesitter
+
+import (
+ "maps"
+ "slices"
+ "testing"
+)
+
+func TestNewTree(t *testing.T) {
+ src := `def render(assigns) do
+ ~H"""
+
+ <%= bar() %>
+
+ """
+end`
+ tree := NewTree([]byte(src))
+ if tree.Language != LangElixir {
+ t.Errorf("expected Elixir root tree, got %#v", tree.Language)
+ }
+
+ heexNodeIds := slices.Collect(maps.Keys(tree.Branches))
+ if len(heexNodeIds) != 1 {
+ t.Errorf("expected 1 Heex branch, got %d", len(heexNodeIds))
+ }
+ heexTree := tree.Branches[heexNodeIds[0]]
+ if heexTree.Language != LangHeex {
+ t.Errorf("expected Heex branch sub-tree, got %#v", heexTree.Language)
+ }
+ if rootId := heexTree.Root.Node.Id(); rootId != heexNodeIds[0] {
+ t.Errorf("expected Heex root to match branch node ID %d, got %d", heexNodeIds[0], rootId)
+ }
+ wantHeex := "
\n <%= bar() %>\n
\n "
+ if heexText := heexTree.TrunkNode().Utf8Text([]byte(src)); heexText != wantHeex {
+ t.Errorf("unexpected Heex text (-want, +got)\n- %#v\n+ %#v", wantHeex, heexText)
+ }
+
+ exNodeIds := slices.Collect(maps.Keys(heexTree.Branches))
+ if len(exNodeIds) != 2 {
+ t.Errorf("expected 2 Elixir branch, got %d", len(exNodeIds))
+ }
+ for _, branch := range heexTree.Branches {
+ if exText := branch.TrunkNode().Utf8Text([]byte(src)); !slices.Contains([]string{"foo()", "bar()"}, exText) {
+ t.Errorf("unexpected nested Elixir text, got %#v", exText)
+ }
+ }
+}
+
+func TestTreeNode_ByteAndPosition(t *testing.T) {
+ src := `def render(assigns) do
+ ~H"""
+
+ <%= bar() %>
+
+ """
+end`
+
+ tree := NewTree([]byte(src))
+ // bar() on line 4 col 8
+ node := tree.TrunkNode().ChildAtPosition(3, 8)
+ text := node.Utf8Text([]byte(src))
+ if node.StartByte() != 61 {
+ t.Errorf("expected %#v to start at byte %d, got %d", text, 61, node.StartByte())
+ }
+ if node.EndByte() != 64 {
+ t.Errorf("expected %#v to end at byte %d, got %d", text, 64, node.EndByte())
+ }
+
+ if sp := node.StartPosition(); sp.Row != 3 || sp.Column != 8 {
+ t.Errorf("expected %#v to start at position (Row: %d, Col: %d), got (Row: %d, Col: %d)", text, 0, 0, sp.Row, sp.Column)
+ }
+ if ep := node.EndPosition(); ep.Row != 3 || ep.Column != 11 {
+ t.Errorf("expected %#v to end at position (Row: %d, Col: %d), got (Row: %d, Col: %d)", text, 0, 0, ep.Row, ep.Column)
+ }
+}
diff --git a/internal/treesitter/variables.go b/internal/treesitter/variables.go
index 282ee81..4d563a0 100644
--- a/internal/treesitter/variables.go
+++ b/internal/treesitter/variables.go
@@ -2,27 +2,8 @@ package treesitter
import (
"strings"
-
- tree_sitter "github.com/tree-sitter/go-tree-sitter"
- tree_sitter_elixir "github.com/tree-sitter/tree-sitter-elixir/bindings/go"
)
-// parseElixir creates a parser, parses src, and returns the root node plus a
-// cleanup function that closes both tree and parser. Used by the standalone
-// (non-cached) entry points. Returns (nil, nil) on failure.
-func parseElixir(src []byte) (root *tree_sitter.Node, cleanup func()) {
- p := tree_sitter.NewParser()
- if err := p.SetLanguage(tree_sitter.NewLanguage(tree_sitter_elixir.Language())); err != nil {
- p.Close()
- return nil, nil
- }
- tree := p.Parse(src, nil)
- return tree.RootNode(), func() {
- tree.Close()
- p.Close()
- }
-}
-
// VariableOccurrence is a position where a variable name appears.
type VariableOccurrence struct {
Line uint // 0-based
@@ -34,18 +15,18 @@ type VariableOccurrence struct {
// occurrences of the variable at the given cursor position within the
// enclosing function scope. Returns nil if the cursor is not on a variable.
func FindVariableOccurrences(src []byte, line, col uint) []VariableOccurrence {
- root, cleanup := parseElixir(src)
- if root == nil {
+ tree := NewTree(src)
+ if tree == nil {
return nil
}
- defer cleanup()
- return FindVariableOccurrencesWithTree(root, src, line, col)
+ defer tree.Close()
+ return tree.FindVariableOccurrences(src, line, col)
}
// FindVariableOccurrencesWithTree is like FindVariableOccurrences but uses a
// pre-parsed tree root, avoiding redundant parsing when a cached tree exists.
-func FindVariableOccurrencesWithTree(root *tree_sitter.Node, src []byte, line, col uint) []VariableOccurrence {
- resolved := resolveVariableScope(root, src, line, col)
+func (t *Tree) FindVariableOccurrences(src []byte, line, col uint) []VariableOccurrence {
+ resolved := t.resolveVariableScope(src, line, col)
if resolved == nil {
return nil
}
@@ -93,8 +74,8 @@ func FindVariableOccurrencesWithTree(root *tree_sitter.Node, src []byte, line, c
//
// Bare identifiers that are zero-arity function calls (not bound as variables)
// are NOT considered collisions — in Elixir, a variable simply shadows them.
-func NameExistsInScopeOf(root *tree_sitter.Node, src []byte, line, col uint, newName string) bool {
- resolved := resolveVariableScope(root, src, line, col)
+func (t *Tree) NameExistsInScopeOf(src []byte, line, col uint, newName string) bool {
+ resolved := t.resolveVariableScope(src, line, col)
if resolved == nil {
return false
}
@@ -113,7 +94,7 @@ func NameExistsInScopeOf(root *tree_sitter.Node, src []byte, line, col uint, new
// rather than a bare zero-arity function call. Reuses the full variable
// resolution logic so the same scoping rules apply.
pos := target.StartPosition()
- return len(FindVariableOccurrencesWithTree(root, src, uint(pos.Row), uint(pos.Column))) > 0
+ return len(t.FindVariableOccurrences(src, uint(pos.Row), uint(pos.Column))) > 0
}
// findFirstNonCallIdentifier returns the first identifier node in the subtree
@@ -122,11 +103,11 @@ func NameExistsInScopeOf(root *tree_sitter.Node, src []byte, line, col uint, new
// descended into — a same-named binding inside one is not a collision in the
// scope rooted at node. (The root itself may be such a def call when renaming a
// function-local; that is the chosen scope and is always searched.)
-func findFirstNonCallIdentifier(node *tree_sitter.Node, src []byte, name string) *tree_sitter.Node {
+func findFirstNonCallIdentifier(node *TreeNode, src []byte, name string) *TreeNode {
return findFirstNonCallIdentifierInScope(node, src, name, true)
}
-func findFirstNonCallIdentifierInScope(node *tree_sitter.Node, src []byte, name string, isRoot bool) *tree_sitter.Node {
+func findFirstNonCallIdentifierInScope(node *TreeNode, src []byte, name string, isRoot bool) *TreeNode {
if node == nil {
return nil
}
@@ -146,8 +127,8 @@ func findFirstNonCallIdentifierInScope(node *tree_sitter.Node, src []byte, name
// resolvedScope holds the result of locating a variable's scope.
type resolvedScope struct {
- cursorNode *tree_sitter.Node
- scope *tree_sitter.Node
+ cursorNode *TreeNode
+ scope *TreeNode
varName string
moduleAttribute bool // true when the identifier is a module attribute (@foo)
}
@@ -155,8 +136,9 @@ type resolvedScope struct {
// resolveVariableScope locates the cursor node at (line, col), validates it as
// a variable or module attribute, and returns the enclosing scope. Returns nil
// if the position is not on a renameable variable.
-func resolveVariableScope(root *tree_sitter.Node, src []byte, line, col uint) *resolvedScope {
- cursorNode := nodeAtPosition(root, line, col)
+func (t *Tree) resolveVariableScope(src []byte, line, col uint) *resolvedScope {
+ cursorNode := t.TrunkNode().ChildAtPosition(line, col)
+
if cursorNode == nil || cursorNode.Kind() != "identifier" {
return nil
}
@@ -202,7 +184,7 @@ func resolveVariableScope(root *tree_sitter.Node, src []byte, line, col uint) *r
}
// moduleAttributeExists returns true if @name appears in the subtree.
-func moduleAttributeExists(node *tree_sitter.Node, src []byte, name string) bool {
+func moduleAttributeExists(node *TreeNode, src []byte, name string) bool {
if node == nil {
return false
}
@@ -217,40 +199,10 @@ func moduleAttributeExists(node *tree_sitter.Node, src []byte, name string) bool
return false
}
-// nodeAtPosition finds the deepest (most specific) node at the given position.
-func nodeAtPosition(node *tree_sitter.Node, line, col uint) *tree_sitter.Node {
- if node == nil {
- return nil
- }
- start := node.StartPosition()
- end := node.EndPosition()
-
- // Check if position is within this node
- if line < uint(start.Row) || line > uint(end.Row) {
- return nil
- }
- if line == uint(start.Row) && col < uint(start.Column) {
- return nil
- }
- if line == uint(end.Row) && col >= uint(end.Column) {
- return nil
- }
-
- // Try to find a more specific child
- for i := uint(0); i < uint(node.ChildCount()); i++ {
- child := node.Child(i)
- if found := nodeAtPosition(child, line, col); found != nil {
- return found
- }
- }
-
- return node
-}
-
// isFunctionNameInCall returns true if the identifier is the function name
// in a call expression (e.g., `foo` in `foo(args)`) or a function name being
// defined (e.g., `foo` in `def foo(args) do`).
-func isFunctionNameInCall(node *tree_sitter.Node, src []byte) bool {
+func isFunctionNameInCall(node *TreeNode, src []byte) bool {
parent := node.Parent()
if parent == nil {
return false
@@ -315,7 +267,7 @@ var moduleKeywords = map[string]bool{
// function definition do not leak to (and cannot reference) an enclosing
// module/script scope, so traversals rooted at an outer scope must not descend
// into these.
-func isFunctionDefinitionCall(node *tree_sitter.Node, src []byte) bool {
+func isFunctionDefinitionCall(node *TreeNode, src []byte) bool {
if node.Kind() != "call" || node.ChildCount() == 0 {
return false
}
@@ -325,7 +277,7 @@ func isFunctionDefinitionCall(node *tree_sitter.Node, src []byte) bool {
// isModuleDefinitionCall reports whether node is a defmodule/defprotocol/defimpl
// call, which opens a module-body scope.
-func isModuleDefinitionCall(node *tree_sitter.Node, src []byte) bool {
+func isModuleDefinitionCall(node *TreeNode, src []byte) bool {
if node.Kind() != "call" || node.ChildCount() == 0 {
return false
}
@@ -337,13 +289,13 @@ func isModuleDefinitionCall(node *tree_sitter.Node, src []byte) bool {
// variable scope — a function or module definition. A traversal rooted at an
// outer scope (a module body, or the whole file) must not descend into these,
// or a rename/collision check would wrongly reach into an unrelated scope.
-func definesNestedScope(node *tree_sitter.Node, src []byte) bool {
+func definesNestedScope(node *TreeNode, src []byte) bool {
return isFunctionDefinitionCall(node, src) || isModuleDefinitionCall(node, src)
}
// isAssignmentTarget returns true if node is on the left-hand side of a `=`
// binary operator, meaning it is unambiguously a variable binding.
-func isAssignmentTarget(node *tree_sitter.Node, src []byte) bool {
+func isAssignmentTarget(node *TreeNode, src []byte) bool {
parent := node.Parent()
if parent == nil || parent.Kind() != "binary_operator" || parent.ChildCount() < 3 {
return false
@@ -360,7 +312,7 @@ func isAssignmentTarget(node *tree_sitter.Node, src []byte) bool {
// at a position other than the cursor. A bare identifier that only appears
// at the cursor position is ambiguous (could be a zero-arity function call)
// and should not be treated as a variable.
-func variableDefinedInScope(scope *tree_sitter.Node, src []byte, varName string, cursorLine, cursorCol uint) bool {
+func variableDefinedInScope(scope *TreeNode, src []byte, varName string, cursorLine, cursorCol uint) bool {
return identifierExistsElsewhere(scope, src, varName, cursorLine, cursorCol, true)
}
@@ -371,7 +323,7 @@ func variableDefinedInScope(scope *tree_sitter.Node, src []byte, varName string,
// the chosen scope itself, which may be such a def call) — otherwise a bare
// top-level call sharing a name with a function-local would be misread as a
// variable.
-func identifierExistsElsewhere(node *tree_sitter.Node, src []byte, name string, line, col uint, isRoot bool) bool {
+func identifierExistsElsewhere(node *TreeNode, src []byte, name string, line, col uint, isRoot bool) bool {
if node == nil {
return false
}
@@ -400,7 +352,7 @@ func identifierExistsElsewhere(node *tree_sitter.Node, src []byte, name string,
// boundary ONLY when the cursor is inside the do_block — not when it's on the
// right side of a <- clause, which is evaluated in the outer scope.
// Otherwise, the enclosing def/defp/defmacro/test call is the scope.
-func findEnclosingScope(node *tree_sitter.Node, src []byte, varName string) *tree_sitter.Node {
+func findEnclosingScope(node *TreeNode, src []byte, varName string) *TreeNode {
prev := node
current := node.Parent()
for current != nil {
@@ -437,7 +389,7 @@ func findEnclosingScope(node *tree_sitter.Node, src []byte, varName string) *tre
}
// Reached the file root without an inner scope: top-level script
// bindings (e.g. config/runtime.exs) are scoped to the whole file.
- if current.Kind() == "source" {
+ if current.Kind() == "source" && current.Parent() == nil {
return current
}
prev = current
@@ -447,7 +399,7 @@ func findEnclosingScope(node *tree_sitter.Node, src []byte, varName string) *tre
}
// nodeIsInsideDoBlock returns true if child is inside the do_block of callNode.
-func nodeIsInsideDoBlock(callNode, child *tree_sitter.Node) bool {
+func nodeIsInsideDoBlock(callNode, child *TreeNode) bool {
for i := uint(0); i < uint(callNode.ChildCount()); i++ {
block := callNode.Child(i)
if block.Kind() == "do_block" &&
@@ -463,7 +415,7 @@ func nodeIsInsideDoBlock(callNode, child *tree_sitter.Node) bool {
// given with/for call should act as a scope boundary: inside the do_block,
// on a lvalue of <-/=, or on the rhs of clause N>0 (which references clause
// N-1's binding, not the outer scope).
-func cursorNeedsWithScope(callNode, prev, cursor *tree_sitter.Node, src []byte, varName string) bool {
+func cursorNeedsWithScope(callNode, prev, cursor *TreeNode, src []byte, varName string) bool {
if nodeIsInsideDoBlock(callNode, prev) {
return true
}
@@ -491,7 +443,7 @@ func cursorNeedsWithScope(callNode, prev, cursor *tree_sitter.Node, src []byte,
// varName within the given subtree, skipping function names in calls.
// skipScopeCheck should be true when node is the scope root itself (so we
// don't immediately bail out of the scope we chose).
-func collectVariableOccurrences(node *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence, skipScopeCheck bool) {
+func collectVariableOccurrences(node *TreeNode, src []byte, varName string, out *[]VariableOccurrence, skipScopeCheck bool) {
if node == nil {
return
}
@@ -551,7 +503,7 @@ func collectVariableOccurrences(node *tree_sitter.Node, src []byte, varName stri
// stabBodyRebindsVariable returns true if the body of the stab_clause contains
// an assignment (=) whose left-hand side unpinnedly binds varName.
-func stabBodyRebindsVariable(stabClause *tree_sitter.Node, src []byte, varName string) bool {
+func stabBodyRebindsVariable(stabClause *TreeNode, src []byte, varName string) bool {
for i := uint(0); i < uint(stabClause.ChildCount()); i++ {
child := stabClause.Child(i)
if child.Kind() == "arguments" {
@@ -566,7 +518,7 @@ func stabBodyRebindsVariable(stabClause *tree_sitter.Node, src []byte, varName s
// subtreeContainsAssignmentOf returns true if the subtree has a binary "="
// whose lvalue unpinnedly binds varName.
-func subtreeContainsAssignmentOf(node *tree_sitter.Node, src []byte, varName string) bool {
+func subtreeContainsAssignmentOf(node *TreeNode, src []byte, varName string) bool {
if node == nil {
return false
}
@@ -587,7 +539,7 @@ func subtreeContainsAssignmentOf(node *tree_sitter.Node, src []byte, varName str
// collectStabArgs collects variable occurrences from the args of a stab_clause
// only (not the body). Used when the body rebinds the variable.
-func collectStabArgs(stabClause *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence) {
+func collectStabArgs(stabClause *TreeNode, src []byte, varName string, out *[]VariableOccurrence) {
for i := uint(0); i < uint(stabClause.ChildCount()); i++ {
child := stabClause.Child(i)
if child.Kind() == "arguments" {
@@ -601,7 +553,7 @@ func collectStabArgs(stabClause *tree_sitter.Node, src []byte, varName string, o
//
// @foo → unary_operator("@") → identifier("foo")
// @foo value → unary_operator("@") → call → identifier("foo") …
-func isModuleAttributeIdent(node *tree_sitter.Node, src []byte) bool {
+func isModuleAttributeIdent(node *TreeNode, src []byte) bool {
parent := node.Parent()
if parent == nil {
return false
@@ -621,7 +573,7 @@ func isModuleAttributeIdent(node *tree_sitter.Node, src []byte) bool {
}
// isAtUnaryOp returns true if node is a unary_operator with the @ operator.
-func isAtUnaryOp(node *tree_sitter.Node, src []byte) bool {
+func isAtUnaryOp(node *TreeNode, src []byte) bool {
if node.Kind() != "unary_operator" {
return false
}
@@ -635,7 +587,7 @@ func isAtUnaryOp(node *tree_sitter.Node, src []byte) bool {
}
// findEnclosingModule walks up from node to find the nearest defmodule call.
-func findEnclosingModule(node *tree_sitter.Node, src []byte) *tree_sitter.Node {
+func findEnclosingModule(node *TreeNode, src []byte) *TreeNode {
current := node.Parent()
for current != nil {
if current.Kind() == "call" && current.ChildCount() > 0 {
@@ -652,7 +604,7 @@ func findEnclosingModule(node *tree_sitter.Node, src []byte) *tree_sitter.Node {
// collectModuleAttributeOccurrences collects all @attrName occurrences within
// the subtree — that is, identifier nodes named attrName that are part of a
// module attribute expression (@attrName or @attrName value).
-func collectModuleAttributeOccurrences(node *tree_sitter.Node, src []byte, attrName string, out *[]VariableOccurrence) {
+func collectModuleAttributeOccurrences(node *TreeNode, src []byte, attrName string, out *[]VariableOccurrence) {
if node == nil {
return
}
@@ -673,23 +625,23 @@ func collectModuleAttributeOccurrences(node *tree_sitter.Node, src []byte, attrN
// string search, this naturally skips strings, comments, atoms, and other
// non-code contexts.
func FindTokenOccurrences(src []byte, token string) []VariableOccurrence {
- root, cleanup := parseElixir(src)
- if root == nil {
+ tree := NewTree(src)
+ if tree == nil {
return nil
}
- defer cleanup()
- return FindTokenOccurrencesWithTree(root, src, token)
+ defer tree.Close()
+ return tree.FindTokenOccurrences(src, token)
}
// FindTokenOccurrencesWithTree is like FindTokenOccurrences but uses a
// pre-parsed tree root.
-func FindTokenOccurrencesWithTree(root *tree_sitter.Node, src []byte, token string) []VariableOccurrence {
+func (t *Tree) FindTokenOccurrences(src []byte, token string) []VariableOccurrence {
var occurrences []VariableOccurrence
- collectTokenOccurrences(root, src, token, &occurrences)
+ collectTokenOccurrences(t.TrunkNode(), src, token, &occurrences)
return occurrences
}
-func collectTokenOccurrences(node *tree_sitter.Node, src []byte, token string, out *[]VariableOccurrence) {
+func collectTokenOccurrences(node *TreeNode, src []byte, token string, out *[]VariableOccurrence) {
if node == nil {
return
}
@@ -697,7 +649,7 @@ func collectTokenOccurrences(node *tree_sitter.Node, src []byte, token string, o
kind := node.Kind()
// Skip subtrees that can't contain meaningful identifier references
- if kind == "string" || kind == "comment" || kind == "sigil" || kind == "charlist" {
+ if kind == "string" || kind == "comment" || kind == "charlist" {
return
}
@@ -746,20 +698,20 @@ func collectTokenOccurrences(node *tree_sitter.Node, src []byte, token string, o
// function scope. Respects clause boundaries: variables from other case/fn
// clauses are excluded. Returns nil if the cursor is not inside a function.
func FindVariablesInScope(src []byte, line, col uint) []string {
- root, cleanup := parseElixir(src)
- if root == nil {
+ tree := NewTree(src)
+ if tree == nil {
return nil
}
- defer cleanup()
- return FindVariablesInScopeWithTree(root, src, line, col)
+ defer tree.Close()
+ return tree.FindVariablesInScope(src, line, col)
}
// FindVariablesInScopeWithTree is like FindVariablesInScope but uses a
// pre-parsed tree root.
-func FindVariablesInScopeWithTree(root *tree_sitter.Node, src []byte, line, col uint) []string {
- cursorNode := nodeAtPosition(root, line, col)
+func (t *Tree) FindVariablesInScope(src []byte, line, col uint) []string {
+ cursorNode := t.TrunkNode().ChildAtPosition(line, col)
if cursorNode == nil && col > 0 {
- cursorNode = nodeAtPosition(root, line, col-1)
+ cursorNode = t.TrunkNode().ChildAtPosition(line, col-1)
}
if cursorNode == nil {
return nil
@@ -777,7 +729,7 @@ func FindVariablesInScopeWithTree(root *tree_sitter.Node, src []byte, line, col
}
// findEnclosingFunction walks up from node to find the nearest def/defp/etc scope.
-func findEnclosingFunction(node *tree_sitter.Node, src []byte) *tree_sitter.Node {
+func findEnclosingFunction(node *TreeNode, src []byte) *TreeNode {
current := node.Parent()
for current != nil {
if current.Kind() == "call" && current.ChildCount() > 0 {
@@ -795,12 +747,12 @@ func findEnclosingFunction(node *tree_sitter.Node, src []byte) *tree_sitter.Node
// excluding function names, definition keywords, and module attributes.
// Skips stab_clauses and do..end calls that don't contain the cursor,
// since variables don't leak out of those scopes in Elixir.
-func collectVariableNames(node *tree_sitter.Node, src []byte, seen map[string]bool, out *[]string, cursorLine, cursorCol uint) {
+func collectVariableNames(node *TreeNode, src []byte, seen map[string]bool, out *[]string, cursorLine, cursorCol uint) {
if node == nil {
return
}
- if !nodeContainsPosition(node, cursorLine, cursorCol) {
+ if !node.ContainsPosition(cursorLine, cursorCol) {
// Variables in other case/fn clauses are not in scope.
if node.Kind() == "stab_clause" {
return
@@ -831,8 +783,8 @@ func collectVariableNames(node *tree_sitter.Node, src []byte, seen map[string]bo
// extractArrowClauses returns the binary_operator nodes for <- and = in the
// call's arguments, in source order.
-func extractArrowClauses(callNode *tree_sitter.Node, src []byte) []*tree_sitter.Node {
- var clauses []*tree_sitter.Node
+func extractArrowClauses(callNode *TreeNode, src []byte) []*TreeNode {
+ var clauses []*TreeNode
for i := uint(0); i < uint(callNode.ChildCount()); i++ {
child := callNode.Child(i)
if child.Kind() != "arguments" {
@@ -861,7 +813,7 @@ func extractArrowClauses(callNode *tree_sitter.Node, src []byte) []*tree_sitter.
// - Cursor on rhs1: uses lhs0's binding — collect lhs0 + rhs1 (+ further rhs until rebind) + body
// - Cursor on lhs1: collect lhs1 + body
// - Cursor in body: uses last clause's binding — collect last lhs + body
-func collectWithOccurrences(callNode, cursor *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence) {
+func collectWithOccurrences(callNode, cursor *TreeNode, src []byte, varName string, out *[]VariableOccurrence) {
clauses := extractArrowClauses(callNode, src)
// Find which clause and side the cursor is on
@@ -884,7 +836,7 @@ func collectWithOccurrences(callNode, cursor *tree_sitter.Node, src []byte, varN
}
// Find the do_block
- var doBlock *tree_sitter.Node
+ var doBlock *TreeNode
for i := uint(0); i < uint(callNode.ChildCount()); i++ {
child := callNode.Child(i)
if child.Kind() == "do_block" {
@@ -957,7 +909,7 @@ func collectWithOccurrences(callNode, cursor *tree_sitter.Node, src []byte, varN
// of =/← binary operators in a call's arguments, processing clauses
// sequentially. Once a clause's pattern (left side) rebinds varName,
// subsequent clauses and the do_block use the new binding — so we stop.
-func collectPatternExpressionOccurrences(callNode *tree_sitter.Node, src []byte, varName string, out *[]VariableOccurrence) {
+func collectPatternExpressionOccurrences(callNode *TreeNode, src []byte, varName string, out *[]VariableOccurrence) {
for i := uint(0); i < uint(callNode.ChildCount()); i++ {
child := callNode.Child(i)
if child.Kind() != "arguments" {
@@ -987,7 +939,7 @@ func collectPatternExpressionOccurrences(callNode *tree_sitter.Node, src []byte,
// callArgumentPatternsBindVariable checks whether a call's argument patterns
// (left side of = or <- operators) contain an unpinned binding of varName.
-func callArgumentPatternsBindVariable(node *tree_sitter.Node, src []byte, varName string) bool {
+func callArgumentPatternsBindVariable(node *TreeNode, src []byte, varName string) bool {
for i := uint(0); i < uint(node.ChildCount()); i++ {
child := node.Child(i)
if child.Kind() != "arguments" {
@@ -1008,7 +960,7 @@ func callArgumentPatternsBindVariable(node *tree_sitter.Node, src []byte, varNam
return false
}
-func callHasDoBlock(node *tree_sitter.Node) bool {
+func callHasDoBlock(node *TreeNode) bool {
for i := uint(0); i < uint(node.ChildCount()); i++ {
if node.Child(i).Kind() == "do_block" {
return true
@@ -1017,28 +969,11 @@ func callHasDoBlock(node *tree_sitter.Node) bool {
return false
}
-// nodeContainsPosition returns true if the node's range includes the given position.
-// Tree-sitter end positions are exclusive, consistent with nodeAtPosition.
-func nodeContainsPosition(node *tree_sitter.Node, line, col uint) bool {
- start := node.StartPosition()
- end := node.EndPosition()
- if line < uint(start.Row) || line > uint(end.Row) {
- return false
- }
- if line == uint(start.Row) && col < uint(start.Column) {
- return false
- }
- if line == uint(end.Row) && col >= uint(end.Column) {
- return false
- }
- return true
-}
-
// stabBindsVariable returns true if the stab_clause's arguments (pattern)
// contain an unpinned identifier matching varName, meaning it creates a new
// binding. Pinned variables (^varName) reference the outer scope and do NOT
// create a new binding.
-func stabBindsVariable(stabClause *tree_sitter.Node, src []byte, varName string) bool {
+func stabBindsVariable(stabClause *TreeNode, src []byte, varName string) bool {
for i := uint(0); i < uint(stabClause.ChildCount()); i++ {
child := stabClause.Child(i)
if child.Kind() == "arguments" {
@@ -1051,7 +986,7 @@ func stabBindsVariable(stabClause *tree_sitter.Node, src []byte, varName string)
// subtreeContainsUnpinnedIdentifier returns true if any identifier node in the
// subtree has the given name AND is not pinned (^varName). Pinned variables
// reference an outer binding and do not create a new one.
-func subtreeContainsUnpinnedIdentifier(node *tree_sitter.Node, src []byte, name string) bool {
+func subtreeContainsUnpinnedIdentifier(node *TreeNode, src []byte, name string) bool {
if node == nil {
return false
}
@@ -1071,7 +1006,7 @@ func subtreeContainsUnpinnedIdentifier(node *tree_sitter.Node, src []byte, name
}
// isPinOperator returns true if node is a unary_operator with the ^ operator.
-func isPinOperator(node *tree_sitter.Node, src []byte) bool {
+func isPinOperator(node *TreeNode, src []byte) bool {
if node.Kind() != "unary_operator" {
return false
}
diff --git a/internal/treesitter/variables_test.go b/internal/treesitter/variables_test.go
index 4400e29..a61173b 100644
--- a/internal/treesitter/variables_test.go
+++ b/internal/treesitter/variables_test.go
@@ -655,11 +655,11 @@ func TestFindVariableOccurrences_FullWorkerFile(t *testing.T) {
defdelegate backoff(job), to: MyApp.Oban.EmailWorker
end`)
- root, cleanup := parseElixir(src)
- if root == nil {
+ tree := NewTree(src)
+ if tree == nil {
t.Fatal("failed to parse")
}
- defer cleanup()
+ defer tree.Close()
// Find the actual line for "transfer_amount = Money.new" in this test source
lines := strings.Split(string(src), "\n")
@@ -675,7 +675,7 @@ end`)
}
t.Logf("transfer_amount rebind is at line %d: %q", transferLine, lines[transferLine])
- occs := FindVariableOccurrences(src, uint(transferLine), 6)
+ occs := tree.FindVariableOccurrences(src, uint(transferLine), 6)
t.Logf("transfer_amount from line %d col 6: %d occs: %+v", transferLine, len(occs), occs)
if occs == nil {
t.Fatal("expected variable occurrences for 'transfer_amount', got nil")
@@ -1459,12 +1459,15 @@ end
apply(config)
`)
- root, cleanup := parseElixir(src)
- defer cleanup()
+ tree := NewTree(src)
+ if tree == nil {
+ t.Fatal("failed to parse")
+ }
+ defer tree.Close()
// Renaming top-level "config" to "other" is safe: "other" only exists as a
// def-local, which is a different scope.
- if NameExistsInScopeOf(root, src, 0, 0, "other") {
+ if tree.NameExistsInScopeOf(src, 0, 0, "other") {
t.Error("false-positive collision: 'other' is a def-local, not in the top-level scope")
}
}
@@ -1555,5 +1558,6 @@ config :app, value: some_helper()
occs := FindVariableOccurrences(src, 2, uint(len("config :app, value: ")))
if occs != nil {
t.Errorf("expected nil for bare top-level call, got %d occurrences: %+v", len(occs), occs)
+
}
}
diff --git a/internal/version/version.go b/internal/version/version.go
index 2e2c620..416e634 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -5,4 +5,4 @@ const Version = "0.7.0"
// IndexVersion is incremented whenever the index schema or parser changes in a
// way that requires a full rebuild. Bump this alongside Version when releasing
// a change that makes existing indexes stale.
-const IndexVersion = 12
+const IndexVersion = 13