Skip to content

Commit 724d9ac

Browse files
committed
Move engine order to ConnectionConfig classes, remove engine check
1 parent cc09847 commit 724d9ac

7 files changed

Lines changed: 127 additions & 148 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
class ProjectTemplate(Enum):
1717
DEFAULT = "default"
18-
EMPTY = "empty"
1918
DBT = "dbt"
19+
EMPTY = "empty"
2020
DLT = "dlt"
2121

2222

@@ -48,8 +48,6 @@ def _gen_config(
4848

4949
for name, field in CONNECTION_CONFIG_TO_TYPE[engine_type].model_fields.items():
5050
field_name = field.alias or name
51-
if field_name in ("dialect", "display_name"):
52-
continue
5351

5452
default_value = field.get_default()
5553

@@ -279,9 +277,7 @@ def init_example_project(
279277
dlt_path: t.Optional[str] = None,
280278
schema_name: str = "sqlmesh_example",
281279
cli_mode: InitCliMode = InitCliMode.DEFAULT,
282-
) -> t.Union[str, Path]:
283-
from sqlmesh.cli.main import ENGINE_TYPE_DISPLAY_ORDER
284-
280+
) -> Path:
285281
root_path = Path(path)
286282
config_extension = "py" if template == ProjectTemplate.DBT else "yaml"
287283
config_path = root_path / f"config.{config_extension}"
@@ -296,32 +292,35 @@ def init_example_project(
296292
f"Found an existing config file '{config_path}'.\n\nPlease change to another directory or remove the existing file."
297293
)
298294

299-
engine_types = "'" + "', '".join(ENGINE_TYPE_DISPLAY_ORDER) + "'"
295+
if template == ProjectTemplate.DBT and not Path(root_path, "dbt_project.yml").exists():
296+
raise SQLMeshError(
297+
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\n Please add it or change directories before running `sqlmesh init` to set up your project."
298+
)
299+
300+
engine_types = "', '".join(CONNECTION_CONFIG_TO_TYPE)
300301
if engine_type is None and template != ProjectTemplate.DBT:
301302
raise SQLMeshError(
302303
f"Missing `engine` argument to `sqlmesh init` - please specify a SQL engine for your project. Options: '{engine_types}'."
303304
)
304305

305-
if engine_type and engine_type not in ENGINE_TYPE_DISPLAY_ORDER:
306+
if engine_type and engine_type not in CONNECTION_CONFIG_TO_TYPE:
306307
raise SQLMeshError(
307308
f"Invalid engine '{engine_type}'. Please specify one of '{engine_types}'."
308309
)
309310

310311
models: t.Set[t.Tuple[str, str]] = set()
311312
settings = None
312313
start = None
313-
if engine_type:
314-
dialect = DIALECT_TO_TYPE[engine_type]
315-
316-
if template == ProjectTemplate.DLT:
317-
if pipeline and dialect:
318-
models, settings, start = generate_dlt_models_and_settings(
319-
pipeline_name=pipeline, dialect=dialect, dlt_path=dlt_path
320-
)
321-
else:
322-
raise SQLMeshError(
323-
"Please provide a DLT pipeline with the `--dlt-pipeline` flag to generate a SQLMesh project from DLT."
324-
)
314+
if engine_type and template == ProjectTemplate.DLT:
315+
dialect = DIALECT_TO_TYPE.get(engine_type)
316+
if pipeline and dialect:
317+
models, settings, start = generate_dlt_models_and_settings(
318+
pipeline_name=pipeline, dialect=dialect, dlt_path=dlt_path
319+
)
320+
else:
321+
raise SQLMeshError(
322+
"Please provide a DLT pipeline with the `--dlt-pipeline` flag to generate a SQLMesh project from DLT."
323+
)
325324

326325
_create_config(config_path, engine_type, settings, start, template, cli_mode)
327326
if template == ProjectTemplate.DBT:

sqlmesh/cli/main.py

Lines changed: 34 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sqlmesh.core.console import configure_console, get_console
1616
from sqlmesh.utils import Verbosity
1717
from sqlmesh.core.config import load_configs
18-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DISPLAY_NAME_TO_TYPE
18+
from sqlmesh.core.config.connection import INIT_DISPLAY_INFO_TO_TYPE
1919
from sqlmesh.core.context import Context
2020
from sqlmesh.utils.date import TimeLike
2121
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
@@ -40,26 +40,6 @@
4040
)
4141
SKIP_CONTEXT_COMMANDS = ("init", "ui")
4242

43-
# These are ordered for user display - do not reorder
44-
ENGINE_TYPE_DISPLAY_ORDER = [
45-
"duckdb",
46-
"snowflake",
47-
"databricks",
48-
"bigquery",
49-
"motherduck",
50-
"clickhouse",
51-
"redshift",
52-
"spark",
53-
"trino",
54-
"azuresql",
55-
"mssql",
56-
"postgres",
57-
"gcp_postgres",
58-
"mysql",
59-
"athena",
60-
"risingwave",
61-
]
62-
6343

