Skip to content

Commit 8510ad1

Browse files
committed
Move engine order to ConnectionConfig classes, remove engine check
1 parent e54b43e commit 8510ad1

7 files changed

Lines changed: 127 additions & 134 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
class ProjectTemplate(Enum):
1717
DEFAULT = "default"
18-
EMPTY = "empty"
1918
DBT = "dbt"
19+
EMPTY = "empty"
2020
DLT = "dlt"
2121

2222

@@ -48,8 +48,6 @@ def _gen_config(
4848

4949
for name, field in CONNECTION_CONFIG_TO_TYPE[engine_type].model_fields.items():
5050
field_name = field.alias or name
51-
if field_name in ("dialect", "display_name"):
52-
continue
5351

5452
default_value = field.get_default()
5553

@@ -279,9 +277,7 @@ def init_example_project(
279277
dlt_path: t.Optional[str] = None,
280278
schema_name: str = "sqlmesh_example",
281279
cli_mode: InitCliMode = InitCliMode.DEFAULT,
282-
) -> t.Union[str, Path]:
283-
from sqlmesh.cli.main import ENGINE_TYPE_DISPLAY_ORDER
284-
280+
) -> Path:
285281
root_path = Path(path)
286282
config_extension = "py" if template == ProjectTemplate.DBT else "yaml"
287283
config_path = root_path / f"config.{config_extension}"
@@ -296,32 +292,35 @@ def init_example_project(
296292
f"Found an existing config file '{config_path}'.\n\nPlease change to another directory or remove the existing file."
297293
)
298294

299-
engine_types = "'" + "', '".join(ENGINE_TYPE_DISPLAY_ORDER) + "'"
295+
if template == ProjectTemplate.DBT and not Path(root_path, "dbt_project.yml").exists():
296+
raise SQLMeshError(
297+
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\n Please add it or change directories before running `sqlmesh init` to set up your project."
298+
)
299+
300+
engine_types = "', '".join(CONNECTION_CONFIG_TO_TYPE)
300301
if engine_type is None and template != ProjectTemplate.DBT:
301302
raise SQLMeshError(
302303
f"Missing `engine` argument to `sqlmesh init` - please specify a SQL engine for your project. Options: '{engine_types}'."
303304
)
304305

305-
if engine_type and engine_type not in ENGINE_TYPE_DISPLAY_ORDER:
306+
if engine_type and engine_type not in CONNECTION_CONFIG_TO_TYPE:
306307
raise SQLMeshError(
307308
f"Invalid engine '{engine_type}'. Please specify one of '{engine_types}'."
308309
)
309310

310311
models: t.Set[t.Tuple[str, str]] = set()
311312
settings = None
312313
start = None
313-
if engine_type:
314-
dialect = DIALECT_TO_TYPE[engine_type]
315-
316-
if template == ProjectTemplate.DLT:
317-
if pipeline and dialect:
318-
models, settings, start = generate_dlt_models_and_settings(
319-
pipeline_name=pipeline, dialect=dialect, dlt_path=dlt_path
320-
)
321-
else:
322-
raise SQLMeshError(
323-
"Please provide a DLT pipeline with the `--dlt-pipeline` flag to generate a SQLMesh project from DLT."
324-
)
314+
if engine_type and template == ProjectTemplate.DLT:
315+
dialect = DIALECT_TO_TYPE.get(engine_type)
316+
if pipeline and dialect:
317+
models, settings, start = generate_dlt_models_and_settings(
318+
pipeline_name=pipeline, dialect=dialect, dlt_path=dlt_path
319+
)
320+
else:
321+
raise SQLMeshError(
322+
"Please provide a DLT pipeline with the `--dlt-pipeline` flag to generate a SQLMesh project from DLT."
323+
)
325324

326325
_create_config(config_path, engine_type, settings, start, template, cli_mode)
327326
if template == ProjectTemplate.DBT:

sqlmesh/cli/main.py

Lines changed: 34 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sqlmesh.core.console import configure_console, get_console
1616
from sqlmesh.utils import Verbosity
1717
from sqlmesh.core.config import load_configs
18-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DISPLAY_NAME_TO_TYPE
18+
from sqlmesh.core.config.connection import INIT_DISPLAY_INFO_TO_TYPE
1919
from sqlmesh.core.context import Context
2020
from sqlmesh.utils.date import TimeLike
2121
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
@@ -41,26 +41,6 @@
4141
)
4242
SKIP_CONTEXT_COMMANDS = ("init", "ui")
4343

44-
# These are ordered for user display - do not reorder
45-
ENGINE_TYPE_DISPLAY_ORDER = [
46-
"duckdb",
47-
"snowflake",
48-
"databricks",
49-
"bigquery",
50-
"motherduck",
51-
"clickhouse",
52-
"redshift",
53-
"spark",
54-
"trino",
55-
"azuresql",
56-
"mssql",
57-
"postgres",
58-
"gcp_postgres",
59-
"mysql",
60-
"athena",
61-
"risingwave",
62-
]
63-
6444

