Skip to content

Commit 03a07f1

Browse files
authored
Fix: avoid redundant object traversals when building python envs (#4295)
1 parent 6e99e93 commit 03a07f1

2 files changed

Lines changed: 81 additions & 33 deletions

File tree

sqlmesh/utils/metaprogramming.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -338,43 +338,43 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
338338
):
339339
env[name] = (obj, is_metadata)
340340
return
341+
342+
if inspect.isclass(obj):
343+
for var in decorator_vars(obj):
344+
if obj_module and var in obj_module.__dict__:
345+
walk(obj_module.__dict__[var], var, is_metadata)
346+
347+
for base in obj.__bases__:
348+
walk(base, base.__qualname__, is_metadata)
349+
350+
for k, v in obj.__dict__.items():
351+
if k.startswith("__"):
352+
continue
353+
354+
# Traverse methods in a class to find global references
355+
if isinstance(v, (classmethod, staticmethod)):
356+
v = v.__func__
357+
358+
if callable(v):
359+
# Walk the method if it's part of the object, else it's a global function and we just store it
360+
if v.__qualname__.startswith(obj.__qualname__):
361+
for k, v in func_globals(v).items():
362+
walk(v, k, is_metadata)
363+
else:
364+
walk(v, v.__name__, is_metadata)
365+
elif callable(obj):
366+
for k, v in func_globals(obj).items():
367+
walk(v, k, is_metadata)
368+
369+
# We store the object in the environment after its dependencies, because otherwise we
370+
# could crash at environment hydration time, since dicts are ordered and the top-level
371+
# objects would be loaded before their dependencies.
372+
env[name] = (obj, is_metadata)
341373
elif not _globals_match(env[name][0], obj):
342374
raise SQLMeshError(
343375
f"Cannot store {obj} in environment, duplicate definitions found for '{name}'"
344376
)
345377

346-
if inspect.isclass(obj):
347-
for var in decorator_vars(obj):
348-
if obj_module and var in obj_module.__dict__:
349-
walk(obj_module.__dict__[var], var, is_metadata)
350-
351-
for base in obj.__bases__:
352-
walk(base, base.__qualname__, is_metadata)
353-
354-
for k, v in obj.__dict__.items():
355-
if k.startswith("__"):
356-
continue
357-
358-
# Traverse methods in a class to find global references
359-
if isinstance(v, (classmethod, staticmethod)):
360-
v = v.__func__
361-
362-
if callable(v):
363-
# Walk the method if it's part of the object, else it's a global function and we just store it
364-
if v.__qualname__.startswith(obj.__qualname__):
365-
for k, v in func_globals(v).items():
366-
walk(v, k, is_metadata)
367-
else:
368-
walk(v, v.__name__, is_metadata)
369-
elif callable(obj):
370-
for k, v in func_globals(obj).items():
371-
walk(v, k, is_metadata)
372-
373-
# We store the object in the environment after its dependencies, because otherwise we
374-
# could crash at environment hydration time, since dicts are ordered and the top-level
375-
# objects would be loaded before their dependencies.
376-
env[name] = (obj, is_metadata)
377-
378378
# The "metadata only" annotation of the object is transitive
379379
walk(obj, name, is_metadata_obj or getattr(obj, c.SQLMESH_METADATA, None))
380380

tests/utils/test_metaprogramming.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tests.utils.test_date as test_date
1818
from sqlmesh.core.dialect import normalize_model_name
1919
from sqlmesh.core import constants as c
20+
from sqlmesh.core.macros import RuntimeStage
2021
from sqlmesh.utils.errors import SQLMeshError
2122
from sqlmesh.utils.metaprogramming import (
2223
Executable,
@@ -47,7 +48,7 @@ def test_print_exception(mocker: MockerFixture):
4748
except Exception as ex:
4849
print_exception(ex, test_env, out_mock)
4950

50-
expected_message = r""" File ".*?/tests/utils/test_metaprogramming\.py", line 46, in test_print_exception
51+
expected_message = r""" File ".*?/tests/utils/test_metaprogramming\.py", line 47, in test_print_exception
5152
eval\("test_fun\(\)", env\)
5253
5354
File "<string>", line 1, in <module>
@@ -140,6 +141,18 @@ def closure(z: int) -> int:
140141
return closure(y) + other_func(Y)
141142

142143

144+
def macro1() -> str:
145+
print("macro1 hello there")
146+
print(RuntimeStage.CREATING)
147+
return "1"
148+
149+
150+
def macro2() -> str:
151+
print("macro2 hello there")
152+
print(RuntimeStage.LOADING)
153+
return "2"
154+
155+
143156
def test_func_globals() -> None:
144157
assert func_globals(main_func) == {
145158
"Y": 2,
@@ -405,3 +418,38 @@ def function_with_custom_decorator():
405418
# Every object is treated as "metadata only", transitively
406419
assert all(is_metadata for (_, is_metadata) in env.values())
407420
assert serialized_env == expected_env
421+
422+
423+
def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
424+
path = Path("tests/utils")
425+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
426+
427+
build_env(macro1, env=env, name="macro1", path=path)
428+
build_env(macro2, env=env, name="macro2", path=path)
429+
430+
serialized_env = serialize_env(env, path=path) # type: ignore
431+
assert prepare_env(serialized_env)
432+
433+
expected_env = {
434+
"RuntimeStage": Executable(
435+
payload="from sqlmesh.core.macros import RuntimeStage", kind=ExecutableKind.IMPORT
436+
),
437+
"macro1": Executable(
438+
payload="""def macro1():
439+
print('macro1 hello there')
440+
print(RuntimeStage.CREATING)
441+
return '1'""",
442+
name="macro1",
443+
path="test_metaprogramming.py",
444+
),
445+
"macro2": Executable(
446+
payload="""def macro2():
447+
print('macro2 hello there')
448+
print(RuntimeStage.LOADING)
449+
return '2'""",
450+
name="macro2",
451+
path="test_metaprogramming.py",
452+
),
453+
}
454+
455+
assert serialized_env == expected_env

0 commit comments

Comments
 (0)