6444
def _sqlmesh_version() -> str:
6545
try:
@@ -194,23 +174,16 @@ def init(
194174
try:
195175
project_template = ProjectTemplate(template.lower())
196176
except ValueError:
197-
template_strings = (
198-
"'" + "', '".join([template.value for template in ProjectTemplate]) + "'"
199-
)
177+
template_strings = "', '".join([template.value for template in ProjectTemplate])
200178
raise click.ClickException(
201-
f"Invalid project template '{template}'. Please specify one of {template_strings}."
179+
f"Invalid project template '{template}'. Please specify one of '{template_strings}'."
202180
)
203181

204-
if project_template == ProjectTemplate.DBT and not Path(ctx.obj, "dbt_project.yml").exists():
205-
raise click.ClickException(
206-
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\n Please add it or change directories before running `sqlmesh init` to set up your dbt project with SQLMesh."
207-
)
208-
209182
if engine or project_template == ProjectTemplate.DBT:
210183
init_example_project(
211184
path=ctx.obj,
212-
engine_type=engine,
213185
template=project_template or ProjectTemplate.DEFAULT,
186+
engine_type=engine,
214187
pipeline=dlt_pipeline,
215188
dlt_path=dlt_path,
216189
)
@@ -221,8 +194,6 @@ def init(
221194
console = srich.console
222195

223196
project_template, engine_type, cli_mode = _interactive_init(ctx.obj, console, project_template)
224-
if project_template != ProjectTemplate.DBT:
225-
_check_engine_installed(console, engine_type)
226197

227198
config_path = init_example_project(
228199
path=ctx.obj,
@@ -1317,10 +1288,6 @@ def _interactive_init(
13171288
project_template = _init_template_prompt(console) if not project_template else project_template
13181289

13191290
if project_template == ProjectTemplate.DBT:
1320-
if not Path(path, "dbt_project.yml").exists():
1321-
raise SQLMeshError(
1322-
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\n Please add it or change directories before running `sqlmesh init` to set up your dbt project with SQLMesh."
1323-
)
13241291
return (project_template, None, None)
13251292

13261293
engine_type = _init_engine_prompt(console)
@@ -1332,24 +1299,26 @@ def _interactive_init(
13321299
def _init_integer_prompt(
13331300
console: Console, err_msg_entity: str, num_options: int, retry_func: t.Callable[[t.Any], t.Any]
13341301
) -> int:
1335-
err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c"
1336-
option_str = Prompt.ask("Enter a number", console=console)
1337-
try:
1338-
option_num = int(option_str)
1339-
if option_num < 1 or option_num > num_options:
1340-
raise ValueError
1341-
except ValueError:
1342-
console.print(
1343-
err_msg.format(
1344-
option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options
1345-
),
1346-
style="red",
1347-
)
1348-
return retry_func(console)
1349-
finally:
1350-
console.print("")
1302+
err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c\n"
1303+
while True:
1304+
option_str = Prompt.ask("Enter a number", console=console)
13511305

1352-
return option_num
1306+
value_error = False
1307+
try:
1308+
option_num = int(option_str)
1309+
except ValueError:
1310+
value_error = True
1311+
1312+
if value_error or option_num < 1 or option_num > num_options:
1313+
console.print(
1314+
err_msg.format(
1315+
option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options
1316+
),
1317+
style="red",
1318+
)
1319+
continue
1320+
console.print("")
1321+
return option_num
13531322

13541323

13551324
def _init_template_prompt(console: Console) -> ProjectTemplate:
@@ -1380,17 +1349,22 @@ def _init_engine_prompt(console: Console) -> str:
13801349
console.print("──────────────────────────────\n")
13811350
console.print("Choose your SQL engine:\n")
13821351

1383-
display_num_to_engine = {}
1384-
for i, engine_type in enumerate(ENGINE_TYPE_DISPLAY_ORDER):
1385-
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{DISPLAY_NAME_TO_TYPE[engine_type]}")
1386-
display_num_to_engine[i + 1] = engine_type
1352+
# INIT_DISPLAY_INFO_TO_TYPE is a dict of {engine_type: (display_order, display_name)}
1353+
ordered_engine_display_names = [
1354+
info[1] for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0])
1355+
]
1356+
display_num_to_display_name = {}
1357+
for i, display_name in enumerate(ordered_engine_display_names):
1358+
console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{display_name}")
1359+
display_num_to_display_name[i + 1] = display_name
13871360
console.print("")
13881361

13891362
engine_num = _init_integer_prompt(
1390-
console, "engine", len(ENGINE_TYPE_DISPLAY_ORDER), _init_engine_prompt
1363+
console, "engine", len(ordered_engine_display_names), _init_engine_prompt
13911364
)
13921365

1393-
return display_num_to_engine[engine_num]
1366+
DISPLAY_NAME_TO_TYPE = {v[1]: k for k, v in INIT_DISPLAY_INFO_TO_TYPE.items()}
1367+
return DISPLAY_NAME_TO_TYPE[display_num_to_display_name[engine_num]]
13941368

13951369

13961370
def _init_cli_mode_prompt(console: Console) -> InitCliMode:
@@ -1413,17 +1387,3 @@ def _init_cli_mode_prompt(console: Console) -> InitCliMode:
14131387
)
14141388

14151389
return InitCliMode(display_num_to_cli_mode[cli_mode_num].lower())
1416-
1417-
1418-
def _check_engine_installed(console: Console, engine_type: t.Optional[str] = None) -> None:
1419-
if not engine_type:
1420-
return
1421-
connection_config = CONNECTION_CONFIG_TO_TYPE[engine_type]
1422-
1423-
try:
1424-
connection_config._connection_factory.fget(None)
1425-
except ModuleNotFoundError:
1426-
install_command = f'pip install "sqlmesh[{engine_type}]"'
1427-
raise SQLMeshError(
1428-
f"Unable to load required Python dependencies for the {DISPLAY_NAME_TO_TYPE[engine_type]} engine.\n\nPlease run `{install_command}` to install them before running `sqlmesh init` again."
1429-
)

0 commit comments

Comments
 (0)