Skip to content

Commit abab037

Browse files
committed
Move init functions to example_project, add pyspark install
1 parent 591369f commit abab037

2 files changed

Lines changed: 128 additions & 119 deletions

File tree

sqlmesh/cli/example_project.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
from enum import Enum
33
from pathlib import Path
44
from dataclasses import dataclass
5-
5+
from rich.prompt import Prompt
6+
from rich.console import Console
67
from sqlmesh.integrations.dlt import generate_dlt_models_and_settings
78
from sqlmesh.utils.date import yesterday_ds
89
from sqlmesh.utils.errors import SQLMeshError
910

10-
from sqlmesh.core.config.connection import CONNECTION_CONFIG_TO_TYPE, DIALECT_TO_TYPE
11+
from sqlmesh.core.config.connection import (
12+
CONNECTION_CONFIG_TO_TYPE,
13+
DIALECT_TO_TYPE,
14+
INIT_DISPLAY_INFO_TO_TYPE,
15+
)
1116

1217

1318
PRIMITIVES = (str, int, bool, float)
@@ -390,3 +395,112 @@ def _create_tests(tests_path: Path, example_objects: ExampleObjects) -> None:
390395
def _write_file(path: Path, payload: str) -> None:
391396
with open(path, "w", encoding="utf-8") as fd:
392397
fd.write(payload)
398+
399+
400+
def interactive_init(
401+
path: Path,
402+
console: Console,
403+
project_template: t.Optional[ProjectTemplate] = None,
404+
) -> t.Tuple[ProjectTemplate, t.Optional[str], t.Optional[InitCliMode]]:
405+
console.print("──────────────────────────────")
406+
console.print("Welcome to SQLMesh!")
407+
408+
project_template = _init_template_prompt(console) if not project_template else project_template
409+
410+
if project_template == ProjectTemplate.DBT:
411+
return (project_template, None, None)
412+
413+
engine_type = _init_engine_prompt(console)
414+
cli_mode = _init_cli_mode_prompt(console)
415+
416+
return (project_template, engine_type, cli_mode)
417+
418+
419+
def _init_integer_prompt(
420+
console: Console, err_msg_entity: str, num_options: int, retry_func: t.Callable[[t.Any], t.Any]
421+
) -> int:
422+
err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c\n"
423+
while True:
424+
option_str = Prompt.ask("Enter a number", console=console)
425+
426+
value_error = False
427+
try:
428+
option_num = int(option_str)
429+
except ValueError:
430+
value_error = True
431+
432+
if value_error or option_num < 1 or option_num > num_options:
433+
console.print(
434+
err_msg.format(
435+
option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options
436+
),
437+
style="red",
438+
)
439+
continue
440+
console.print("")
441+
return option_num
442+
443+
444+
def _init_display_choices(values_dict: t.Dict[str, str], console: Console) -> t.Dict[int, str]:
445+
display_num_to_value = {}
446+
for i, value_str in enumerate(values_dict.keys()):
447+
console.print(f" \[{i + 1}] {value_str} {values_dict[value_str]}")
448+
display_num_to_value[i + 1] = value_str
449+
console.print("")
450+
return display_num_to_value
451+
452+
453+
def _init_template_prompt(console: Console) -> ProjectTemplate:
454+
console.print("──────────────────────────────\n")
455+
console.print("What type of project do you want to set up?\n")
456+
457+
# These are ordered for user display - do not reorder
458+
template_descriptions = {
459+
ProjectTemplate.DEFAULT.name: "- Create SQLMesh example project models and files",
460+
ProjectTemplate.DBT.value: " - You have an existing dbt project and want to run it with SQLMesh",
461+
ProjectTemplate.EMPTY.name: " - Create a SQLMesh configuration file and project directories only",
462+
}
463+
464+
display_num_to_template = _init_display_choices(template_descriptions, console)
465+
466+
template_num = _init_integer_prompt(
467+
console, "project type", len(template_descriptions), _init_template_prompt
468+
)
469+
470+
return ProjectTemplate(display_num_to_template[template_num].lower())
471+
472+
473+
def _init_engine_prompt(console: Console) -> str:
474+
console.print("──────────────────────────────\n")
475+
console.print("Choose your SQL engine:\n")
476+
477+
# INIT_DISPLAY_INFO_TO_TYPE is a dict of {engine_type: (display_order, display_name)}
478+
DISPLAY_NAME_TO_TYPE = {v[1]: k for k, v in INIT_DISPLAY_INFO_TO_TYPE.items()}
479+
ordered_engine_display_names = {
480+
info[1]: "" for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0])
481+
}
482+
display_num_to_display_name = _init_display_choices(ordered_engine_display_names, console)
483+
484+
engine_num = _init_integer_prompt(
485+
console, "engine", len(ordered_engine_display_names), _init_engine_prompt
486+
)
487+
488+
return DISPLAY_NAME_TO_TYPE[display_num_to_display_name[engine_num]]
489+
490+
491+
def _init_cli_mode_prompt(console: Console) -> InitCliMode:
492+
console.print("──────────────────────────────\n")
493+
console.print("Choose your SQLMesh CLI experience:\n")
494+
495+
cli_mode_descriptions = {
496+
InitCliMode.DEFAULT.name: "- See and control every detail",
497+
InitCliMode.SIMPLE.name: " - Automatically run changes and show summary output",
498+
}
499+
500+
display_num_to_cli_mode = _init_display_choices(cli_mode_descriptions, console)
501+
502+
cli_mode_num = _init_integer_prompt(
503+
console, "config", len(cli_mode_descriptions), _init_cli_mode_prompt
504+
)
505+
506+
return InitCliMode(display_num_to_cli_mode[cli_mode_num].lower())

