|
2 | 2 | import pytest |
3 | 3 | from pytest import FixtureRequest |
4 | 4 | from sqlglot import exp |
| 5 | +from pathlib import Path |
5 | 6 | from sqlglot.optimizer.qualify_columns import quote_identifiers |
6 | 7 | from sqlglot.helper import seq_get |
7 | 8 | from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter |
|
10 | 11 | from sqlmesh.core.model import SqlModel, load_sql_based_model |
11 | 12 | from sqlmesh.core.plan import Plan |
12 | 13 | from tests.core.engine_adapter.integration import TestContext |
| 14 | +from sqlmesh import model, ExecutionContext |
| 15 | +from sqlmesh.core.model import ModelKindName |
| 16 | +from datetime import datetime |
13 | 17 |
|
14 | 18 | from tests.core.engine_adapter.integration import ( |
15 | 19 | TestContext, |
|
19 | 23 | ) |
20 | 24 |
|
21 | 25 |
|
22 | | -@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"]))) |
| 26 | +@pytest.fixture( |
| 27 | + params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"], show_variant_in_test_id=False)) |
| 28 | +) |
23 | 29 | def ctx( |
24 | 30 | request: FixtureRequest, |
25 | 31 | create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], |
@@ -220,3 +226,48 @@ def test_create_iceberg_table(ctx: TestContext, engine_adapter: SnowflakeEngineA |
220 | 226 | result = sqlmesh.plan(auto_apply=True) |
221 | 227 |
|
222 | 228 | assert len(result.new_snapshots) == 2 |
| 229 | + |
| 230 | + |
| 231 | +def test_snowpark_concurrency(ctx: TestContext) -> None: |
| 232 | + from snowflake.snowpark import DataFrame |
| 233 | + |
| 234 | + table = ctx.table("my_model") |
| 235 | + |
| 236 | + # this model will insert 10 records in batches of 1, with 4 batches at a time running concurrently |
| 237 | + @model( |
| 238 | + name=table.sql(), |
| 239 | + kind=dict( |
| 240 | + name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, |
| 241 | + time_column="ds", |
| 242 | + batch_size=1, |
| 243 | + batch_concurrency=4, |
| 244 | + ), |
| 245 | + columns={"id": "int", "ds": "date"}, |
| 246 | + start="2020-01-01", |
| 247 | + end="2020-01-10", |
| 248 | + ) |
| 249 | + def execute(context: ExecutionContext, start: datetime, **kwargs) -> DataFrame: |
| 250 | + if snowpark := context.snowpark: |
| 251 | + return snowpark.create_dataframe([(start.day, start.date())], schema=["id", "ds"]) |
| 252 | + |
| 253 | + raise ValueError("Snowpark not present!") |
| 254 | + |
| 255 | + m = model.get_registry()[table.sql().lower()].model( |
| 256 | + module_path=Path("."), path=Path("."), dialect="snowflake" |
| 257 | + ) |
| 258 | + |
| 259 | + sqlmesh = ctx.create_context() |
| 260 | + |
| 261 | + # verify that we are actually running in multithreaded mode |
| 262 | + assert sqlmesh.concurrent_tasks > 1 |
| 263 | + assert ctx.engine_adapter._multithreaded |
| 264 | + |
| 265 | + sqlmesh.upsert_model(m) |
| 266 | + |
| 267 | + plan = sqlmesh.plan(auto_apply=True) |
| 268 | + |
| 269 | + assert len(plan.new_snapshots) == 1 |
| 270 | + |
| 271 | + query = exp.select("*").from_(table) |
| 272 | + df = ctx.engine_adapter.fetchdf(query, quote_identifiers=True) |
| 273 | + assert len(df) == 10 |
0 commit comments