Skip to content

Commit 5a1be92

Browse files
authored
Feat: add support for Trino's authorization session property (#4639)
1 parent 113325c commit 5a1be92

4 files changed

Lines changed: 151 additions & 2 deletions

File tree

sqlmesh/core/engine_adapter/trino.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from __future__ import annotations
2+
3+
import contextlib
24
import re
35
import typing as t
46
from functools import lru_cache
@@ -28,7 +30,7 @@
2830
from sqlmesh.utils.date import TimeLike
2931

3032
if t.TYPE_CHECKING:
31-
from sqlmesh.core._typing import SchemaName, TableName
33+
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
3234
from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF
3335

3436

@@ -87,6 +89,24 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str:
8789
)
8890
return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE
8991

92+
@contextlib.contextmanager
93+
def session(self, properties: SessionProperties) -> t.Iterator[None]:
94+
authorization = properties.get("authorization")
95+
if not authorization:
96+
yield
97+
return
98+
99+
if not isinstance(authorization, exp.Expression):
100+
authorization = exp.Literal.string(authorization)
101+
102+
authorization_sql = authorization.sql(dialect=self.dialect)
103+
104+
self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")
105+
try:
106+
yield
107+
finally:
108+
self.execute(f"RESET SESSION AUTHORIZATION")
109+
90110
def _insert_overwrite_by_condition(
91111
self,
92112
table_name: TableName,

sqlmesh/core/model/meta.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
319319
return parsed_session_properties
320320

321321
for eq in parsed_session_properties:
322-
if eq.name == "query_label":
322+
prop_name = eq.left.name
323+
324+
if prop_name == "query_label":
323325
query_label = eq.right
324326
if not (
325327
isinstance(query_label, exp.Array)
@@ -345,6 +347,12 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
345347
raise ConfigError(
346348
"Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2."
347349
)
350+
elif prop_name == "authorization":
351+
authorization = eq.right
352+
if not (isinstance(authorization, exp.Literal) and authorization.is_string):
353+
raise ConfigError(
354+
"Invalid value for `session_properties.authorization`. Must be a string literal."
355+
)
348356

349357
return parsed_session_properties
350358

tests/core/engine_adapter/test_trino.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,49 @@ def test_create_schema_sets_location(make_mocked_engine_adapter: t.Callable, moc
591591
'CREATE SCHEMA IF NOT EXISTS "landing"."transactions" WITH (LOCATION=\'s3://raw-data/landing/transactions\')', # match '^landing\..*$'
592592
]
593593
)
594+
595+
596+
def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter):
597+
adapter = trino_mocked_engine_adapter
598+
599+
# Test 1: No authorization property - should not execute any authorization commands
600+
with adapter.session({}):
601+
pass
602+
603+
assert to_sql_calls(adapter) == []
604+
605+
# Test 2: String authorization
606+
with adapter.session({"authorization": "test_user"}):
607+
adapter.execute("SELECT 1")
608+
609+
assert to_sql_calls(adapter) == [
610+
"SET SESSION AUTHORIZATION 'test_user'",
611+
"SELECT 1",
612+
"RESET SESSION AUTHORIZATION",
613+
]
614+
615+
# Test 3: Expression authorization
616+
adapter.cursor.execute.reset_mock()
617+
with adapter.session({"authorization": exp.Literal.string("another_user")}):
618+
adapter.execute("SELECT 2")
619+
620+
assert to_sql_calls(adapter) == [
621+
"SET SESSION AUTHORIZATION 'another_user'",
622+
"SELECT 2",
623+
"RESET SESSION AUTHORIZATION",
624+
]
625+
626+
# Test 4: RESET is called even if exception occurs during session
627+
adapter.cursor.execute.reset_mock()
628+
try:
629+
with adapter.session({"authorization": "test_user"}):
630+
adapter.execute("SELECT 1")
631+
raise RuntimeError("Test exception")
632+
except RuntimeError:
633+
pass
634+
635+
assert to_sql_calls(adapter) == [
636+
"SET SESSION AUTHORIZATION 'test_user'",
637+
"SELECT 1",
638+
"RESET SESSION AUTHORIZATION",
639+
]

tests/core/test_model.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4531,6 +4531,81 @@ def test_model_session_properties(sushi_context):
45314531
)
45324532

45334533

4534+
def test_session_properties_authorization_validation():
4535+
model = load_sql_based_model(
4536+
d.parse(
4537+
"""
4538+
MODEL (
4539+
name test_schema.test_model,
4540+
session_properties (
4541+
authorization = 'test_user'
4542+
)
4543+
);
4544+
SELECT a FROM tbl;
4545+
""",
4546+
default_dialect="trino",
4547+
)
4548+
)
4549+
assert model.session_properties == {"authorization": "test_user"}
4550+
4551+
with pytest.raises(
4552+
ConfigError,
4553+
match=r"Invalid value for `session_properties.authorization`. Must be a string literal.",
4554+
):
4555+
load_sql_based_model(
4556+
d.parse(
4557+
"""
4558+
MODEL (
4559+
name test_schema.test_model,
4560+
session_properties (
4561+
authorization = 123
4562+
)
4563+
);
4564+
SELECT a FROM tbl;
4565+
""",
4566+
default_dialect="trino",
4567+
)
4568+
)
4569+
4570+
with pytest.raises(
4571+
ConfigError,
4572+
match=r"Invalid value for `session_properties.authorization`. Must be a string literal.",
4573+
):
4574+
load_sql_based_model(
4575+
d.parse(
4576+
"""
4577+
MODEL (
4578+
name test_schema.test_model,
4579+
session_properties (
4580+
authorization = some_column
4581+
)
4582+
);
4583+
SELECT a FROM tbl;
4584+
""",
4585+
default_dialect="trino",
4586+
)
4587+
)
4588+
4589+
with pytest.raises(
4590+
ConfigError,
4591+
match=r"Invalid value for `session_properties.authorization`. Must be a string literal.",
4592+
):
4593+
load_sql_based_model(
4594+
d.parse(
4595+
"""
4596+
MODEL (
4597+
name test_schema.test_model,
4598+
session_properties (
4599+
authorization = CONCAT('user', '_suffix')
4600+
)
4601+
);
4602+
SELECT a FROM tbl;
4603+
""",
4604+
default_dialect="trino",
4605+
)
4606+
)
4607+
4608+
45344609
def test_model_jinja_macro_rendering():
45354610
expressions = d.parse(
45364611
"""

0 commit comments

Comments
 (0)