sqlmesh/cli/main.py

Lines changed: 12 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,26 @@
66
import typing as t
77

88
import click
9-
from rich.prompt import Prompt
109
from sqlmesh import configure_logging, remove_excess_logs
1110
from sqlmesh.cli import error_handler
1211
from sqlmesh.cli import options as opt
13-
from sqlmesh.cli.example_project import ProjectTemplate, init_example_project, InitCliMode
12+
from sqlmesh.cli.example_project import (
13+
ProjectTemplate,
14+
init_example_project,
15+
InitCliMode,
16+
interactive_init,
17+
)
1418
from sqlmesh.core.analytics import cli_analytics
1519
from sqlmesh.core.console import configure_console, get_console
1620
from sqlmesh.utils import Verbosity
1721
from sqlmesh.core.config import load_configs
18-
from sqlmesh.core.config.connection import INIT_DISPLAY_INFO_TO_TYPE
1922
from sqlmesh.core.context import Context
2023
from sqlmesh.utils.date import TimeLike
2124
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
2225
from pathlib import Path
2326

2427
logger = logging.getLogger(__name__)
2528

26-
if t.TYPE_CHECKING:
27-
from rich.console import Console
2829

2930
SKIP_LOAD_COMMANDS = (
3031
"clean",
@@ -194,7 +195,7 @@ def init(
194195

195196
console = srich.console
196197

197-
project_template, engine_type, cli_mode = _interactive_init(ctx.obj, console, project_template)
198+
project_template, engine_type, cli_mode = interactive_init(ctx.obj, console, project_template)
198199

199200
config_path = init_example_project(
200201
path=ctx.obj,
@@ -206,8 +207,11 @@ def init(
206207
)
207208

208209
engine_install_text = ""
209-
if engine_type and engine_type not in ("duckdb", "motherduck", "spark"):
210-
engine_install_text = f'• Run command in CLI to install your SQL engine\'s Python dependencies: pip install "sqlmesh\\[{engine_type.replace("_", "")}]"\n'
210+
if engine_type and engine_type not in ("duckdb", "motherduck"):
211+
install_text = (
212+
"pyspark" if engine_type == "spark" else f"sqlmesh\\[{engine_type.replace('_', '')}]"
213+
)
214+
engine_install_text = f'• Run command in CLI to install your SQL engine\'s Python dependencies: pip install "{install_text}"\n'
211215
# interactive init does not support DLT template
212216
next_step_text = {
213217
ProjectTemplate.DEFAULT: f"{engine_install_text}• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}",
@@ -1282,115 +1286,6 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool
12821286
obj.import_state(input_file=input_file, clear=replace, confirm=confirm)
12831287

12841288

1285-
def _interactive_init(
1286-
path: Path,
1287-
console: Console,
1288-
project_template: t.Optional[ProjectTemplate] = None,
1289-
) -> t.Tuple[ProjectTemplate, t.Optional[str], t.Optional[InitCliMode]]:
1290-
console.print("──────────────────────────────")
1291-
console.print("Welcome to SQLMesh!")
1292-
1293-
project_template = _init_template_prompt(console) if not project_template else project_template
1294-
1295-
if project_template == ProjectTemplate.DBT:
1296-
return (project_template, None, None)
1297-
1298-
engine_type = _init_engine_prompt(console)
1299-
cli_mode = _init_cli_mode_prompt(console)
1300-
1301-
return (project_template, engine_type, cli_mode)
1302-
1303-
1304-
def _init_integer_prompt(
1305-
console: Console, err_msg_entity: str, num_options: int, retry_func: t.Callable[[t.Any], t.Any]
1306-
) -> int:
1307-
err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c\n"
1308-
while True:
1309-
option_str = Prompt.ask("Enter a number", console=console)
1310-
1311-
value_error = False
1312-
try:
1313-
option_num = int(option_str)
1314-
except ValueError:
1315-
value_error = True
1316-
1317-
if value_error or option_num < 1 or option_num > num_options:
1318-
console.print(
1319-
err_msg.format(
1320-
option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options
1321-
),
1322-
style="red",
1323-
)
1324-
continue
1325-
console.print("")
1326-
return option_num
1327-
1328-
1329-
def _init_display_choices(values_dict: t.Dict[str, str], console: Console) -> t.Dict[int, str]:
1330-
display_num_to_value = {}
1331-
for i, value_str in enumerate(values_dict.keys()):
1332-
console.print(f" \[{i + 1}] {value_str} {values_dict[value_str]}")
1333-
display_num_to_value[i + 1] = value_str
1334-
console.print("")
1335-
return display_num_to_value
1336-
1337-
1338-
def _init_template_prompt(console: Console) -> ProjectTemplate:
1339-
console.print("──────────────────────────────\n")
1340-
console.print("What type of project do you want to set up?\n")
1341-
1342-
# These are ordered for user display - do not reorder
1343-
template_descriptions = {
1344-
ProjectTemplate.DEFAULT.name: "- Create SQLMesh example project models and files",
1345-
ProjectTemplate.DBT.value: " - You have an existing dbt project and want to run it with SQLMesh",
1346-
ProjectTemplate.EMPTY.name: " - Create a SQLMesh configuration file and project directories only",
1347-
}
1348-
1349-
display_num_to_template = _init_display_choices(template_descriptions, console)
1350-
1351-
template_num = _init_integer_prompt(
1352-
console, "project type", len(template_descriptions), _init_template_prompt
1353-
)
1354-
1355-
return ProjectTemplate(display_num_to_template[template_num].lower())
1356-
1357-
1358-
def _init_engine_prompt(console: Console) -> str:
1359-
console.print("──────────────────────────────\n")
1360-
console.print("Choose your SQL engine:\n")
1361-
1362-
# INIT_DISPLAY_INFO_TO_TYPE is a dict of {engine_type: (display_order, display_name)}
1363-
DISPLAY_NAME_TO_TYPE = {v[1]: k for k, v in INIT_DISPLAY_INFO_TO_TYPE.items()}
1364-
ordered_engine_display_names = {
1365-
info[1]: "" for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0])
1366-
}
1367-
display_num_to_display_name = _init_display_choices(ordered_engine_display_names, console)
1368-
1369-
engine_num = _init_integer_prompt(
1370-
console, "engine", len(ordered_engine_display_names), _init_engine_prompt
1371-
)
1372-
1373-
return DISPLAY_NAME_TO_TYPE[display_num_to_display_name[engine_num]]
1374-
1375-
1376-
def _init_cli_mode_prompt(console: Console) -> InitCliMode:
1377-
console.print("──────────────────────────────\n")
1378-
console.print("Choose your SQLMesh CLI experience:\n")
1379-
1380-
cli_mode_descriptions = {
1381-
InitCliMode.DEFAULT.name: "- See and control every detail",
1382-
InitCliMode.SIMPLE.name: " - Automatically run changes and show summary output",
1383-
}
1384-
1385-
display_num_to_cli_mode = _init_display_choices(cli_mode_descriptions, console)
1386-
1387-
cli_mode_num = _init_integer_prompt(
1388-
console, "config", len(cli_mode_descriptions), _init_cli_mode_prompt
1389-
)
1390-
1391-
return InitCliMode(display_num_to_cli_mode[cli_mode_num].lower())
1392-
1393-
13941289
def _check_engine_installed(console: Console, engine_type: t.Optional[str] = None) -> None:
13951290
if not engine_type:
13961291
return

0 commit comments

Comments
 (0)