Skip to content

Commit 5ff525c

Browse files
committed
Feat: add support for Trino's authorization session property
1 parent a29ad26 commit 5ff525c

2 files changed

Lines changed: 78 additions & 2 deletions

File tree

sqlmesh/core/engine_adapter/trino.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
2+
3+
import contextlib
4+
import pandas as pd
25
import re
36
import typing as t
47
from functools import lru_cache
5-
import pandas as pd
68
from pandas.api.types import is_datetime64_any_dtype # type: ignore
79
from sqlglot import exp
810
from sqlglot.helper import seq_get
@@ -25,11 +27,12 @@
2527
SourceQuery,
2628
set_catalog,
2729
)
30+
from sqlmesh.utils.errors import SQLMeshError
2831
from sqlmesh.core.schema_diff import SchemaDiffer
2932
from sqlmesh.utils.date import TimeLike
3033

3134
if t.TYPE_CHECKING:
32-
from sqlmesh.core._typing import SchemaName, TableName
35+
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
3336
from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF
3437

3538

@@ -88,6 +91,26 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str:
8891
)
8992
return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE
9093

94+
@contextlib.contextmanager
95+
def session(self, properties: SessionProperties) -> t.Iterator[None]:
96+
authorization = properties.get("authorization")
97+
if not authorization:
98+
yield
99+
return
100+
101+
if isinstance(authorization, str):
102+
authorization = exp.Literal.string(authorization)
103+
if not (isinstance(authorization, exp.Expression) and authorization.is_string):
104+
raise SQLMeshError(f"Invalid authorization: '{authorization}'")
105+
106+
authorization_sql = authorization.sql(dialect=self.dialect)
107+
108+
self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")
109+
try:
110+
yield
111+
finally:
112+
self.execute(f"RESET SESSION AUTHORIZATION")
113+
91114
def _insert_overwrite_by_condition(
92115
self,
93116
table_name: TableName,

tests/core/engine_adapter/test_trino.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.model import load_sql_based_model
1212
from sqlmesh.core.model.definition import SqlModel
1313
from sqlmesh.core.dialect import schema_
14+
from sqlmesh.utils.errors import SQLMeshError
1415
from tests.core.engine_adapter import to_sql_calls
1516

1617
pytestmark = [pytest.mark.engine, pytest.mark.trino]
@@ -591,3 +592,55 @@ def test_create_schema_sets_location(make_mocked_engine_adapter: t.Callable, moc
591592
'CREATE SCHEMA IF NOT EXISTS "landing"."transactions" WITH (LOCATION=\'s3://raw-data/landing/transactions\')', # match '^landing\..*$'
592593
]
593594
)
595+
596+
597+
def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter):
598+
adapter = trino_mocked_engine_adapter
599+
600+
# Test 1: No authorization property - should not execute any authorization commands
601+
with adapter.session({}):
602+
pass
603+
604+
assert to_sql_calls(adapter) == []
605+
606+
# Test 2: String authorization
607+
with adapter.session({"authorization": "test_user"}):
608+
adapter.execute("SELECT 1")
609+
610+
assert to_sql_calls(adapter) == [
611+
"SET SESSION AUTHORIZATION 'test_user'",
612+
"SELECT 1",
613+
"RESET SESSION AUTHORIZATION",
614+
]
615+
616+
# Test 3: Expression authorization
617+
adapter.cursor.execute.reset_mock()
618+
with adapter.session({"authorization": exp.Literal.string("another_user")}):
619+
adapter.execute("SELECT 2")
620+
621+
assert to_sql_calls(adapter) == [
622+
"SET SESSION AUTHORIZATION 'another_user'",
623+
"SELECT 2",
624+
"RESET SESSION AUTHORIZATION",
625+
]
626+
627+
# Test 4: Invalid authorization (non-string expression)
628+
adapter.cursor.execute.reset_mock()
629+
with pytest.raises(SQLMeshError, match="Invalid authorization"):
630+
with adapter.session({"authorization": exp.Literal.number(123)}):
631+
pass
632+
633+
# Test 5: RESET is called even if exception occurs during session
634+
adapter.cursor.execute.reset_mock()
635+
try:
636+
with adapter.session({"authorization": "test_user"}):
637+
adapter.execute("SELECT 1")
638+
raise RuntimeError("Test exception")
639+
except RuntimeError:
640+
pass
641+
642+
assert to_sql_calls(adapter) == [
643+
"SET SESSION AUTHORIZATION 'test_user'",
644+
"SELECT 1",
645+
"RESET SESSION AUTHORIZATION",
646+
]

0 commit comments

Comments
 (0)