Skip to content

Commit 0109528

Browse files
committed
feat: add position details to linter
- adds ability for linter to return range to specify where the error gets returned - maps that correctly over for the lsp - for the builtin selectstar rule, add the range
1 parent 63fd842 commit 0109528

9 files changed

Lines changed: 288 additions & 93 deletions

File tree

sqlmesh/core/linter/definition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
99
from functools import reduce
1010
from sqlmesh.core.model import Model
11-
from sqlmesh.core.linter.rule import Rule, RuleViolation
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
1212
from sqlmesh.core.console import LinterConsole, get_console
1313

1414
if t.TYPE_CHECKING:
@@ -74,6 +74,7 @@ def lint_model(
7474
violation_msg=violation.violation_msg,
7575
model=model,
7676
violation_type="error",
77+
violation_range=violation.violation_range,
7778
)
7879
for violation in error_violations
7980
] + [
@@ -82,6 +83,7 @@ def lint_model(
8283
violation_msg=violation.violation_msg,
8384
model=model,
8485
violation_type="warning",
86+
violation_range=violation.violation_range,
8587
)
8688
for violation in warn_violations
8789
]
@@ -149,7 +151,8 @@ def __init__(
149151
violation_msg: str,
150152
model: Model,
151153
violation_type: t.Literal["error", "warning"],
154+
violation_range: t.Optional["Range"] = None,
152155
) -> None:
153-
super().__init__(rule, violation_msg)
156+
super().__init__(rule, violation_msg, violation_range)
154157
self.model = model
155158
self.violation_type = violation_type

sqlmesh/core/linter/helpers.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from pathlib import Path
2+
3+
from sqlmesh.core.linter.rule import Position, Range
4+
from sqlmesh.utils.pydantic import PydanticModel
5+
import typing as t
6+
7+
8+
class TokenPositionDetails(PydanticModel):
9+
"""
10+
Details about a token's position in the source code in the structure provided by SQLGlot.
11+
12+
Attributes:
13+
line (int): The line that the token ends on.
14+
col (int): The column that the token ends on.
15+
start (int): The start index of the token.
16+
end (int): The ending index of the token.
17+
"""
18+
19+
line: int
20+
col: int
21+
start: int
22+
end: int
23+
24+
@staticmethod
25+
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
26+
return TokenPositionDetails(
27+
line=meta["line"],
28+
col=meta["col"],
29+
start=meta["start"],
30+
end=meta["end"],
31+
)
32+
33+
def to_range(self, read_file: t.Optional[t.List[str]]) -> Range:
34+
"""
35+
Convert a TokenPositionDetails object to a Range object.
36+
37+
In the circumstances where the token's start and end positions are the same,
38+
there is no need for a read_file parameter, as the range can be derived from the token's
39+
line and column. This is an optimization to avoid unnecessary file reads and should
40+
only be used when the token represents a single character or position in the file.
41+
42+
If the token's start and end positions are different, the read_file parameter is required.
43+
44+
:param read_file: List of lines from the file. Optional
45+
:return: A Range object representing the token's position
46+
"""
47+
if self.start == self.end:
48+
# If the start and end positions are the same, we can create a range directly
49+
return Range(
50+
start=Position(line=self.line - 1, character=self.col - 1),
51+
end=Position(line=self.line - 1, character=self.col),
52+
)
53+
54+
if read_file is None:
55+
raise ValueError("read_file must be provided when start and end positions differ.")
56+
57+
# Convert from 1-indexed to 0-indexed for line only
58+
end_line_0 = self.line - 1
59+
end_col_0 = self.col
60+
61+
# Find the start line and column by counting backwards from the end position
62+
start_pos = self.start
63+
end_pos = self.end
64+
65+
# Initialize with the end position
66+
start_line_0 = end_line_0
67+
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
68+
69+
# If start_col_0 is negative, we need to go back to previous lines
70+
while start_col_0 < 0 and start_line_0 > 0:
71+
start_line_0 -= 1
72+
start_col_0 += len(read_file[start_line_0])
73+
# Account for newline character
74+
if start_col_0 >= 0:
75+
break
76+
start_col_0 += 1 # For the newline character
77+
78+
# Ensure we don't have negative values
79+
start_col_0 = max(0, start_col_0)
80+
return Range(
81+
start=Position(line=start_line_0, character=start_col_0),
82+
end=Position(line=end_line_0, character=end_col_0),
83+
)
84+
85+
86+
def read_range_from_file(file: Path, text_range: Range) -> str:
87+
"""
88+
Read the file and return the content within the specified range.
89+
90+
Args:
91+
file: Path to the file to read
92+
text_range: The range of text to extract
93+
94+
Returns:
95+
The content within the specified range
96+
"""
97+
with file.open("r") as f:
98+
lines = f.readlines()
99+
100+
# Ensure the range is within bounds
101+
start_line = max(0, text_range.start.line)
102+
end_line = min(len(lines), text_range.end.line + 1)
103+
104+
if start_line >= end_line:
105+
return ""
106+
107+
# Extract the relevant portions of each line
108+
result = []
109+
for i in range(start_line, end_line):
110+
line = lines[i]
111+
start_char = text_range.start.character if i == text_range.start.line else 0
112+
end_char = text_range.end.character if i == text_range.end.line else len(line)
113+
result.append(line[start_char:end_char])
114+
115+
return "".join(result)

