Skip to content

Commit 27d09c4

Browse files
Fix: When initialising multiple connections pass concurrent tasks (#4176)
1 parent 786033f commit 27d09c4

3 files changed

Lines changed: 61 additions & 6 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,15 @@ def connection_validator(self) -> t.Callable[[], None]:
109109
"""A function that validates the connection configuration"""
110110
return self.create_engine_adapter().ping
111111

112-
def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
112+
def create_engine_adapter(
113+
self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
114+
) -> EngineAdapter:
113115
"""Returns a new instance of the Engine Adapter."""
116+
117+
concurrent_tasks = concurrent_tasks or self.concurrent_tasks
114118
return self._engine_adapter(
115119
self._connection_factory_with_kwargs,
116-
multithreaded=self.concurrent_tasks > 1,
120+
multithreaded=concurrent_tasks > 1,
117121
default_catalog=self.get_catalog(),
118122
cursor_init=self._cursor_init,
119123
register_comments=register_comments_override or self.register_comments,
@@ -284,7 +288,9 @@ def init(cursor: duckdb.DuckDBPyConnection) -> None:
284288

285289
return init
286290

287-
def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
291+
def create_engine_adapter(
292+
self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
293+
) -> EngineAdapter:
288294
"""Checks if another engine adapter has already been created that shares a catalog that points to the same data
289295
file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
290296
associated with the new adapter will be ignored."""
@@ -315,7 +321,9 @@ def create_engine_adapter(self, register_comments_override: bool = False) -> Eng
315321
logger.info(f"Creating new DuckDB adapter for data files: {masked_files}")
316322
else:
317323
logger.info("Creating new DuckDB adapter for in-memory database")
318-
adapter = super().create_engine_adapter(register_comments_override)
324+
adapter = super().create_engine_adapter(
325+
register_comments_override, concurrent_tasks=concurrent_tasks
326+
)
319327
for data_file in data_files:
320328
key = data_file if isinstance(data_file, str) else data_file.path
321329
BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2254,7 +2254,7 @@ def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
22542254
for gateway_name in self.config.gateways:
22552255
if gateway_name != self.selected_gateway:
22562256
connection = self.config.get_connection(gateway_name)
2257-
adapter = connection.create_engine_adapter()
2257+
adapter = connection.create_engine_adapter(concurrent_tasks=self.concurrent_tasks)
22582258
self._engine_adapters[gateway_name] = adapter
22592259
return self._engine_adapters
22602260

tests/core/test_config.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from sqlmesh.core.context import Context
2828
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
29+
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
2930
from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter
3031
from sqlmesh.core.notification_target import ConsoleNotificationTarget
3132
from sqlmesh.core.user import User
@@ -709,6 +710,10 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
709710
aws_secret_access_key: accesskey
710711
work_group: group
711712
s3_warehouse_location: s3://location
713+
duckdb:
714+
connection:
715+
type: duckdb
716+
database: db.db
712717
713718
default_gateway: redshift
714719
@@ -725,11 +730,53 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
725730
ctx = Context(paths=tmp_path, config=config)
726731

727732
assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
728-
assert len(ctx.engine_adapters) == 2
733+
assert len(ctx.engine_adapters) == 3
729734
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
730735
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
736+
assert isinstance(ctx.engine_adapters["duckdb"], DuckDBEngineAdapter)
731737
assert ctx.engine_adapter == ctx._get_engine_adapter("redshift")
732738

739+
# The duckdb engine adapter should be have been set as multithreaded as well
740+
assert ctx.engine_adapters["duckdb"]._multithreaded
741+
742+
743+
def test_multi_gateway_single_threaded_config(tmp_path):
744+
config_path = tmp_path / "config_duck_athena.yaml"
745+
with open(config_path, "w", encoding="utf-8") as fd:
746+
fd.write(
747+
"""
748+
gateways:
749+
duckdb:
750+
connection:
751+
type: duckdb
752+
database: db.db
753+
athena:
754+
connection:
755+
type: athena
756+
aws_access_key_id: '1234'
757+
aws_secret_access_key: accesskey
758+
work_group: group
759+
s3_warehouse_location: s3://location
760+
default_gateway: duckdb
761+
model_defaults:
762+
dialect: duckdb
763+
"""
764+
)
765+
766+
config = load_config_from_paths(
767+
Config,
768+
project_paths=[config_path],
769+
)
770+
771+
ctx = Context(paths=tmp_path, config=config)
772+
assert isinstance(ctx._connection_config, DuckDBConnectionConfig)
773+
assert len(ctx.engine_adapters) == 2
774+
assert ctx.engine_adapter == ctx._get_engine_adapter("duckdb")
775+
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
776+
777+
# In this case athena should use 1 concurrent task as the default gateway is duckdb
778+
assert not ctx.engine_adapters["athena"]._multithreaded
779+
733780

734781
def test_trino_schema_location_mapping_syntax(tmp_path):
735782
config_path = tmp_path / "config_trino.yaml"

0 commit comments

Comments
 (0)