Skip to content

Commit 472452f

Browse files
add mock executor; fix loader; adapt unit tests
1 parent 34c1a1c commit 472452f

5 files changed

Lines changed: 84 additions & 41 deletions

File tree

sqlmesh/core/loader.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from __future__ import annotations
22

33
import abc
4+
import concurrent.futures
45
import glob
56
import itertools
67
import linecache
78
import multiprocessing as mp
89
import os
910
import re
1011
import typing as t
12+
import concurrent
1113
from collections import Counter, defaultdict
1214
from dataclasses import dataclass
1315
from pathlib import Path
1416
from pydantic import ValidationError
15-
from concurrent.futures import ProcessPoolExecutor, as_completed
1617

1718
from sqlglot.errors import SqlglotError
1819
from sqlglot import exp
@@ -478,20 +479,15 @@ def _load_models(
478479
audits into a Dict and creates the dag
479480
"""
480481
cache = SqlMeshLoader._Cache(self, self.config_path)
481-
import time
482482

483-
now = time.time()
484483
sql_models = self._load_sql_models(macros, jinja_macros, audits, signals, cache, gateway)
485-
print("sql models", time.time() - now)
486-
now = time.time()
487484
external_models = self._load_external_models(audits, cache, gateway)
488-
print("external models", time.time() - now)
489485
python_models = self._load_python_models(macros, jinja_macros, audits, signals)
490486

491487
all_model_names = list(sql_models) + list(external_models) + list(python_models)
492488
duplicates = [name for name, count in Counter(all_model_names).items() if count > 1]
493489
if duplicates:
494-
raise ValueError(f"Duplicate model name(s) found: {', '.join(duplicates)}.")
490+
raise ConfigError(f"Duplicate model name(s) found: {', '.join(duplicates)}.")
495491

496492
return UniqueKeyDict("models", **sql_models, **external_models, **python_models)
497493

@@ -506,8 +502,7 @@ def _load_sql_models(
506502
) -> UniqueKeyDict[str, Model]:
507503
"""Loads the sql models into a Dict"""
508504
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
509-
510-
paths = set()
505+
paths: t.Set[Path] = set()
511506

512507
for path in self._glob_paths(
513508
self.config_path / c.MODELS,
@@ -522,14 +517,11 @@ def _load_sql_models(
522517

523518
for path in paths.copy():
524519
cached_models = cache.get(path)
525-
526520
if cached_models:
527521
paths.remove(path)
528-
529522
for model in cached_models:
530-
models[model.fqn] = model
531-
532-
error = False
523+
if model.enabled:
524+
models[model.fqn] = model
533525

534526
if paths:
535527
defaults = dict(
@@ -550,31 +542,31 @@ def _load_sql_models(
550542
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
551543
)
552544

553-
with ProcessPoolExecutor(
545+
errors: t.List[str] = []
546+
with concurrent.futures.ProcessPoolExecutor(
554547
mp_context=mp.get_context("fork"),
555548
initializer=_init_model_defaults,
556549
initargs=(self.config, gateway, defaults, cache),
557550
max_workers=c.MAX_FORK_WORKERS,
558551
) as pool:
559-
for fut in as_completed(pool.submit(load_sql_models, path) for path in paths):
552+
futures_to_paths = {pool.submit(load_sql_models, path): path for path in paths}
553+
for fut, path in futures_to_paths.items():
560554
try:
561-
path, loaded = fut.result()
562-
555+
_, loaded = fut.result()
563556
if loaded:
564557
for model in loaded:
565-
model._path = path
566-
models[model.fqn] = model
558+
if model.enabled:
559+
model._path = path
560+
models[model.fqn] = model
567561
else:
568562
for model in cache.get(path):
569-
models[model.fqn] = model
563+
if model.enabled:
564+
models[model.fqn] = model
570565
except Exception as ex:
571-
self._console.log_error(
572-
f"Failed to load model definition at '{path}'.\n{ex}"
573-
)
574-
error = True
566+
errors.append(f"Failed to load model definition at '{path}'.\n\n{ex}")
575567

576-
if error:
577-
raise ConfigError("Failed to load models")
568+
if errors:
569+
raise ConfigError(f"Failed to load models\n\n{'\n'.join(errors)}")
578570

579571
return models
580572

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,15 @@ def _make_function(table_name: str, random_id: str) -> exp.Table:
506506
return temp_table
507507

508508
return _make_function
509+
510+
511+
@pytest.fixture(autouse=True)
512+
def patch_process_pool_executor(mocker: MockerFixture, request):
513+
"""Patch ProcessPoolExecutor with MockProcessPoolExecutor in all tests except test_forking.py."""
514+
# Skip mocking for test_forking.py
515+
if request.node.fspath.basename == "test_forking.py":
516+
return
517+
518+
from tests.mock_executor import MockProcessPoolExecutor
519+
520+
mocker.patch("concurrent.futures.ProcessPoolExecutor", MockProcessPoolExecutor)

tests/core/test_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def test_duplicate_model_names_different_kind(tmp_path: Path, sample_models):
8585
path_3.write_text(model_3["contents"])
8686

8787
with pytest.raises(
88-
ValueError, match=r'Duplicate model name\(s\) found: "memory"."test_schema"."test_model".'
88+
ConfigError, match=r'Duplicate model name\(s\) found: "memory"."test_schema"."test_model".'
8989
):
9090
Context(paths=tmp_path, config=config)
9191

9292

9393
@pytest.mark.parametrize("sample_models", ["sql", "external"], indirect=True)
9494
def test_duplicate_model_names_same_kind(tmp_path: Path, sample_models):
95-
"""Test same (SQL and external) models with duplicate model names raises ValueError."""
95+
"""Test same (SQL and external) models with duplicate model names raises ConfigError."""
9696

9797
def duplicate_model_path(fpath):
9898
return Path(fpath).parent / ("duplicate" + Path(fpath).suffix)

tests/core/test_model.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,26 +2941,33 @@ def test_model_cache(tmp_path: Path, mocker: MockerFixture):
29412941
expressions = d.parse(
29422942
"""
29432943
MODEL (
2944-
name db.seed,
2944+
name db.model_sql,
29452945
);
29462946
SELECT 1, ds;
29472947
"""
29482948
)
29492949

29502950
model = load_sql_based_model([e for e in expressions if e])
29512951

2952-
loader = mocker.Mock(return_value=[model])
2953-
2954-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2955-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2952+
assert cache.put([model], "test_model", "test_entry_a")
2953+
assert cache.get("test_model", "test_entry_a")[0].dict() == model.dict()
29562954

2957-
assert cache.get_or_load("test_model", "test_entry_b", loader=loader)[0].dict() == model.dict()
2958-
assert cache.get_or_load("test_model", "test_entry_b", loader=loader)[0].dict() == model.dict()
2955+
expressions = d.parse(
2956+
"""
2957+
MODEL (
2958+
name db.model_seed,
2959+
kind SEED (
2960+
path '../seeds/waiter_names.csv',
2961+
),
2962+
);
2963+
"""
2964+
)
29592965

2960-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2961-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2966+
seed_model = load_sql_based_model(
2967+
expressions, path=Path("./examples/sushi/models/test_model.sql")
2968+
)
29622969

2963-
assert loader.call_count == 2
2970+
assert not cache.put([seed_model], "test_model", "test_entry_b")
29642971

29652972

29662973
@pytest.mark.slow
@@ -2983,7 +2990,7 @@ def test_model_cache_gateway(tmp_path: Path, mocker: MockerFixture):
29832990
assert patched_cache_put.call_count == 0
29842991

29852992
Context(paths=tmp_path, config=config, gateway="secondary")
2986-
assert patched_cache_put.call_count == 4
2993+
assert patched_cache_put.call_count == 2
29872994

29882995

29892996
@pytest.mark.slow
@@ -3001,7 +3008,7 @@ def test_model_cache_default_catalog(tmp_path: Path, mocker: MockerFixture):
30013008
PropertyMock(return_value=None),
30023009
):
30033010
Context(paths=tmp_path)
3004-
assert patched_cache_put.call_count == 4
3011+
assert patched_cache_put.call_count == 2
30053012

30063013

30073014
def test_model_ctas_query():

tests/mock_executor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from concurrent.futures import Future
2+
3+
4+
class MockProcessPoolExecutor:
5+
"""A mock implementation of ProcessPoolExecutor for use in tests.
6+
7+
This executor runs functions synchronously in the same process, avoiding the issues
8+
with forking in test environments.
9+
"""
10+
11+
def __init__(self, max_workers=None, mp_context=None, initializer=None, initargs=()):
12+
if initializer is not None:
13+
try:
14+
initializer(*initargs)
15+
except BaseException as ex:
16+
raise RuntimeError(f"Exception in initializer: {ex}")
17+
18+
def __enter__(self):
19+
return self
20+
21+
def __exit__(self, *args):
22+
return True
23+
24+
def submit(self, fn, *args, **kwargs):
25+
"""Execute the function synchronously and return a Future with the result."""
26+
future = Future()
27+
try:
28+
result = fn(*args, **kwargs)
29+
future.set_result(result)
30+
except Exception as e:
31+
future.set_exception(e)
32+
return future

0 commit comments

Comments
 (0)