-
Notifications
You must be signed in to change notification settings - Fork 380
Expand file tree
/
Copy pathhelpers.py
More file actions
256 lines (207 loc) · 7.77 KB
/
helpers.py
File metadata and controls
256 lines (207 loc) · 7.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from pathlib import Path
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
from sqlmesh.core.linter.rule import Range, Position, TextEdit
from sqlmesh.utils.pydantic import PydanticModel
from sqlglot import tokenize, TokenType
import typing as t
class TokenPositionDetails(PydanticModel):
"""
Details about a token's position in the source code in the structure provided by SQLGlot.
Attributes:
line (int): The line that the token ends on.
col (int): The column that the token ends on.
start (int): The start index of the token.
end (int): The ending index of the token.
"""
line: int
col: int
start: int
end: int
@staticmethod
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
return TokenPositionDetails(
line=meta["line"],
col=meta["col"],
start=meta["start"],
end=meta["end"],
)
def to_range(self, read_file: t.Optional[t.List[str]]) -> Range:
"""
Convert a TokenPositionDetails object to a Range object.
In the circumstances where the token's start and end positions are the same,
there is no need for a read_file parameter, as the range can be derived from the token's
line and column. This is an optimization to avoid unnecessary file reads and should
only be used when the token represents a single character or position in the file.
If the token's start and end positions are different, the read_file parameter is required.
:param read_file: List of lines from the file. Optional
:return: A Range object representing the token's position
"""
if self.start == self.end:
# If the start and end positions are the same, we can create a range directly
return Range(
start=Position(line=self.line - 1, character=self.col - 1),
end=Position(line=self.line - 1, character=self.col),
)
if read_file is None:
raise ValueError("read_file must be provided when start and end positions differ.")
# Convert from 1-indexed to 0-indexed for line only
end_line_0 = self.line - 1
end_col_0 = self.col
# Find the start line and column by counting backwards from the end position
start_pos = self.start
end_pos = self.end
# Initialize with the end position
start_line_0 = end_line_0
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
# If start_col_0 is negative, we need to go back to previous lines
while start_col_0 < 0 and start_line_0 > 0:
start_line_0 -= 1
start_col_0 += len(read_file[start_line_0])
# Account for newline character
if start_col_0 >= 0:
break
start_col_0 += 1 # For the newline character
# Ensure we don't have negative values
start_col_0 = max(0, start_col_0)
return Range(
start=Position(line=start_line_0, character=start_col_0),
end=Position(line=end_line_0, character=end_col_0),
)
def read_range_from_string(content: str, text_range: Range) -> str:
lines = content.splitlines(keepends=False)
# Ensure the range is within bounds
start_line = max(0, text_range.start.line)
end_line = min(len(lines), text_range.end.line + 1)
if start_line >= end_line:
return ""
# Extract the relevant portions of each line
result = []
for i in range(start_line, end_line):
line = lines[i]
start_char = text_range.start.character if i == text_range.start.line else 0
end_char = text_range.end.character if i == text_range.end.line else len(line)
result.append(line[start_char:end_char])
return "".join(result)
def read_range_from_file(file: Path, text_range: Range) -> str:
"""
Read the file and return the content within the specified range.
Args:
file: Path to the file to read
text_range: The range of text to extract
Returns:
The content within the specified range
"""
with file.open("r", encoding="utf-8") as f:
lines = f.readlines()
return read_range_from_string("".join(lines), text_range)
def get_range_of_model_block(
sql: str,
dialect: str,
) -> t.Optional[Range]:
"""
Get the range of the model block in an SQL file.
"""
tokens = tokenize(sql, dialect=dialect)
# Find start of the model block
start = next(
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
None,
)
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)
if start is None or end is None:
return None
start_position = TokenPositionDetails(
line=start.line,
col=start.col,
start=start.start,
end=start.end,
)
end_position = TokenPositionDetails(
line=end.line,
col=end.col,
start=end.start,
end=end.end,
)
splitlines = sql.splitlines()
return Range(
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
)
def get_range_of_a_key_in_model_block(
sql: str,
dialect: str,
key: str,
) -> t.Optional[Range]:
"""
Get the range of a specific key in the model block of an SQL file.
"""
tokens = tokenize(sql, dialect=dialect)
if tokens is None:
return None
# Find the start of the model block
start_index = next(
(
i
for i, t in enumerate(tokens)
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
),
None,
)
end_index = next(
(i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON),
None,
)
if start_index is None or end_index is None:
return None
if start_index >= end_index:
return None
tokens_of_interest = tokens[start_index + 1 : end_index]
# Find the key token
key_token = next(
(
t
for t in tokens_of_interest
if t.token_type is TokenType.VAR and t.text.upper() == key.upper()
),
None,
)
if key_token is None:
return None
position = TokenPositionDetails(
line=key_token.line,
col=key_token.col,
start=key_token.start,
end=key_token.end,
)
return position.to_range(sql.splitlines())
def apply_text_edits(path: Path, edits: t.Sequence[TextEdit]) -> None:
"""Apply a sequence of TextEdits to a file."""
if not edits:
return
with open(path, "r", encoding="utf-8") as file:
content = file.read()
lines = content.splitlines(keepends=True)
offsets = [0]
for line in lines:
offsets.append(offsets[-1] + len(line))
def to_offset(pos: Position) -> int:
line = min(pos.line, len(lines) - 1)
char = min(pos.character, len(lines[line]))
return offsets[line] + char
sorted_edits = sorted(
edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True
)
for edit in sorted_edits:
start = to_offset(edit.range.start)
end = to_offset(edit.range.end)
content = content[:start] + edit.new_text + content[end:]
with open(path, "w", encoding="utf-8") as file:
file.write(content)
def apply_fixes(violations: t.Iterable[AnnotatedRuleViolation]) -> None:
"""Apply fixes from the provided violations."""
edits_by_path: dict[Path, list[TextEdit]] = {}
for violation in violations:
for fix in violation.fixes:
for edit in fix.edits:
edits_by_path.setdefault(edit.path, []).append(edit)
for path, edits in edits_by_path.items():
apply_text_edits(path, edits)