|
5 | 5 | import types |
6 | 6 | import typing as t |
7 | 7 | from enum import Enum |
8 | | -from functools import reduce |
| 8 | +from functools import lru_cache, reduce |
9 | 9 | from itertools import chain |
10 | 10 | from pathlib import Path |
11 | 11 | from string import Template |
@@ -237,7 +237,7 @@ def evaluate_macros( |
237 | 237 | self.transform(value) if isinstance(value, exp.Expression) else value |
238 | 238 | ) |
239 | 239 | if isinstance(node, exp.Identifier) and "@" in node.this: |
240 | | - text = self.template(node.this, self.locals) |
| 240 | + text = self.template(node.this, {}) |
241 | 241 | if node.this != text: |
242 | 242 | changed = True |
243 | 243 | node.args["this"] = text |
@@ -287,18 +287,9 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str: |
287 | 287 | for k, v in chain(variables.items(), self.locals.items(), local_variables.items()): |
288 | 288 | # try to convert all variables into sqlglot expressions |
289 | 289 | # because they're going to be converted into strings in sql |
290 | | - # we use bare Exception instead of ValueError because there's |
291 | | - # a recursive error with MagicMock. |
292 | 290 | # we don't convert strings because that would result in adding quotes |
293 | | - if not isinstance(v, str): |
294 | | - try: |
295 | | - v = exp.convert(v) |
296 | | - except Exception: |
297 | | - pass |
298 | | - |
299 | | - if isinstance(v, exp.Expression): |
300 | | - v = v.sql(dialect=self.dialect) |
301 | | - mapping[k] = v |
| 291 | + if k != c.SQLMESH_VARS: |
| 292 | + mapping[k] = convert_sql(v, self.dialect) |
302 | 293 |
|
303 | 294 | return MacroStrTemplate(str(text)).safe_substitute(mapping) |
304 | 295 |
|
@@ -1378,3 +1369,30 @@ def _coerce( |
1378 | 1369 | f"Coercion of expression '{expr}' to type '{typ}' failed. Using non coerced expression at '{path}'", |
1379 | 1370 | ) |
1380 | 1371 | return expr |
| 1372 | + |
| 1373 | + |
| 1374 | +def convert_sql(v: t.Any, dialect: DialectType) -> t.Any: |
| 1375 | + try: |
| 1376 | + return _cache_convert_sql(v, dialect, v.__class__) |
| 1377 | + # dicts aren't hashable but are convertable |
| 1378 | + except TypeError: |
| 1379 | + return _convert_sql(v, dialect) |
| 1380 | + |
| 1381 | + |
| 1382 | +def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any: |
| 1383 | + if not isinstance(v, str): |
| 1384 | + try: |
| 1385 | + v = exp.convert(v) |
| 1386 | + # we use bare Exception instead of ValueError because there's |
| 1387 | + # a recursive error with MagicMock. |
| 1388 | + except Exception: |
| 1389 | + pass |
| 1390 | + |
| 1391 | + if isinstance(v, exp.Expression): |
| 1392 | + v = v.sql(dialect=dialect) |
| 1393 | + return v |
| 1394 | + |
| 1395 | + |
| 1396 | +@lru_cache(maxsize=1028) |
| 1397 | +def _cache_convert_sql(v: t.Any, dialect: DialectType, t: type) -> t.Any: |
| 1398 | + return _convert_sql(v, dialect) |
0 commit comments