Skip to content

Commit 705ee11

Browse files
Fix: When initialising multiple connections pass concurrent tasks
1 parent 46eaf49 commit 705ee11

2 files changed

Lines changed: 49 additions & 1 deletion

File tree

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,7 @@ def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
21972197
for gateway_name in self.config.gateways:
21982198
if gateway_name != self.selected_gateway:
21992199
connection = self.config.get_connection(gateway_name)
2200+
connection.concurrent_tasks = self.concurrent_tasks
22002201
adapter = connection.create_engine_adapter()
22012202
self._engine_adapters[gateway_name] = adapter
22022203
return self._engine_adapters

tests/core/test_config.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from sqlmesh.core.context import Context
3030
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
31+
from sqlmesh.core.engine_adapter.bigquery import BigQueryEngineAdapter
3132
from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter
3233
from sqlmesh.core.notification_target import ConsoleNotificationTarget
3334
from sqlmesh.core.user import User
@@ -747,6 +748,10 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
747748
aws_secret_access_key: accesskey
748749
work_group: group
749750
s3_warehouse_location: s3://location
751+
bigquery:
752+
connection:
753+
type: bigquery
754+
750755
751756
default_gateway: redshift
752757
@@ -763,11 +768,53 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
763768
ctx = Context(paths=tmp_path, config=config)
764769

765770
assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
766-
assert len(ctx.engine_adapters) == 2
771+
assert len(ctx.engine_adapters) == 3
767772
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
768773
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
774+
assert isinstance(ctx.engine_adapters["bigquery"], BigQueryEngineAdapter)
769775
assert ctx.engine_adapter == ctx._get_engine_adapter("redshift")
770776

777+
# The bigquery engine adapter should be have been set as multithreaded as well
778+
assert ctx.engine_adapters["bigquery"]._multithreaded
779+
780+
781+
def test_multi_gateway_single_threaded_config(tmp_path):
782+
config_path = tmp_path / "config_duck_athena.yaml"
783+
with open(config_path, "w", encoding="utf-8") as fd:
784+
fd.write(
785+
"""
786+
gateways:
787+
duckdb:
788+
connection:
789+
type: duckdb
790+
database: db.db
791+
athena:
792+
connection:
793+
type: athena
794+
aws_access_key_id: '1234'
795+
aws_secret_access_key: accesskey
796+
work_group: group
797+
s3_warehouse_location: s3://location
798+
default_gateway: duckdb
799+
model_defaults:
800+
dialect: duckdb
801+
"""
802+
)
803+
804+
config = load_config_from_paths(
805+
Config,
806+
project_paths=[config_path],
807+
)
808+
809+
ctx = Context(paths=tmp_path, config=config)
810+
assert isinstance(ctx._connection_config, DuckDBConnectionConfig)
811+
assert len(ctx.engine_adapters) == 2
812+
assert ctx.engine_adapter == ctx._get_engine_adapter("duckdb")
813+
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
814+
815+
# In this case athena should use 1 concurrent task as the default gateway is duckdb
816+
assert not ctx.engine_adapters["athena"]._multithreaded
817+
771818

772819
def test_trino_schema_location_mapping_syntax(tmp_path):
773820
config_path = tmp_path / "config_trino.yaml"

0 commit comments

Comments
 (0)