Skip to content

Commit e3968e5

Browse files
committed
Add dialect and display name to ConnectionConfig classes
1 parent b2a9541 commit e3968e5

3 files changed

Lines changed: 97 additions & 31 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlmesh.utils.date import yesterday_ds
99
from sqlmesh.utils.errors import SQLMeshError
1010

11-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
11+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DIALECT_TO_TYPE
1212

1313

1414
PRIMITIVES = (str, int, bool, float)
@@ -27,7 +27,7 @@ class InitCliMode(Enum):
2727

2828

2929
def _gen_config(
30-
dialect: t.Optional[str],
30+
engine_type: t.Optional[str],
3131
settings: t.Optional[str],
3232
start: t.Optional[str],
3333
template: ProjectTemplate,
@@ -39,7 +39,7 @@ def _gen_config(
3939
database: db.db"""
4040
)
4141

42-
engine = "mssql" if dialect == "tsql" else dialect
42+
engine = "mssql" if engine_type == "tsql" else engine_type
4343

4444
if not settings and template != ProjectTemplate.DBT:
4545
doc_link = "https://sqlmesh.readthedocs.io/en/stable/integrations/engines{engine_link}"
@@ -51,6 +51,9 @@ def _gen_config(
5151

5252
for name, field in CONNECTION_CONFIG_TO_TYPE[engine].model_fields.items():
5353
field_name = field.alias or name
54+
if field_name in ("dialect", "display_name"):
55+
continue
56+
5457
default_value = field.get_default()
5558

5659
if isinstance(default_value, Enum):
@@ -77,7 +80,7 @@ def _gen_config(
7780

7881
connection_settings = (
7982
" # For more information on configuring the connection to your execution engine, visit:\n"
80-
" # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connections\n"
83+
" # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection\n"
8184
f" # {doc_link.format(engine_link=engine_link)}\n{connection_settings}"
8285
)
8386

@@ -93,7 +96,7 @@ def _gen_config(
9396
# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults
9497
9598
model_defaults:
96-
dialect: {dialect}
99+
dialect: {DIALECT_TO_TYPE[engine]}
97100
start: {start or yesterday_ds()} # Start date for backfill history
98101
cron: '@daily' # Run models daily at 12am UTC (can override per model)
99102

sqlmesh/cli/main.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sqlmesh.core.console import configure_console, get_console
1616
from sqlmesh.utils import Verbosity
1717
from sqlmesh.core.config import load_configs
18-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE
18+
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DISPLAY_NAME_TO_TYPE
1919
from sqlmesh.core.context import Context
2020
from sqlmesh.utils.date import TimeLike
2121
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
@@ -41,24 +41,24 @@
4141
SKIP_CONTEXT_COMMANDS = ("init", "ui")
4242

4343
# These are ordered for user display - do not reorder
44-
ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE = {
45-
"DuckDB": "duckdb",
46-
"Snowflake": "snowflake",
47-
"Databricks": "databricks",
48-
"BigQuery": "bigquery",
49-
"MotherDuck": "duckdb",
50-
"ClickHouse": "clickhouse",
51-
"Redshift": "redshift",
52-
"Spark": "spark",
53-
"Trino": "trino",
54-
"Azure SQL": "azuresql",
55-
"MSSQL": "tsql",
56-
"Postgres": "postgres",
57-
"GCP Postgres": "gcp_postgres",
58-
"MySQL": "mysql",
59-
"Athena": "athena",
60-
"RisingWave": "risingwave",
61-
}
44+
ENGINE_TYPE_DISPLAY_ORDER = [
45+
"duckdb",
46+
"snowflake",
47+
"databricks",
48+
"bigquery",
49+
"motherduck",
50+
"clickhouse",
51+
"redshift",
52+
"spark",
53+
"trino",
54+
"azuresql",
55+
"mssql",
56+
"postgres",
57+
"gcp_postgres",
58+
"mysql",
59+
"athena",
60+
"risingwave",
61+
]
6262

6363

6464
def _sqlmesh_version() -> str:
@@ -231,7 +231,7 @@ def init(
231231
)
232232

233233
next_step_text = {
234-
ProjectTemplate.DEFAULT: f"• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}\nRun command in CLI: sqlmesh plan\n(Optional) Explain a plan: sqlmesh plan --explain",
234+
ProjectTemplate.DEFAULT: f"• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}",
235235
ProjectTemplate.DBT: "",
236236
}
237237
next_step_text[ProjectTemplate.EMPTY] = next_step_text[ProjectTemplate.DEFAULT]
@@ -1376,9 +1376,9 @@ def _init_engine_prompt(console: Console) -> str:
13761376
console.print("Choose your SQL engine:\n")
13771377

13781378
display_num_to_engine = {}
1379-
for i, engine in enumerate(ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE.keys()):
1380-
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{engine}")
1381-
display_num_to_engine[i + 1] = engine
1379+
for i, engine_type in enumerate(ENGINE_TYPE_DISPLAY_ORDER):
1380+
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{DISPLAY_NAME_TO_TYPE[engine_type]}")
1381+
display_num_to_engine[i + 1] = engine_type
13821382
console.print("")
13831383

13841384
# self._print("""Need another engine? See: https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines)
@@ -1388,10 +1388,10 @@ def _init_engine_prompt(console: Console) -> str:
13881388
# """)
13891389

13901390
engine_num = _init_integer_prompt(
1391-
console, "engine", len(ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE), _init_engine_prompt
1391+
console, "engine", len(ENGINE_TYPE_DISPLAY_ORDER), _init_engine_prompt
13921392
)
13931393

1394-
return ENGINE_DISPLAY_NAME_TO_CONNECTION_TYPE[display_num_to_engine[engine_num]]
1394+
return display_num_to_engine[engine_num]
13951395

13961396

13971397
def _init_cli_mode_prompt(console: Console) -> InitCliMode:
@@ -1426,5 +1426,5 @@ def _check_engine_installed(console: Console, engine_type: t.Optional[str] = Non
14261426
except ModuleNotFoundError:
14271427
install_command = f'pip install "sqlmesh[{engine_type}]"'
14281428
raise SQLMeshError(
1429-
f"Unable to load required Python dependencies for the {engine_type.upper()} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
1429+
f"Unable to load required Python dependencies for the {DISPLAY_NAME_TO_TYPE[engine_type]} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
14301430
)

sqlmesh/core/config/connection.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def validate(cls: t.Any, data: t.Any) -> t.Any:
8888

8989
class ConnectionConfig(abc.ABC, BaseConfig):
9090
type_: str
91+
dialect: str
92+
display_name: str
9193
concurrent_tasks: int
9294
register_comments: bool
9395
pre_ping: bool
@@ -464,6 +466,8 @@ class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
464466
"""Configuration for the MotherDuck connection."""
465467

466468
type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck")
469+
dialect: t.Literal["duckdb"] = "duckdb"
470+
display_name: t.Literal["MotherDuck"] = "MotherDuck"
467471

468472
@property
469473
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -488,6 +492,8 @@ class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
488492
"""Configuration for the DuckDB connection."""
489493

490494
type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb")
495+
dialect: t.Literal["duckdb"] = "duckdb"
496+
display_name: t.Literal["DuckDB"] = "DuckDB"
491497

492498

493499
class SnowflakeConnectionConfig(ConnectionConfig):
@@ -538,6 +544,8 @@ class SnowflakeConnectionConfig(ConnectionConfig):
538544
session_parameters: t.Optional[dict] = None
539545

540546
type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
547+
dialect: t.Literal["snowflake"] = "snowflake"
548+
display_name: t.Literal["Snowflake"] = "Snowflake"
541549

542550
_concurrent_tasks_validator = concurrent_tasks_validator
543551
_engine_import_validator = _get_engine_import_validator("snowflake", "snowflake")
@@ -731,6 +739,8 @@ class DatabricksConnectionConfig(ConnectionConfig):
731739
pre_ping: t.Literal[False] = False
732740

733741
type_: t.Literal["databricks"] = Field(alias="type", default="databricks")
742+
dialect: t.Literal["databricks"] = "databricks"
743+
display_name: t.Literal["Databricks"] = "Databricks"
734744

735745
_concurrent_tasks_validator = concurrent_tasks_validator
736746
_http_headers_validator = http_headers_validator
@@ -986,6 +996,8 @@ class BigQueryConnectionConfig(ConnectionConfig):
986996
pre_ping: t.Literal[False] = False
987997

988998
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
999+
dialect: t.Literal["bigquery"] = "bigquery"
1000+
display_name: t.Literal["BigQuery"] = "BigQuery"
9891001

9901002
_engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
9911003

@@ -1127,6 +1139,9 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
11271139
scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",)
11281140
driver: str = "pg8000"
11291141
type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
1142+
dialect: t.Literal["postgres"] = "postgres"
1143+
display_name: t.Literal["GCP Postgres"] = "GCP Postgres"
1144+
11301145
concurrent_tasks: int = 4
11311146
register_comments: bool = True
11321147
pre_ping: bool = True
@@ -1261,6 +1276,8 @@ class RedshiftConnectionConfig(ConnectionConfig):
12611276
pre_ping: bool = False
12621277

12631278
type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
1279+
dialect: t.Literal["redshift"] = "redshift"
1280+
display_name: t.Literal["Redshift"] = "Redshift"
12641281

12651282
_engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
12661283

@@ -1322,6 +1339,8 @@ class PostgresConnectionConfig(ConnectionConfig):
13221339
pre_ping: bool = True
13231340

13241341
type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
1342+
dialect: t.Literal["postgres"] = "postgres"
1343+
display_name: t.Literal["Postgres"] = "Postgres"
13251344

13261345
_engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
13271346

@@ -1375,6 +1394,8 @@ class MySQLConnectionConfig(ConnectionConfig):
13751394
pre_ping: bool = True
13761395

13771396
type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
1397+
dialect: t.Literal["mysql"] = "mysql"
1398+
display_name: t.Literal["MySQL"] = "MySQL"
13781399

13791400
_engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
13801401

@@ -1427,6 +1448,8 @@ class MSSQLConnectionConfig(ConnectionConfig):
14271448
pre_ping: bool = True
14281449

14291450
type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
1451+
dialect: t.Literal["tsql"] = "tsql"
1452+
display_name: t.Literal["MSSQL"] = "MSSQL"
14301453

14311454
_engine_import_validator = _get_engine_import_validator("pymssql", "mssql")
14321455

@@ -1464,6 +1487,8 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
14641487

14651488
class AzureSQLConnectionConfig(MSSQLConnectionConfig):
14661489
type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore
1490+
dialect: t.Literal["tsql"] = "tsql"
1491+
display_name: t.Literal["Azure SQL"] = "Azure SQL" # type: ignore
14671492

14681493
@property
14691494
def _extra_engine_config(self) -> t.Dict[str, t.Any]:
@@ -1484,6 +1509,8 @@ class SparkConnectionConfig(ConnectionConfig):
14841509
pre_ping: t.Literal[False] = False
14851510

14861511
type_: t.Literal["spark"] = Field(alias="type", default="spark")
1512+
dialect: t.Literal["spark"] = "spark"
1513+
display_name: t.Literal["Spark"] = "Spark"
14871514

14881515
_engine_import_validator = _get_engine_import_validator("pyspark", "spark")
14891516

@@ -1602,6 +1629,8 @@ class TrinoConnectionConfig(ConnectionConfig):
16021629
pre_ping: t.Literal[False] = False
16031630

16041631
type_: t.Literal["trino"] = Field(alias="type", default="trino")
1632+
dialect: t.Literal["trino"] = "trino"
1633+
display_name: t.Literal["Trino"] = "Trino"
16051634

16061635
_engine_import_validator = _get_engine_import_validator("trino", "trino")
16071636

@@ -1762,6 +1791,8 @@ class ClickhouseConnectionConfig(ConnectionConfig):
17621791
connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None
17631792

17641793
type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
1794+
dialect: t.Literal["clickhouse"] = "clickhouse"
1795+
display_name: t.Literal["ClickHouse"] = "ClickHouse"
17651796

17661797
_engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
17671798

@@ -1886,6 +1917,8 @@ class AthenaConnectionConfig(ConnectionConfig):
18861917
pre_ping: t.Literal[False] = False
18871918

18881919
type_: t.Literal["athena"] = Field(alias="type", default="athena")
1920+
dialect: t.Literal["athena"] = "athena"
1921+
display_name: t.Literal["Athena"] = "Athena"
18891922

18901923
_engine_import_validator = _get_engine_import_validator("pyathena", "athena")
18911924

@@ -1954,6 +1987,8 @@ class RisingwaveConnectionConfig(ConnectionConfig):
19541987
pre_ping: bool = True
19551988

19561989
type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
1990+
dialect: t.Literal["risingwave"] = "risingwave"
1991+
display_name: t.Literal["RisingWave"] = "RisingWave"
19571992

19581993
_engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
19591994

@@ -1998,6 +2033,34 @@ def init(cursor: t.Any) -> None:
19982033
)
19992034
}
20002035

2036+
CONNECTION_CONFIG_TO_TYPE = {
2037+
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
2038+
tpe.all_field_infos()["type_"].default: tpe
2039+
for tpe in subclasses(
2040+
__name__,
2041+
ConnectionConfig,
2042+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2043+
)
2044+
}
2045+
2046+
DIALECT_TO_TYPE = {
2047+
tpe.all_field_infos()["type_"].default: tpe.all_field_infos()["dialect"].default
2048+
for tpe in subclasses(
2049+
__name__,
2050+
ConnectionConfig,
2051+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2052+
)
2053+
}
2054+
2055+
DISPLAY_NAME_TO_TYPE = {
2056+
tpe.all_field_infos()["type_"].default: tpe.all_field_infos()["display_name"].default
2057+
for tpe in subclasses(
2058+
__name__,
2059+
ConnectionConfig,
2060+
exclude=(ConnectionConfig, BaseDuckDBConnectionConfig),
2061+
)
2062+
}
2063+
20012064

20022065
def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
20032066
if "type" not in v:

0 commit comments

Comments
 (0)