Skip to content

Commit 538bed5

Browse files
committed
Add dialect and display name to ConnectionConfig classes
1 parent 64b70cd commit 538bed5

3 files changed

Lines changed: 97 additions & 31 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlmesh.utils.date import yesterday_ds
99
from sqlmesh.utils.errors import SQLMeshError
1010

11-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
11+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DIALECT_TO_TYPE
1212

1313

1414
PRIMITIVES = (str, int, bool, float)
@@ -27,7 +27,7 @@ class InitCliMode(Enum):
2727

2828

2929
def _gen_config(
30-
dialect: t.Optional[str],
30+
engine_type: t.Optional[str],
3131
settings: t.Optional[str],
3232
start: t.Optional[str],
3333
template: ProjectTemplate,
@@ -39,7 +39,7 @@ def _gen_config(
3939
database: db.db"""
4040
)
4141

42-
engine = "mssql" if dialect == "tsql" else dialect
42+
engine = "mssql" if engine_type == "tsql" else engine_type
4343

4444
if not settings and template != ProjectTemplate.DBT:
4545
doc_link = "https://sqlmesh.readthedocs.io/en/stable/integrations/engines{engine_link}"
@@ -51,6 +51,9 @@ def _gen_config(
5151

5252
for name, field in CONNECTION_CONFIG_TO_TYPE[engine].model_fields.items():
5353
field_name = field.alias or name
54+
if field_name in ("dialect", "display_name"):
55+
continue
56+
5457
default_value = field.get_default()
5558

5659
if isinstance(default_value, Enum):
@@ -77,7 +80,7 @@ def _gen_config(
7780

7881
connection_settings = (
7982
" # For more information on configuring the connection to your execution engine, visit:\n"
80-
" # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connections\n"
83+
" # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection\n"
8184
f" # {doc_link.format(engine_link=engine_link)}\n{connection_settings}"
8285
)
8386

@@ -93,7 +96,7 @@ def _gen_config(
9396
# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults
9497
9598
model_defaults:
96-
dialect: {dialect}
99+
dialect: {DIALECT_TO_TYPE[engine]}
97100
start: {start or yesterday_ds()} # Start date for backfill history
98101
cron: '@daily' # Run models daily at 12am UTC (can override per model)
99102

sqlmesh/cli/main.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sqlmesh.core.console import configure_console, get_console
1616
from sqlmesh.utils import Verbosity
1717
from sqlmesh.core.config import load_configs
18-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
18+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DISPLAY_NAME_TO_TYPE
1919
from sqlmesh.core.context import Context
2020
from sqlmesh.utils.date import TimeLike
2121
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
@@ -41,24 +41,24 @@
4141
SKIP_CONTEXT_COMMANDS = ("init", "ui")
4242

4343
# These are ordered for user display - do not reorder
44-
ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE = {
45-
"DuckDB": "duckdb",
46-
"Snowflake": "snowflake",
47-
"Databricks": "databricks",
48-
"BigQuery": "bigquery",
49-
"MotherDuck": "duckdb",
50-
"ClickHouse": "clickhouse",
51-
"Redshift": "redshift",
52-
"Spark": "spark",
53-
"Trino": "trino",
54-
"Azure SQL": "azuresql",
55-
"MSSQL": "tsql",
56-
"Postgres": "postgres",
57-
"GCP Postgres": "gcp_postgres",
58-
"MySQL": "mysql",
59-
"Athena": "athena",
60-
"RisingWave": "risingwave",
61-
}
44+
ENGINE_TYPE_DISPLAY_ORDER = [
45+
"duckdb",
46+
"snowflake",
47+
"databricks",
48+
"bigquery",
49+
"motherduck",
50+
"clickhouse",
51+
"redshift",
52+
"spark",
53+
"trino",
54+
"azuresql",
55+
"mssql",
56+
"postgres",
57+
"gcp_postgres",
58+
"mysql",
59+
"athena",
60+
"risingwave",
61+
]
6262

6363

6464
def _sqlmesh_version() -> str:
@@ -231,7 +231,7 @@ def init(
231231
)
232232

233233
next_step_text = {
234-
ProjectTemplate.DEFAULT: f"• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}\nRun command in CLI: sqlmesh plan\n(Optional) Explain a plan: sqlmesh plan --explain",
234+
ProjectTemplate.DEFAULT: f"• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}",
235235
ProjectTemplate.DBT: "",
236236
}
237237
next_step_text[ProjectTemplate.EMPTY] = next_step_text[ProjectTemplate.DEFAULT]
@@ -1376,9 +1376,9 @@ def _init_engine_prompt(console: Console) -> str:
13761376
console.print("Choose your SQL engine:\n")
13771377

13781378
display_num_to_engine = {}
1379-
for i, engine in enumerate(ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE.keys()):
1380-
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{engine}")
1381-
display_num_to_engine[i + 1] = engine
1379+
for i, engine_type in enumerate(ENGINE_TYPE_DISPLAY_ORDER):
1380+
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{DISPLAY_NAME_TO_TYPE[engine_type]}")
1381+
display_num_to_engine[i + 1] = engine_type
13821382
console.print("")
13831383

13841384
# self._print("""Need another engine? See: https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines)
@@ -1388,10 +1388,10 @@ def _init_engine_prompt(console: Console) -> str:
13881388
# """)
13891389

13901390
engine_num = _init_integer_prompt(
1391-
console, "engine", len(ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE), _init_engine_prompt
1391+
console, "engine", len(ENGINE_TYPE_DISPLAY_ORDER), _init_engine_prompt
13921392
)
13931393

1394-
return ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE[display_num_to_engine[engine_num]]
1394+
return display_num_to_engine[engine_num]
13951395

13961396

13971397
def _init_cli_mode_prompt(console: Console) -> InitCliMode:
@@ -1426,5 +1426,5 @@ def _check_engine_installed(console: Console, engine_type: t.Optional[str] = Non
14261426
except ModuleNotFoundError:
14271427
install_command = f'pip install "sqlmesh[{engine_type}]"'
14281428
raise SQLMeshError(
1429-
f"Unable to load required Python dependencies for the {engine_type.upper()} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
1429+
f"Unable to load required Python dependencies for the {DISPLAY_NAME_TO_TYPE[engine_type]} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
14301430
)

sqlmesh/core/config/connection.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def validate(cls: t.Any, data: t.Any) -> t.Any:
8787

8888
class ConnectionConfig(abc.ABC, BaseConfig):
8989
type_: str
90+
dialect: str
91+
display_name: str
9092
concurrent_tasks: int
9193
register_comments: bool
9294
pre_ping: bool
@@ -463,6 +465,8 @@ class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
463465
"""Configuration for the MotherDuck connection."""
464466

465467
type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck")
468+
dialect: t.Literal["duckdb"] = "duckdb"
469+
display_name: t.Literal["MotherDuck"] = "MotherDuck"
466470

467471
@property
468472
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -487,6 +491,8 @@ class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
487491
"""Configuration for the DuckDB connection."""
488492

489493
type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb")
494+
dialect: t.Literal["duckdb"] = "duckdb"
495+
display_name: t.Literal["DuckDB"] = "DuckDB"
490496

491497

492498
class SnowflakeConnectionConfig(ConnectionConfig):
@@ -537,6 +543,8 @@ class SnowflakeConnectionConfig(ConnectionConfig):
537543
session_parameters: t.Optional[dict] = None
538544

539545
type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
546+
dialect: t.Literal["snowflake"] = "snowflake"
547+
display_name: t.Literal["Snowflake"] = "Snowflake"
540548

541549
_concurrent_tasks_validator = concurrent_tasks_validator
542550

@@ -733,6 +741,8 @@ class DatabricksConnectionConfig(ConnectionConfig):
733741
pre_ping: t.Literal[False] = False
734742

735743
type_: t.Literal["databricks"] = Field(alias="type", default="databricks")
744+
dialect: t.Literal["databricks"] = "databricks"
745+
display_name: t.Literal["Databricks"] = "Databricks"
736746

737747
_concurrent_tasks_validator = concurrent_tasks_validator
738748
_http_headers_validator = http_headers_validator
@@ -989,6 +999,8 @@ class BigQueryConnectionConfig(ConnectionConfig):
989999
pre_ping: t.Literal[False] = False
9901000

9911001
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
1002+
dialect: t.Literal["bigquery"] = "bigquery"
1003+
display_name: t.Literal["BigQuery"] = "BigQuery"
9921004

9931005
_engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
9941006

@@ -1130,6 +1142,9 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
11301142
scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",)
11311143
driver: str = "pg8000"
11321144
type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
1145+
dialect: t.Literal["postgres"] = "postgres"
1146+
display_name: t.Literal["GCP Postgres"] = "GCP Postgres"
1147+
11331148
concurrent_tasks: int = 4
11341149
register_comments: bool = True
11351150
pre_ping: bool = True
@@ -1264,6 +1279,8 @@ class RedshiftConnectionConfig(ConnectionConfig):
12641279
pre_ping: bool = False
12651280

12661281
type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
1282+
dialect: t.Literal["redshift"] = "redshift"
1283+
display_name: t.Literal["Redshift"] = "Redshift"
12671284

12681285
_engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
12691286

@@ -1325,6 +1342,8 @@ class PostgresConnectionConfig(ConnectionConfig):
13251342
pre_ping: bool = True
13261343

13271344
type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
1345+
dialect: t.Literal["postgres"] = "postgres"
1346+
display_name: t.Literal["Postgres"] = "Postgres"
13281347

13291348
_engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
13301349

@@ -1378,6 +1397,8 @@ class MySQLConnectionConfig(ConnectionConfig):
13781397
pre_ping: bool = True
13791398

13801399
type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
1400+
dialect: t.Literal["mysql"] = "mysql"
1401+
display_name: t.Literal["MySQL"] = "MySQL"
13811402

13821403
_engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
13831404

@@ -1440,6 +1461,8 @@ class MSSQLConnectionConfig(ConnectionConfig):
14401461
pre_ping: bool = True
14411462

14421463
type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
1464+
dialect: t.Literal["tsql"] = "tsql"
1465+
display_name: t.Literal["MSSQL"] = "MSSQL"
14431466

14441467
@model_validator(mode="before")
14451468
@classmethod
@@ -1581,6 +1604,8 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
15811604

15821605
class AzureSQLConnectionConfig(MSSQLConnectionConfig):
15831606
type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore
1607+
dialect: t.Literal["tsql"] = "tsql"
1608+
display_name: t.Literal["Azure SQL"] = "Azure SQL" # type: ignore
15841609

15851610
@property
15861611
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
@@ -1601,6 +1626,8 @@ class SparkConnectionConfig(ConnectionConfig):
16011626
pre_ping: t.Literal[False] = False
16021627

16031628
type_: t.Literal["spark"] = Field(alias="type", default="spark")
1629+
dialect: t.Literal["spark"] = "spark"
1630+
display_name: t.Literal["Spark"] = "Spark"
16041631

16051632
_engine_import_validator = _get_engine_import_validator("pyspark", "spark")
16061633

@@ -1719,6 +1746,8 @@ class TrinoConnectionConfig(ConnectionConfig):
17191746
pre_ping: t.Literal[False] = False
17201747

17211748
type_: t.Literal["trino"] = Field(alias="type", default="trino")
1749+
dialect: t.Literal["trino"] = "trino"
1750+
display_name: t.Literal["Trino"] = "Trino"
17221751

17231752
_engine_import_validator = _get_engine_import_validator("trino", "trino")
17241753

@@ -1879,6 +1908,8 @@ class ClickhouseConnectionConfig(ConnectionConfig):
18791908
connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None
18801909

18811910
type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
1911+
dialect: t.Literal["clickhouse"] = "clickhouse"
1912+
display_name: t.Literal["ClickHouse"] = "ClickHouse"
18821913

18831914
_engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
18841915

@@ -2003,6 +2034,8 @@ class AthenaConnectionConfig(ConnectionConfig):
20032034
pre_ping: t.Literal[False] = False
20042035

20052036
type_: t.Literal["athena"] = Field(alias="type", default="athena")
2037+
dialect: t.Literal["athena"] = "athena"
2038+
display_name: t.Literal["Athena"] = "Athena"
20062039

20072040
_engine_import_validator = _get_engine_import_validator("pyathena", "athena")
20082041

@@ -2071,6 +2104,8 @@ class RisingwaveConnectionConfig(ConnectionConfig):
20712104
pre_ping: bool = True
20722105

20732106
type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
2107+
dialect: t.Literal["risingwave"] = "risingwave"
2108+
display_name: t.Literal["RisingWave"] = "RisingWave"
20742109

20752110
_engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
20762111

@@ -2115,6 +2150,34 @@ def init(cursor: t.Any) -> None:
21152150
)
21162151
}
21172152

2153+
CONNECTION_CONFIG_TO_TYPE = {
2154+
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
2155+
tpe.all_field_infos()["type_"].default: tpe
2156+
for tpe in subclasses(
2157+
__name__,
2158+
ConnectionConfig,
2159+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2160+
)
2161+
}
2162+
2163+
DIALECT_TO_TYPE = {
2164+
tpe.all_field_infos()["type_"].default: tpe.all_field_infos()["dialect"].default
2165+
for tpe in subclasses(
2166+
__name__,
2167+
ConnectionConfig,
2168+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2169+
)
2170+
}
2171+
2172+
DISPLAY_NAME_TO_TYPE = {
2173+
tpe.all_field_infos()["type_"].default: tpe.all_field_infos()["display_name"].default
2174+
for tpe in subclasses(
2175+
__name__,
2176+
ConnectionConfig,
2177+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2178+
)
2179+
}
2180+
21182181

21192182
def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
21202183
if "type" not in v:

0 commit comments

Comments
 (0)