Skip to content

Commit a38b49a

Browse files
committed
Refactor to track metadata objects using a separate value in env
1 parent a4a4123 commit a38b49a

5 files changed

Lines changed: 44 additions & 56 deletions

File tree

sqlmesh/core/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def make_python_env(
4242
) -> t.Dict[str, Executable]:
4343
python_env = {} if python_env is None else python_env
4444
variables = variables or {}
45-
env: t.Dict[str, t.Any] = {}
45+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
4646
used_macros = {}
4747
used_variables = (used_variables or set()).copy()
4848

sqlmesh/core/model/decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def model(
125125
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
126126
) -> Model:
127127
"""Get the model registered by this function."""
128-
env: t.Dict[str, t.Any] = {}
128+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
129129
entrypoint = self.func.__name__
130130

131131
if not self.name_provided and not infer_names:

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2482,7 +2482,7 @@ def _create_model(
24822482
dialect=dialect,
24832483
)
24842484

2485-
env: t.Dict[str, t.Any] = {}
2485+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
24862486

24872487
for signal_name, _ in model.signals:
24882488
if signal_definitions and signal_name in signal_definitions:

sqlmesh/utils/metaprogramming.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def normalize_source(obj: t.Any) -> str:
266266
def build_env(
267267
obj: t.Any,
268268
*,
269-
env: t.Dict[str, t.Any],
269+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]],
270270
name: str,
271271
path: Path,
272272
) -> None:
@@ -291,16 +291,12 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
291291
visited.add(name)
292292
name_missing_from_env = name not in env
293293

294-
if name_missing_from_env or (
295-
not is_metadata and env[name] == obj and getattr(env[name], c.SQLMESH_METADATA, None)
296-
):
294+
if name_missing_from_env or (not is_metadata and env[name] == (obj, True)):
297295
if not name_missing_from_env:
298296
# The existing object in the env is "metadata only" but we're walking it again as a
299297
# non-"metadata only" dependency, so we update this flag to ensure all transitive
300298
# dependencies are also not marked as "metadata only"
301-
is_metadata = False
302-
if hasattr(obj, c.SQLMESH_METADATA):
303-
delattr(obj, c.SQLMESH_METADATA)
299+
is_metadata = None
304300

305301
if hasattr(obj, c.SQLMESH_MACRO):
306302
# We only need to add the undecorated code of @macro() functions in env, which
@@ -320,14 +316,9 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
320316
or not hasattr(obj_module, "__file__")
321317
or not _is_relative_to(obj_module.__file__, path)
322318
):
323-
if is_metadata:
324-
setattr(obj, c.SQLMESH_METADATA, True)
325-
elif hasattr(obj, c.SQLMESH_METADATA):
326-
delattr(obj, c.SQLMESH_METADATA)
327-
328-
env[name] = obj
319+
env[name] = (obj, is_metadata)
329320
return
330-
elif env[name] != obj:
321+
elif env[name][0] != obj:
331322
raise SQLMeshError(
332323
f"Cannot store {obj} in environment, duplicate definitions found for '{name}'"
333324
)
@@ -359,13 +350,10 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
359350
for k, v in func_globals(obj).items():
360351
walk(v, k, is_metadata)
361352

362-
if is_metadata:
363-
setattr(obj, c.SQLMESH_METADATA, True)
364-
365353
# We store the object in the environment after its dependencies, because otherwise we
366354
# could crash at environment hydration time, since dicts are ordered and the top-level
367355
# objects would be loaded before their dependencies.
368-
env[name] = obj
356+
env[name] = (obj, is_metadata)
369357

370358
# The "metadata only" annotation of the object is transitive
371359
walk(obj, name, getattr(obj, c.SQLMESH_METADATA, None))
@@ -416,8 +404,8 @@ def is_value(self) -> bool:
416404
return self.kind == ExecutableKind.VALUE
417405

418406
@classmethod
419-
def value(cls, v: t.Any) -> Executable:
420-
return Executable(payload=repr(v), kind=ExecutableKind.VALUE)
407+
def value(cls, v: t.Any, is_metadata: t.Optional[bool] = None) -> Executable:
408+
return Executable(payload=repr(v), kind=ExecutableKind.VALUE, is_metadata=is_metadata)
421409

422410

