Skip to content

Commit a602c71

Browse files
adapt the query_cache_pool
1 parent 80487e4 commit a602c71

9 files changed

Lines changed: 101 additions & 103 deletions

File tree

sqlmesh/core/loader.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
from __future__ import annotations
22

33
import abc
4-
import concurrent.futures
54
import glob
65
import itertools
76
import linecache
8-
import multiprocessing as mp
97
import os
108
import re
119
import typing as t
12-
import concurrent
1310
from collections import Counter, defaultdict
1411
from dataclasses import dataclass
1512
from pathlib import Path
@@ -40,6 +37,7 @@
4037
from sqlmesh.utils.errors import ConfigError
4138
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
4239
from sqlmesh.utils.metaprogramming import import_python_file
40+
from sqlmesh.utils.process import create_process_pool_executor
4341
from sqlmesh.utils.yaml import YAML, load as yaml_load
4442

4543

@@ -531,8 +529,7 @@ def _load_sql_models(
531529
)
532530

533531
errors: t.List[str] = []
534-
with concurrent.futures.ProcessPoolExecutor(
535-
mp_context=mp.get_context("fork"),
532+
with create_process_pool_executor(
536533
initializer=_init_model_defaults,
537534
initargs=(self.config, gateway, defaults, cache),
538535
max_workers=c.MAX_FORK_WORKERS,

sqlmesh/core/model/cache.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
import multiprocessing as mp
54
import typing as t
6-
from concurrent.futures import ProcessPoolExecutor
75
from pathlib import Path
86

97
from sqlglot import exp
@@ -15,6 +13,7 @@
1513
from sqlmesh.core.model.definition import ExternalModel, Model, SqlModel
1614
from sqlmesh.utils.cache import FileCache
1715
from sqlmesh.utils.hashing import crc32
16+
from sqlmesh.utils.process import PoolExecutor, create_process_pool_executor
1817

1918
from dataclasses import dataclass
2019

@@ -135,9 +134,8 @@ def _entry_name(model: SqlModel) -> str:
135134
return f"{model.name}_{crc32(hash_data)}"
136135

137136

138-
def optimized_query_cache_pool(optimized_query_cache: OptimizedQueryCache) -> ProcessPoolExecutor:
139-
return ProcessPoolExecutor(
140-
mp_context=mp.get_context("fork"),
137+
def optimized_query_cache_pool(optimized_query_cache: OptimizedQueryCache) -> PoolExecutor:
138+
return create_process_pool_executor(
141139
initializer=_init_optimized_query_cache,
142140
initargs=(optimized_query_cache,),
143141
max_workers=c.MAX_FORK_WORKERS,

sqlmesh/core/model/definition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,10 @@ def convert_to_time_column(
811811
return exp.convert(time)
812812

813813
def set_mapping_schema(self, schema: t.Dict) -> None:
814+
# Make a shallow copy to avoid modifying the original in case they're the same
815+
temp_schema = schema.copy()
814816
self.mapping_schema.clear()
815-
self.mapping_schema.update(schema)
817+
self.mapping_schema.update(temp_schema)
816818

817819
def update_schema(self, schema: MappingSchema) -> None:
818820
"""Updates the schema for this model's dependencies based on the given mapping schema."""

sqlmesh/core/model/schema.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ def update_model_schemas(
2828
schema = MappingSchema(normalize=False)
2929
optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
3030

31-
if c.MAX_FORK_WORKERS == 1:
32-
_update_model_schemas_sequential(dag, models, schema, optimized_query_cache)
33-
else:
34-
_update_model_schemas_parallel(dag, models, schema, optimized_query_cache)
31+
_update_model_schemas(dag, models, schema, optimized_query_cache)
3532

3633

3734
def _update_schema_with_model(schema: MappingSchema, model: Model) -> None:
@@ -49,25 +46,7 @@ def _update_schema_with_model(schema: MappingSchema, model: Model) -> None:
4946
raise
5047

5148

52-
def _update_model_schemas_sequential(
53-
dag: DAG[str],
54-
models: UniqueKeyDict[str, Model],
55-
schema: MappingSchema,
56-
optimized_query_cache: OptimizedQueryCache,
57-
) -> None:
58-
for name in dag.sorted:
59-
model = models.get(name)
60-
61-
# External models don't exist in the context, so we need to skip them
62-
if not model:
63-
continue
64-
65-
model.update_schema(schema)
66-
optimized_query_cache.with_optimized_query(model)
67-
_update_schema_with_model(schema, model)
68-
69-
70-
def _update_model_schemas_parallel(
49+
def _update_model_schemas(
7150
dag: DAG[str],
7251
models: UniqueKeyDict[str, Model],
7352
schema: MappingSchema,
@@ -102,17 +81,24 @@ def process_models(completed_model: t.Optional[Model] = None) -> None:
10281
)
10382
)
10483

84+
errors: t.List[str] = []
10585
with optimized_query_cache_pool(optimized_query_cache) as executor:
10686
process_models()
10787

108-
while futures:
88+
while futures and not errors:
10989
for future in as_completed(futures):
110-
futures.remove(future)
111-
fqn, entry_name, data_hash, metadata_hash, mapping_schema = future.result()
112-
model = models[fqn]
113-
model._data_hash = data_hash
114-
model._metadata_hash = metadata_hash
115-
model.set_mapping_schema(mapping_schema)
116-
optimized_query_cache.with_optimized_query(model, entry_name)
117-
_update_schema_with_model(schema, model)
118-
process_models(completed_model=model)
90+
try:
91+
futures.remove(future)
92+
fqn, entry_name, data_hash, metadata_hash, mapping_schema = future.result()
93+
model = models[fqn]
94+
model._data_hash = data_hash
95+
model._metadata_hash = metadata_hash
96+
model.set_mapping_schema(mapping_schema)
97+
optimized_query_cache.with_optimized_query(model, entry_name)
98+
_update_schema_with_model(schema, model)
99+
process_models(completed_model=model)
100+
except Exception as ex:
101+
errors.append(f"{ex}")
102+
103+
if errors:
104+
raise SchemaError(f"Failed to update model schemas\n\n{'\n'.join(errors)}")

sqlmesh/core/snapshot/cache.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,19 @@ def get_or_load(
5555
for snapshot in loaded_snapshots:
5656
snapshots[snapshot.snapshot_id] = snapshot
5757

58-
if c.MAX_FORK_WORKERS != 1:
59-
with optimized_query_cache_pool(self._optimized_query_cache) as executor:
60-
for key, entry_name in executor.map(
61-
load_optimized_query,
62-
(
63-
(snapshot.model, s_id)
64-
for s_id, snapshot in snapshots.items()
65-
if snapshot.is_model
66-
),
67-
):
68-
if entry_name:
69-
self._optimized_query_cache.with_optimized_query(
70-
snapshots[key].model, entry_name
71-
)
58+
with optimized_query_cache_pool(self._optimized_query_cache) as executor:
59+
for key, entry_name in executor.map(
60+
load_optimized_query,
61+
(
62+
(snapshot.model, s_id)
63+
for s_id, snapshot in snapshots.items()
64+
if snapshot.is_model
65+
),
66+
):
67+
if entry_name:
68+
self._optimized_query_cache.with_optimized_query(
69+
snapshots[key].model, entry_name
70+
)
7271

7372
for snapshot in snapshots.values():
7473
self._update_node_hash_cache(snapshot)

sqlmesh/utils/process.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# mypy: disable-error-code=no-untyped-def
2+
3+
from concurrent.futures import Future, ProcessPoolExecutor
4+
import typing as t
5+
import multiprocessing as mp
6+
from sqlmesh.core import constants as c
7+
8+
9+
class SynchronousPoolExecutor:
10+
"""A mock implementation of the ProcessPoolExecutor for synchronous use.
11+
12+
This executor runs functions synchronously in the same process, avoiding the issues
13+
with forking in test environments or when forking isn't possible (non-posix).
14+
"""
15+
16+
def __init__(self, max_workers=None, mp_context=None, initializer=None, initargs=()):
17+
if initializer is not None:
18+
try:
19+
initializer(*initargs)
20+
except BaseException as ex:
21+
raise RuntimeError(f"Exception in initializer: {ex}")
22+
23+
def __enter__(self):
24+
return self
25+
26+
def __exit__(self, *args):
27+
return True
28+
29+
def submit(self, fn, *args, **kwargs):
30+
"""Execute the function synchronously and return a Future with the result."""
31+
future = Future()
32+
try:
33+
result = fn(*args, **kwargs)
34+
future.set_result(result)
35+
except Exception as e:
36+
future.set_exception(e)
37+
return future
38+
39+
def map(self, fn, *iterables, timeout=None, chunksize=1):
40+
"""Synchronous implementation of ProcessPoolExecutor.map.
41+
42+
This executes the function for each set of inputs from the iterables in the
43+
current process using Python's built-in map, rather than distributing work.
44+
"""
45+
return map(fn, *iterables)
46+
47+
48+
PoolExecutor = t.Union[SynchronousPoolExecutor, ProcessPoolExecutor]
49+
50+
51+
def create_process_pool_executor(
52+
initializer: t.Callable, initargs: t.Tuple, max_workers: t.Optional[int] = c.MAX_FORK_WORKERS
53+
) -> PoolExecutor:
54+
executor = SynchronousPoolExecutor if max_workers == 1 else ProcessPoolExecutor
55+
return executor(
56+
mp_context=mp.get_context("fork"),
57+
initializer=initializer,
58+
initargs=initargs,
59+
max_workers=max_workers,
60+
)

tests/conftest.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -506,15 +506,3 @@ def _make_function(table_name: str, random_id: str) -> exp.Table:
506506
return temp_table
507507

508508
return _make_function
509-
510-
511-
@pytest.fixture(autouse=True)
512-
def patch_process_pool_executor(mocker: MockerFixture, request):
513-
"""Patch ProcessPoolExecutor with MockProcessPoolExecutor in all tests except test_forking.py."""
514-
# Skip mocking for test_forking.py
515-
if request.node.fspath.basename == "test_forking.py":
516-
return
517-
518-
from tests.mock_executor import MockProcessPoolExecutor
519-
520-
mocker.patch("concurrent.futures.ProcessPoolExecutor", MockProcessPoolExecutor)

tests/mock_executor.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

tests/test_forking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def test_parallel_load(assert_exp_eq, mocker):
1212
mocker.patch("sqlmesh.core.constants.MAX_FORK_WORKERS", 2)
13-
spy = mocker.spy(schema, "_update_model_schemas_parallel")
13+
spy = mocker.spy(schema, "_update_model_schemas")
1414
context = Context(paths="examples/sushi")
1515

1616
if hasattr(os, "fork"):

0 commit comments

Comments
 (0)