Skip to content

Commit 4ecde95

Browse files
committed
Add dialect and display name to ConnectionConfig classes
1 parent 402291d commit 4ecde95

3 files changed

Lines changed: 97 additions & 31 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlmesh.utils.date import yesterday_ds
99
from sqlmesh.utils.errors import SQLMeshError
1010

11-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
11+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DIALECT_TO_TYPE
1212

1313

1414
PRIMITIVES = (str, int, bool, float)
@@ -27,7 +27,7 @@ class InitCliMode(Enum):
2727

2828

2929
def _gen_config(
30-
dialect: t.Optional[str],
30+
engine_type: t.Optional[str],
3131
settings: t.Optional[str],
3232
start: t.Optional[str],
3333
template: ProjectTemplate,
@@ -39,7 +39,7 @@ def _gen_config(
3939
database: db.db"""
4040
)
4141

42-
engine = "mssql" if dialect == "tsql" else dialect
42+
engine = "mssql" if engine_type == "tsql" else engine_type
4343

4444
if not settings and template != ProjectTemplate.DBT:
4545
doc_link = "https://sqlmesh.readthedocs.io/en/stable/integrations/engines{engine_link}"
@@ -51,6 +51,9 @@ def _gen_config(
5151

5252
for name, field in CONNECTION_CONFIG_TO_TYPE[engine].model_fields.items():
5353
field_name = field.alias or name
54+
if field_name in ("dialect", "display_name"):
55+
continue
56+
5457
default_value = field.get_default()
5558

5659
if isinstance(default_value, Enum):
@@ -77,7 +80,7 @@ def _gen_config(
7780

7881
connection_settings = (
7982
" # For more information on configuring the connection to your execution engine, visit:\n"
80-
" # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connections\n"
83+
" # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection\n"
8184
f" # {doc_link.format(engine_link=engine_link)}\n{connection_settings}"
8285
)
8386

@@ -93,7 +96,7 @@ def _gen_config(
9396
# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults
9497
9598
model_defaults:
96-
dialect: {dialect}
99+
dialect: {DIALECT_TO_TYPE[engine]}
97100
start: {start or yesterday_ds()} # Start date for backfill history
98101
cron: '@daily' # Run models daily at 12am UTC (can override per model)
99102

sqlmesh/cli/main.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sqlmesh.core.console import configure_console, get_console
1616
from sqlmesh.utils import Verbosity
1717
from sqlmesh.core.config import load_configs
18-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
18+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DISPLAY_NAME_TO_TYPE
1919
from sqlmesh.core.context import Context
2020
from sqlmesh.utils.date import TimeLike
2121
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
@@ -41,24 +41,24 @@
4141
SKIP_CONTEXT_COMMANDS = ("init", "ui")
4242

4343
# These are ordered for user display - do not reorder
44-
ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE = {
45-
"DuckDB": "duckdb",
46-
"Snowflake": "snowflake",
47-
"Databricks": "databricks",
48-
"BigQuery": "bigquery",
49-
"MotherDuck": "duckdb",
50-
"ClickHouse": "clickhouse",
51-
"Redshift": "redshift",
52-
"Spark": "spark",
53-
"Trino": "trino",
54-
"Azure SQL": "azuresql",
55-
"MSSQL": "tsql",
56-
"Postgres": "postgres",
57-
"GCP Postgres": "gcp_postgres",
58-
"MySQL": "mysql",
59-
"Athena": "athena",
60-
"RisingWave": "risingwave",
61-
}
44+
ENGINE_TYPE_DISPLAY_ORDER = [
45+
"duckdb",
46+
"snowflake",
47+
"databricks",
48+
"bigquery",
49+
"motherduck",
50+
"clickhouse",
51+
"redshift",
52+
"spark",
53+
"trino",
54+
"azuresql",
55+
"mssql",
56+
"postgres",
57+
"gcp_postgres",
58+
"mysql",
59+
"athena",
60+
"risingwave",
61+
]
6262

6363

6464
def _sqlmesh_version() -> str:
@@ -231,7 +231,7 @@ def init(
231231
)
232232

233233
next_step_text = {
234-
ProjectTemplate.DEFAULT: f"• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}\nRun command in CLI: sqlmesh plan\n(Optional) Explain a plan: sqlmesh plan --explain",
234+
ProjectTemplate.DEFAULT: f"• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}",
235235
ProjectTemplate.DBT: "",
236236
}
237237
next_step_text[ProjectTemplate.EMPTY] = next_step_text[ProjectTemplate.DEFAULT]
@@ -1376,9 +1376,9 @@ def _init_engine_prompt(console: Console) -> str:
13761376
console.print("Choose your SQL engine:\n")
13771377

13781378
display_num_to_engine = {}
1379-
for i, engine in enumerate(ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE.keys()):
1380-
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{engine}")
1381-
display_num_to_engine[i + 1] = engine
1379+
for i, engine_type in enumerate(ENGINE_TYPE_DISPLAY_ORDER):
1380+
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{DISPLAY_NAME_TO_TYPE[engine_type]}")
1381+
display_num_to_engine[i + 1] = engine_type
13821382
console.print("")
13831383

13841384
# self._print("""Need another engine? See: https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines)
@@ -1388,10 +1388,10 @@ def _init_engine_prompt(console: Console) -> str:
13881388
# """)
13891389

13901390
engine_num = _init_integer_prompt(
1391-
console, "engine", len(ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE), _init_engine_prompt
1391+
console, "engine", len(ENGINE_TYPE_DISPLAY_ORDER), _init_engine_prompt
13921392
)
13931393

1394-
return ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE[display_num_to_engine[engine_num]]
1394+
return display_num_to_engine[engine_num]
13951395

13961396

13971397
def _init_cli_mode_prompt(console: Console) -> InitCliMode:
@@ -1426,5 +1426,5 @@ def _check_engine_installed(console: Console, engine_type: t.Optional[str] = Non
14261426
except ModuleNotFoundError:
14271427
install_command = f'pip install "sqlmesh[{engine_type}]"'
14281428
raise SQLMeshError(
1429-
f"Unable to load required Python dependencies for the {engine_type.upper()} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
1429+
f"Unable to load required Python dependencies for the {DISPLAY_NAME_TO_TYPE[engine_type]} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
14301430
)

sqlmesh/core/config/connection.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def validate(cls: t.Any, data: t.Any) -> t.Any:
8787

8888
class ConnectionConfig(abc.ABC, BaseConfig):
8989
type_: str
90+
dialect: str
91+
display_name: str
9092
concurrent_tasks: int
9193
register_comments: bool
9294
pre_ping: bool
@@ -463,6 +465,8 @@ class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
463465
"""Configuration for the MotherDuck connection."""
464466

465467
type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck")
468+
dialect: t.Literal["duckdb"] = "duckdb"
469+
display_name: t.Literal["MotherDuck"] = "MotherDuck"
466470

467471
@property
468472
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -487,6 +491,8 @@ class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
487491
"""Configuration for the DuckDB connection."""
488492

489493
type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb")
494+
dialect: t.Literal["duckdb"] = "duckdb"
495+
display_name: t.Literal["DuckDB"] = "DuckDB"
490496

491497

492498
class SnowflakeConnectionConfig(ConnectionConfig):
@@ -537,6 +543,8 @@ class SnowflakeConnectionConfig(ConnectionConfig):
537543
session_parameters: t.Optional[dict] = None
538544

539545
type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
546+
dialect: t.Literal["snowflake"] = "snowflake"
547+
display_name: t.Literal["Snowflake"] = "Snowflake"
540548

541549
_concurrent_tasks_validator = concurrent_tasks_validator
542550
_engine_import_validator = _get_engine_import_validator("snowflake", "snowflake")
@@ -730,6 +738,8 @@ class DatabricksConnectionConfig(ConnectionConfig):
730738
pre_ping: t.Literal[False] = False
731739

732740
type_: t.Literal["databricks"] = Field(alias="type", default="databricks")
741+
dialect: t.Literal["databricks"] = "databricks"
742+
display_name: t.Literal["Databricks"] = "Databricks"
733743

734744
_concurrent_tasks_validator = concurrent_tasks_validator
735745
_http_headers_validator = http_headers_validator
@@ -985,6 +995,8 @@ class BigQueryConnectionConfig(ConnectionConfig):
985995
pre_ping: t.Literal[False] = False
986996

987997
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
998+
dialect: t.Literal["bigquery"] = "bigquery"
999+
display_name: t.Literal["BigQuery"] = "BigQuery"
9881000

9891001
_engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
9901002

@@ -1126,6 +1138,9 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
11261138
scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",)
11271139
driver: str = "pg8000"
11281140
type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
1141+
dialect: t.Literal["postgres"] = "postgres"
1142+
display_name: t.Literal["GCP Postgres"] = "GCP Postgres"
1143+
11291144
concurrent_tasks: int = 4
11301145
register_comments: bool = True
11311146
pre_ping: bool = True
@@ -1260,6 +1275,8 @@ class RedshiftConnectionConfig(ConnectionConfig):
12601275
pre_ping: bool = False
12611276

12621277
type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
1278+
dialect: t.Literal["redshift"] = "redshift"
1279+
display_name: t.Literal["Redshift"] = "Redshift"
12631280

12641281
_engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
12651282

@@ -1321,6 +1338,8 @@ class PostgresConnectionConfig(ConnectionConfig):
13211338
pre_ping: bool = True
13221339

13231340
type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
1341+
dialect: t.Literal["postgres"] = "postgres"
1342+
display_name: t.Literal["Postgres"] = "Postgres"
13241343

13251344
_engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
13261345

@@ -1374,6 +1393,8 @@ class MySQLConnectionConfig(ConnectionConfig):
13741393
pre_ping: bool = True
13751394

13761395
type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
1396+
dialect: t.Literal["mysql"] = "mysql"
1397+
display_name: t.Literal["MySQL"] = "MySQL"
13771398

13781399
_engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
13791400

@@ -1436,6 +1457,8 @@ class MSSQLConnectionConfig(ConnectionConfig):
14361457
pre_ping: bool = True
14371458

14381459
type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
1460+
dialect: t.Literal["tsql"] = "tsql"
1461+
display_name: t.Literal["MSSQL"] = "MSSQL"
14391462

14401463
@model_validator(mode="before")
14411464
@classmethod
@@ -1577,6 +1600,8 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
15771600

15781601
class AzureSQLConnectionConfig(MSSQLConnectionConfig):
15791602
type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore
1603+
dialect: t.Literal["tsql"] = "tsql"
1604+
display_name: t.Literal["Azure SQL"] = "Azure SQL" # type: ignore
15801605

15811606
@property
15821607
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
@@ -1597,6 +1622,8 @@ class SparkConnectionConfig(ConnectionConfig):
15971622
pre_ping: t.Literal[False] = False
15981623

15991624
type_: t.Literal["spark"] = Field(alias="type", default="spark")
1625+
dialect: t.Literal["spark"] = "spark"
1626+
display_name: t.Literal["Spark"] = "Spark"
16001627

16011628
_engine_import_validator = _get_engine_import_validator("pyspark", "spark")
16021629

@@ -1715,6 +1742,8 @@ class TrinoConnectionConfig(ConnectionConfig):
17151742
pre_ping: t.Literal[False] = False
17161743

17171744
type_: t.Literal["trino"] = Field(alias="type", default="trino")
1745+
dialect: t.Literal["trino"] = "trino"
1746+
display_name: t.Literal["Trino"] = "Trino"
17181747

17191748
_engine_import_validator = _get_engine_import_validator("trino", "trino")
17201749

@@ -1875,6 +1904,8 @@ class ClickhouseConnectionConfig(ConnectionConfig):
18751904
connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None
18761905

18771906
type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
1907+
dialect: t.Literal["clickhouse"] = "clickhouse"
1908+
display_name: t.Literal["ClickHouse"] = "ClickHouse"
18781909

18791910
_engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
18801911

@@ -1999,6 +2030,8 @@ class AthenaConnectionConfig(ConnectionConfig):
19992030
pre_ping: t.Literal[False] = False
20002031

20012032
type_: t.Literal["athena"] = Field(alias="type", default="athena")
2033+
dialect: t.Literal["athena"] = "athena"
2034+
display_name: t.Literal["Athena"] = "Athena"
20022035

20032036
_engine_import_validator = _get_engine_import_validator("pyathena", "athena")
20042037

@@ -2067,6 +2100,8 @@ class RisingwaveConnectionConfig(ConnectionConfig):
20672100
pre_ping: bool = True
20682101

20692102
type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
2103+
dialect: t.Literal["risingwave"] = "risingwave"
2104+
display_name: t.Literal["RisingWave"] = "RisingWave"
20702105

20712106
_engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
20722107

@@ -2111,6 +2146,34 @@ def init(cursor: t.Any) -> None:
21112146
)
21122147
}
21132148

2149+
CONNECTION_CONFIG_TO_TYPE = {
2150+
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
2151+
tpe.all_field_infos()["type_"].default: tpe
2152+
for tpe in subclasses(
2153+
__name__,
2154+
ConnectionConfig,
2155+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2156+
)
2157+
}
2158+
2159+
DIALECT_TO_TYPE = {
2160+
tpe.all_field_infos()["type_"].default: tpe.all_field_infos()["dialect"].default
2161+
for tpe in subclasses(
2162+
__name__,
2163+
ConnectionConfig,
2164+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2165+
)
2166+
}
2167+
2168+
DISPLAY_NAME_TO_TYPE = {
2169+
tpe.all_field_infos()["type_"].default: tpe.all_field_infos()["display_name"].default
2170+
for tpe in subclasses(
2171+
__name__,
2172+
ConnectionConfig,
2173+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2174+
)
2175+
}
2176+
21142177

21152178
def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
21162179
if "type" not in v:

0 commit comments

Comments
 (0)