423411
def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]:
@@ -431,11 +419,9 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
431419
"""
432420
serialized = {}
433421

434-
for k, v in env.items():
435-
is_metadata = getattr(v, c.SQLMESH_METADATA, None)
436-
422+
for k, (v, is_metadata) in env.items():
437423
if isinstance(v, LITERALS) or v is None:
438-
serialized[k] = Executable.value(v)
424+
serialized[k] = Executable.value(v, is_metadata=is_metadata)
439425
elif inspect.ismodule(v):
440426
name = v.__name__
441427
if hasattr(v, "__file__") and _is_relative_to(v.__file__, path):
@@ -471,7 +457,6 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
471457
v = wrapped
472458
file_path = Path(inspect.getfile(wrapped))
473459
relative_obj_file_path = _is_relative_to(file_path, path)
474-
is_metadata = is_metadata or getattr(v, c.SQLMESH_METADATA, None)
475460
except TypeError:
476461
file_path = None
477462
relative_obj_file_path = False

tests/utils/test_metaprogramming.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytest_mock.plugin import MockerFixture
1212
from sqlglot import exp
1313
from sqlglot import exp as expressions
14-
from sqlglot.expressions import to_table
14+
from sqlglot.expressions import SQLGLOT_META, to_table
1515
from sqlglot.optimizer.pushdown_projections import SELECT_ALL
1616

1717
import tests.utils.test_date as test_date
@@ -100,13 +100,6 @@ def other_func(a: int) -> int:
100100
return X + a + W
101101

102102

103-
def noop_metadata() -> None:
104-
return None
105-
106-
107-
setattr(noop_metadata, c.SQLMESH_METADATA, True)
108-
109-
110103
@contextmanager
111104
def test_context_manager():
112105
yield
@@ -134,8 +127,7 @@ def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2)
134127
sqlglot.parse_one("1")
135128
MyClass()
136129
DataClass(x=y)
137-
noop_metadata()
138-
normalize_model_name("test")
130+
normalize_model_name("test" + SQLGLOT_META)
139131
fetch_data()
140132
function_with_custom_decorator()
141133

@@ -154,7 +146,6 @@ def test_func_globals() -> None:
154146
"Z": 3,
155147
"DataClass": DataClass,
156148
"MyClass": MyClass,
157-
"noop_metadata": noop_metadata,
158149
"normalize_model_name": normalize_model_name,
159150
"other_func": other_func,
160151
"sqlglot": sqlglot,
@@ -163,6 +154,7 @@ def test_func_globals() -> None:
163154
"fetch_data": fetch_data,
164155
"test_context_manager": test_context_manager,
165156
"function_with_custom_decorator": function_with_custom_decorator,
157+
"SQLGLOT_META": SQLGLOT_META,
166158
}
167159
assert func_globals(other_func) == {
168160
"X": 1,
@@ -194,8 +186,7 @@ def test_normalize_source() -> None:
194186
sqlglot.parse_one('1')
195187
MyClass()
196188
DataClass(x=y)
197-
noop_metadata()
198-
normalize_model_name('test')
189+
normalize_model_name('test' + SQLGLOT_META)
199190
fetch_data()
200191
function_with_custom_decorator()
201192
@@ -221,20 +212,21 @@ def closure(z: int):
221212
def test_serialize_env_error() -> None:
222213
with pytest.raises(SQLMeshError):
223214
# pretend to be the module pandas
224-
serialize_env({"test_date": test_date}, path=Path("tests/utils"))
215+
serialize_env({"test_date": (test_date, None)}, path=Path("tests/utils"))
225216

226217
with pytest.raises(SQLMeshError):
227-
serialize_env({"select_all": SELECT_ALL}, path=Path("tests/utils"))
218+
serialize_env({"select_all": (SELECT_ALL, None)}, path=Path("tests/utils"))
228219

229220

230221
def test_serialize_env() -> None:
231-
env: t.Dict[str, t.Any] = {}
232222
path = Path("tests/utils")
223+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
224+
233225
build_env(main_func, env=env, name="MAIN", path=path)
234-
env = serialize_env(env, path=path) # type: ignore
226+
serialized_env = serialize_env(env, path=path) # type: ignore
227+
assert prepare_env(serialized_env)
235228

236-
assert prepare_env(env)
237-
assert env == {
229+
expected_env = {
238230
"MAIN": Executable(
239231
name="main_func",
240232
alias="MAIN",
@@ -244,8 +236,7 @@ def test_serialize_env() -> None:
244236
sqlglot.parse_one('1')
245237
MyClass()
246238
DataClass(x=y)
247-
noop_metadata()
248-
normalize_model_name('test')
239+
normalize_model_name('test' + SQLGLOT_META)
249240
fetch_data()
250241
function_with_custom_decorator()
251242
@@ -319,13 +310,6 @@ def test_context_manager():
319310
path="test_metaprogramming.py",
320311
payload="my_lambda = lambda : print('z')",
321312
),
322-
"noop_metadata": Executable(
323-
name="noop_metadata",
324-
path="test_metaprogramming.py",
325-
payload="""def noop_metadata():
326-
return None""",
327-
is_metadata=True,
328-
),
329313
"normalize_model_name": Executable(
330314
payload="from sqlmesh.core.dialect import normalize_model_name",
331315
kind=ExecutableKind.IMPORT,
@@ -401,4 +385,23 @@ def function_with_custom_decorator():
401385
return""",
402386
alias="_func",
403387
),
388+
"SQLGLOT_META": Executable.value("sqlglot.meta"),
404389
}
390+
391+
assert all(is_metadata is None for (_, is_metadata) in env.values())
392+
assert serialized_env == expected_env
393+
394+
# Annotate the entrypoint as "metadata only" to show how it propagates
395+
setattr(main_func, c.SQLMESH_METADATA, True)
396+
397+
env = {}
398+
399+
build_env(main_func, env=env, name="MAIN", path=path)
400+
serialized_env = serialize_env(env, path=path) # type: ignore
401+
assert prepare_env(serialized_env)
402+
403+
expected_env = {k: Executable(**v.dict(), is_metadata=True) for k, v in expected_env.items()}
404+
405+
# Every object is treated as "metadata only", transitively
406+
assert all(is_metadata for (_, is_metadata) in env.values())
407+
assert serialized_env == expected_env

0 commit comments

Comments
 (0)