Skip to content

Commit 978d2de

Browse files
Fix: When initialising multiple connections pass concurrent tasks
1 parent cf3b0fa commit 978d2de

2 files changed

Lines changed: 49 additions & 1 deletion

File tree

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,7 @@ def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
21972197
for gateway_name in self.config.gateways:
21982198
if gateway_name != self.selected_gateway:
21992199
connection = self.config.get_connection(gateway_name)
2200+
connection.concurrent_tasks = self.concurrent_tasks
22002201
adapter = connection.create_engine_adapter()
22012202
self._engine_adapters[gateway_name] = adapter
22022203
return self._engine_adapters

tests/core/test_config.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from sqlmesh.core.context import Context
2828
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
29+
from sqlmesh.core.engine_adapter.bigquery import BigQueryEngineAdapter
2930
from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter
3031
from sqlmesh.core.notification_target import ConsoleNotificationTarget
3132
from sqlmesh.core.user import User
@@ -709,6 +710,10 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
709710
aws_secret_access_key: accesskey
710711
work_group: group
711712
s3_warehouse_location: s3://location
713+
bigquery:
714+
connection:
715+
type: bigquery
716+
712717
713718
default_gateway: redshift
714719
@@ -725,11 +730,53 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
725730
ctx = Context(paths=tmp_path, config=config)
726731

727732
assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
728-
assert len(ctx.engine_adapters) == 2
733+
assert len(ctx.engine_adapters) == 3
729734
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
730735
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
736+
assert isinstance(ctx.engine_adapters["bigquery"], BigQueryEngineAdapter)
731737
assert ctx.engine_adapter == ctx._get_engine_adapter("redshift")
732738

739+
# The bigquery engine adapter should be have been set as multithreaded as well
740+
assert ctx.engine_adapters["bigquery"]._multithreaded
741+
742+
743+
def test_multi_gateway_single_threaded_config(tmp_path):
744+
config_path = tmp_path / "config_duck_athena.yaml"
745+
with open(config_path, "w", encoding="utf-8") as fd:
746+
fd.write(
747+
"""
748+
gateways:
749+
duckdb:
750+
connection:
751+
type: duckdb
752+
database: db.db
753+
athena:
754+
connection:
755+
type: athena
756+
aws_access_key_id: '1234'
757+
aws_secret_access_key: accesskey
758+
work_group: group
759+
s3_warehouse_location: s3://location
760+
default_gateway: duckdb
761+
model_defaults:
762+
dialect: duckdb
763+
"""
764+
)
765+
766+
config = load_config_from_paths(
767+
Config,
768+
project_paths=[config_path],
769+
)
770+
771+
ctx = Context(paths=tmp_path, config=config)
772+
assert isinstance(ctx._connection_config, DuckDBConnectionConfig)
773+
assert len(ctx.engine_adapters) == 2
774+
assert ctx.engine_adapter == ctx._get_engine_adapter("duckdb")
775+
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
776+
777+
# In this case athena should use 1 concurrent task as the default gateway is duckdb
778+
assert not ctx.engine_adapters["athena"]._multithreaded
779+
733780

734781
def test_trino_schema_location_mapping_syntax(tmp_path):
735782
config_path = tmp_path / "config_trino.yaml"

0 commit comments

Comments
 (0)