|
11 | 11 | from sqlmesh.core.model import load_sql_based_model |
12 | 12 | from sqlmesh.core.model.definition import SqlModel |
13 | 13 | from sqlmesh.core.dialect import schema_ |
| 14 | +from sqlmesh.utils.errors import SQLMeshError |
14 | 15 | from tests.core.engine_adapter import to_sql_calls |
15 | 16 |
|
16 | 17 | pytestmark = [pytest.mark.engine, pytest.mark.trino] |
@@ -591,3 +592,55 @@ def test_create_schema_sets_location(make_mocked_engine_adapter: t.Callable, moc |
591 | 592 | 'CREATE SCHEMA IF NOT EXISTS "landing"."transactions" WITH (LOCATION=\'s3://raw-data/landing/transactions\')', # match '^landing\..*$' |
592 | 593 | ] |
593 | 594 | ) |
| 595 | + |
| 596 | + |
| 597 | +def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter): |
| 598 | + adapter = trino_mocked_engine_adapter |
| 599 | + |
| 600 | + # Test 1: No authorization property - should not execute any authorization commands |
| 601 | + with adapter.session({}): |
| 602 | + pass |
| 603 | + |
| 604 | + assert to_sql_calls(adapter) == [] |
| 605 | + |
| 606 | + # Test 2: String authorization |
| 607 | + with adapter.session({"authorization": "test_user"}): |
| 608 | + adapter.execute("SELECT 1") |
| 609 | + |
| 610 | + assert to_sql_calls(adapter) == [ |
| 611 | + "SET SESSION AUTHORIZATION 'test_user'", |
| 612 | + "SELECT 1", |
| 613 | + "RESET SESSION AUTHORIZATION", |
| 614 | + ] |
| 615 | + |
| 616 | + # Test 3: Expression authorization |
| 617 | + adapter.cursor.execute.reset_mock() |
| 618 | + with adapter.session({"authorization": exp.Literal.string("another_user")}): |
| 619 | + adapter.execute("SELECT 2") |
| 620 | + |
| 621 | + assert to_sql_calls(adapter) == [ |
| 622 | + "SET SESSION AUTHORIZATION 'another_user'", |
| 623 | + "SELECT 2", |
| 624 | + "RESET SESSION AUTHORIZATION", |
| 625 | + ] |
| 626 | + |
| 627 | + # Test 4: Invalid authorization (non-string expression) |
| 628 | + adapter.cursor.execute.reset_mock() |
| 629 | + with pytest.raises(SQLMeshError, match="Invalid authorization"): |
| 630 | + with adapter.session({"authorization": exp.Literal.number(123)}): |
| 631 | + pass |
| 632 | + |
| 633 | + # Test 5: RESET is called even if exception occurs during session |
| 634 | + adapter.cursor.execute.reset_mock() |
| 635 | + try: |
| 636 | + with adapter.session({"authorization": "test_user"}): |
| 637 | + adapter.execute("SELECT 1") |
| 638 | + raise RuntimeError("Test exception") |
| 639 | + except RuntimeError: |
| 640 | + pass |
| 641 | + |
| 642 | + assert to_sql_calls(adapter) == [ |
| 643 | + "SET SESSION AUTHORIZATION 'test_user'", |
| 644 | + "SELECT 1", |
| 645 | + "RESET SESSION AUTHORIZATION", |
| 646 | + ] |
0 commit comments