6545
def _sqlmesh_version() -> str:
6646
try:
@@ -195,23 +175,16 @@ def init(
195175
try:
196176
project_template = ProjectTemplate(template.lower())
197177
except ValueError:
198-
template_strings = (
199-
"'" + "', '".join([template.value for template in ProjectTemplate]) + "'"
200-
)
178+
template_strings = "', '".join([template.value for template in ProjectTemplate])
201179
raise click.ClickException(
202-
f"Invalid project template '{template}'. Please specify one of {template_strings}."
180+
f"Invalid project template '{template}'. Please specify one of '{template_strings}'."
203181
)
204182

205-
if project_template == ProjectTemplate.DBT and not Path(ctx.obj, "dbt_project.yml").exists():
206-
raise click.ClickException(
207-
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\n Please add it or change directories before running `sqlmesh init` to set up your dbt project with SQLMesh."
208-
)
209-
210183
if engine or project_template == ProjectTemplate.DBT:
211184
init_example_project(
212185
path=ctx.obj,
213-
engine_type=engine,
214186
template=project_template or ProjectTemplate.DEFAULT,
187+
engine_type=engine,
215188
pipeline=dlt_pipeline,
216189
dlt_path=dlt_path,
217190
)
@@ -222,8 +195,6 @@ def init(
222195
console = srich.console
223196

224197
project_template, engine_type, cli_mode = _interactive_init(ctx.obj, console, project_template)
225-
if project_template != ProjectTemplate.DBT:
226-
_check_engine_installed(console, engine_type)
227198

228199
config_path = init_example_project(
229200
path=ctx.obj,
@@ -1318,10 +1289,6 @@ def _interactive_init(
13181289
project_template = _init_template_prompt(console) if not project_template else project_template
13191290

13201291
if project_template == ProjectTemplate.DBT:
1321-
if not Path(path, "dbt_project.yml").exists():
1322-
raise SQLMeshError(
1323-
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\n Please add it or change directories before running `sqlmesh init` to set up your dbt project with SQLMesh."
1324-
)
13251292
return (project_template, None, None)
13261293

13271294
engine_type = _init_engine_prompt(console)
@@ -1333,24 +1300,26 @@ def _interactive_init(
13331300
def _init_integer_prompt(
13341301
console: Console, err_msg_entity: str, num_options: int, retry_func: t.Callable[[t.Any], t.Any]
13351302
) -> int:
1336-
err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c"
1337-
option_str = Prompt.ask("Enter a number", console=console)
1338-
try:
1339-
option_num = int(option_str)
1340-
if option_num < 1 or option_num > num_options:
1341-
raise ValueError
1342-
except ValueError:
1343-
console.print(
1344-
err_msg.format(
1345-
option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options
1346-
),
1347-
style="red",
1348-
)
1349-
return retry_func(console)
1350-
finally:
1351-
console.print("")
1303+
err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c\n"
1304+
while True:
1305+
option_str = Prompt.ask("Enter a number", console=console)
13521306

1353-
return option_num
1307+
value_error = False
1308+
try:
1309+
option_num = int(option_str)
1310+
except ValueError:
1311+
value_error = True
1312+
1313+
if value_error or option_num < 1 or option_num > num_options:
1314+
console.print(
1315+
err_msg.format(
1316+
option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options
1317+
),
1318+
style="red",
1319+
)
1320+
continue
1321+
console.print("")
1322+
return option_num
13541323

13551324

13561325
def _init_template_prompt(console: Console) -> ProjectTemplate:
@@ -1381,17 +1350,22 @@ def _init_engine_prompt(console: Console) -> str:
13811350
console.print("──────────────────────────────\n")
13821351
console.print("Choose your SQL engine:\n")
13831352

1384-
display_num_to_engine = {}
1385-
for i, engine_type in enumerate(ENGINE_TYPE_DISPLAY_ORDER):
1386-
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{DISPLAY_NAME_TO_TYPE[engine_type]}")
1387-
display_num_to_engine[i + 1] = engine_type
1353+
# INIT_DISPLAY_INFO_TO_TYPE is a dict of {engine_type: (display_order, display_name)}
1354+
ordered_engine_display_names = [
1355+
info[1] for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0])
1356+
]
1357+
display_num_to_display_name = {}
1358+
for i, display_name in enumerate(ordered_engine_display_names):
1359+
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{display_name}")
1360+
display_num_to_display_name[i + 1] = display_name
13881361
console.print("")
13891362

13901363
engine_num = _init_integer_prompt(
1391-
console, "engine", len(ENGINE_TYPE_DISPLAY_ORDER), _init_engine_prompt
1364+
console, "engine", len(ordered_engine_display_names), _init_engine_prompt
13921365
)
13931366

1394-
return display_num_to_engine[engine_num]
1367+
DISPLAY_NAME_TO_TYPE = {v[1]: k for k, v in INIT_DISPLAY_INFO_TO_TYPE.items()}
1368+
return DISPLAY_NAME_TO_TYPE[display_num_to_display_name[engine_num]]
13951369

13961370

13971371
def _init_cli_mode_prompt(console: Console) -> InitCliMode:

0 commit comments

Comments
 (0)