sqlmesh/core/linter/rule.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@ class RuleLocation(PydanticModel):
2222
start_line: t.Optional[int] = None
2323

2424

25+
class Position(PydanticModel):
26+
"""The position of a rule violation in a file, the position follows the LSP standard."""
27+
28+
line: int
29+
character: int
30+
31+
32+
class Range(PydanticModel):
33+
"""The range of a rule violation in a file. The range follows the LSP standard."""
34+
35+
start: Position
36+
end: Position
37+
38+
2539
class _Rule(abc.ABCMeta):
2640
def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule:
2741
attrs["name"] = clsname.lower()
@@ -45,9 +59,15 @@ def summary(self) -> str:
4559
"""A summary of what this rule checks for."""
4660
return self.__doc__ or ""
4761

48-
def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation:
62+
def violation(
63+
self,
64+
violation_msg: t.Optional[str] = None,
65+
violation_range: t.Optional[Range] = None,
66+
) -> RuleViolation:
4967
"""Create a RuleViolation instance for this rule"""
50-
return RuleViolation(rule=self, violation_msg=violation_msg or self.summary)
68+
return RuleViolation(
69+
rule=self, violation_msg=violation_msg or self.summary, violation_range=violation_range
70+
)
5171

5272
def get_definition_location(self) -> RuleLocation:
5373
"""Return the file path and position information for this rule.
@@ -79,9 +99,12 @@ def __repr__(self) -> str:
7999

80100

81101
class RuleViolation:
82-
def __init__(self, rule: Rule, violation_msg: str) -> None:
102+
def __init__(
103+
self, rule: Rule, violation_msg: str, violation_range: t.Optional[Range] = None
104+
) -> None:
83105
self.rule = rule
84106
self.violation_msg = violation_msg
107+
self.violation_range = violation_range
85108

86109
def __repr__(self) -> str:
87110
return f"{self.rule.name}: {self.violation_msg}"

sqlmesh/core/linter/rules/builtin.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,39 @@
44

55
import typing as t
66

7+
from sqlglot.expressions import Star
78
from sqlglot.helper import subclasses
89

9-
from sqlmesh.core.linter.rule import Rule, RuleViolation
10+
from sqlmesh.core.linter.helpers import TokenPositionDetails
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
1012
from sqlmesh.core.linter.definition import RuleSet
1113
from sqlmesh.core.model import Model, SqlModel
1214

1315

1416
class NoSelectStar(Rule):
15-
"""Query should not contain SELECT * on its outer most projections, even if it can be expanded."""
17+
"""Query should not contain SELECT * on its outermost projections, even if it can be expanded."""
1618

1719
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
20+
# Only applies to SQL models, as other model types do not have a query.
1821
if not isinstance(model, SqlModel):
1922
return None
20-
21-
return self.violation() if model.query.is_star else None
23+
query = model.query
24+
if query.is_star:
25+
violation_range = self._get_range(model)
26+
return self.violation(violation_range=violation_range)
27+
return None
28+
29+
def _get_range(self, model: SqlModel) -> t.Optional[Range]:
30+
"""Get the range of the violation if available."""
31+
try:
32+
if len(model.query.expressions) == 1 and isinstance(model.query.expressions[0], Star):
33+
position = TokenPositionDetails.from_meta(model.query.expressions[0].meta)
34+
if position.start != position.end:
35+
return None
36+
return position.to_range(None)
37+
return None
38+
except Exception:
39+
return None
2240

2341

2442
class InvalidSelectStarExpansion(Rule):

sqlmesh/lsp/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlmesh.core.model.definition import SqlModel
77
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
8-
from sqlmesh.lsp.custom import RenderModelEntry, ModelForRendering
8+
from sqlmesh.lsp.custom import ModelForRendering
99
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1010
from sqlmesh.lsp.uri import URI
1111

sqlmesh/lsp/main.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,8 +618,24 @@ def _diagnostic_to_lsp_diagnostic(
618618
) -> t.Optional[types.Diagnostic]:
619619
if diagnostic.model._path is None:
620620
return None
621-
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
622-
lines = file.readlines()
621+
if not diagnostic.violation_range:
622+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
623+
lines = file.readlines()
624+
range = types.Range(
625+
start=types.Position(line=0, character=0),
626+
end=types.Position(line=len(lines) - 1, character=len(lines[-1])),
627+
)
628+
else:
629+
range = types.Range(
630+
start=types.Position(
631+
line=diagnostic.violation_range.start.line,
632+
character=diagnostic.violation_range.start.character,
633+
),
634+
end=types.Position(
635+
line=diagnostic.violation_range.end.line,
636+
character=diagnostic.violation_range.end.character,
637+
),
638+
)
623639

624640
# Get rule definition location for diagnostics link
625641
rule_location = diagnostic.rule.get_definition_location()
@@ -628,10 +644,7 @@ def _diagnostic_to_lsp_diagnostic(
628644

629645
# Use URI format to create a link for "related information"
630646
return types.Diagnostic(
631-
range=types.Range(
632-
start=types.Position(line=0, character=0),
633-
end=types.Position(line=len(lines), character=len(lines[-1])),
634-
),
647+
range=range,
635648
message=diagnostic.violation_msg,
636649
severity=types.DiagnosticSeverity.Error
637650
if diagnostic.violation_type == "error"

0 commit comments

Comments
 (0)