Skip to content

Commit 8411554

Browse files
committed
Make init project object creation generic
1 parent 6b45fc7 commit 8411554

2 files changed

Lines changed: 58 additions & 70 deletions

File tree

sqlmesh/cli/project_init.py

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -148,31 +148,29 @@ def _gen_config(
148148

149149
@dataclass
150150
class ExampleObjects:
151-
schema_name: str
152-
full_model_name: str
153-
full_model_def: str
154-
incremental_model_name: str
155-
incremental_model_def: str
156-
seed_model_name: str
157-
seed_model_def: str
158-
seed_data: str
159-
audit_def: str
160-
test_def: str
161-
162-
def models(self) -> t.Set[t.Tuple[str, str]]:
163-
return {
164-
(self.full_model_name, self.full_model_def),
165-
(self.incremental_model_name, self.incremental_model_def),
166-
(self.seed_model_name, self.seed_model_def),
167-
}
151+
sql_models: t.Dict[str, str]
152+
python_models: t.Dict[str, str]
153+
seeds: t.Dict[str, str]
154+
audits: t.Dict[str, str]
155+
tests: t.Dict[str, str]
156+
sql_macros: t.Dict[str, str]
157+
python_macros: t.Dict[str, str]
168158

169159

170160
def _gen_example_objects(schema_name: str) -> ExampleObjects:
161+
sql_models: t.Dict[str, str] = {}
162+
python_models: t.Dict[str, str] = {}
163+
seeds: t.Dict[str, str] = {}
164+
audits: t.Dict[str, str] = {}
165+
tests: t.Dict[str, str] = {}
166+
sql_macros: t.Dict[str, str] = {}
167+
python_macros: t.Dict[str, str] = {"__init__": ""}
168+
171169
full_model_name = f"{schema_name}.full_model"
172170
incremental_model_name = f"{schema_name}.incremental_model"
173171
seed_model_name = f"{schema_name}.seed_model"
174172

175-
full_model_def = f"""MODEL (
173+
sql_models[full_model_name] = f"""MODEL (
176174
name {full_model_name},
177175
kind FULL,
178176
cron '@daily',
@@ -188,7 +186,7 @@ def _gen_example_objects(schema_name: str) -> ExampleObjects:
188186
GROUP BY item_id
189187
"""
190188

191-
incremental_model_def = f"""MODEL (
189+
sql_models[incremental_model_name] = f"""MODEL (
192190
name {incremental_model_name},
193191
kind INCREMENTAL_BY_TIME_RANGE (
194192
time_column event_date
@@ -208,7 +206,7 @@ def _gen_example_objects(schema_name: str) -> ExampleObjects:
208206
event_date BETWEEN @start_date AND @end_date
209207
"""
210208

211-
seed_model_def = f"""MODEL (
209+
sql_models[seed_model_name] = f"""MODEL (
212210
name {seed_model_name},
213211
kind SEED (
214212
path '../seeds/seed_data.csv'
@@ -222,17 +220,7 @@ def _gen_example_objects(schema_name: str) -> ExampleObjects:
222220
);
223221
"""
224222

225-
audit_def = """AUDIT (
226-
name assert_positive_order_ids,
227-
);
228-
229-
SELECT *
230-
FROM @this_model
231-
WHERE
232-
item_id < 0
233-
"""
234-
235-
seed_data = """id,item_id,event_date
223+
seeds["seed_data"] = """id,item_id,event_date
236224
1,2,2020-01-01
237225
2,1,2020-01-01
238226
3,3,2020-01-03
@@ -242,7 +230,17 @@ def _gen_example_objects(schema_name: str) -> ExampleObjects:
242230
7,1,2020-01-07
243231
"""
244232

245-
test_def = f"""test_example_full_model:
233+
audits["assert_positive_order_ids"] = """AUDIT (
234+
name assert_positive_order_ids,
235+
);
236+
237+
SELECT *
238+
FROM @this_model
239+
WHERE
240+
item_id < 0
241+
"""
242+
243+
tests["test_example_full_model"] = f"""test_example_full_model:
246244
model: {full_model_name}
247245
inputs:
248246
{incremental_model_name}:
@@ -263,16 +261,13 @@ def _gen_example_objects(schema_name: str) -> ExampleObjects:
263261
"""
264262

265263
return ExampleObjects(
266-
schema_name=schema_name,
267-
full_model_name=full_model_name,
268-
full_model_def=full_model_def,
269-
incremental_model_name=incremental_model_name,
270-
incremental_model_def=incremental_model_def,
271-
seed_model_name=seed_model_name,
272-
seed_model_def=seed_model_def,
273-
seed_data=seed_data,
274-
audit_def=audit_def,
275-
test_def=test_def,
264+
sql_models=sql_models,
265+
python_models=python_models,
266+
seeds=seeds,
267+
audits=audits,
268+
tests=tests,
269+
python_macros=python_macros,
270+
sql_macros=sql_macros,
276271
)
277272

278273

@@ -321,7 +316,7 @@ def init_example_project(
321316
if engine_type and template == ProjectTemplate.DLT:
322317
dialect = DIALECT_TO_TYPE.get(engine_type)
323318
if pipeline and dialect:
324-
models, settings, start = generate_dlt_models_and_settings(
319+
dlt_models, settings, start = generate_dlt_models_and_settings(
325320
pipeline_name=pipeline, dialect=dialect, dlt_path=dlt_path
326321
)
327322
else:
@@ -336,17 +331,21 @@ def init_example_project(
336331
_create_folders([audits_path, macros_path, models_path, seeds_path, tests_path])
337332

338333
if template == ProjectTemplate.DLT:
339-
_create_models(models_path, models)
334+
_create_object_files(
335+
models_path, {model[0].split(".")[-1]: model[1] for model in dlt_models}, "sql"
336+
)
340337
return config_path
341338

342339
example_objects = _gen_example_objects(schema_name=schema_name)
343340

344341
if template != ProjectTemplate.EMPTY:
345-
_create_macros(macros_path)
346-
_create_audits(audits_path, example_objects)
347-
_create_models(models_path, example_objects.models())
348-
_create_seeds(seeds_path, example_objects)
349-
_create_tests(tests_path, example_objects)
342+
_create_object_files(models_path, example_objects.sql_models, "sql")
343+
_create_object_files(models_path, example_objects.python_models, "py")
344+
_create_object_files(seeds_path, example_objects.seeds, "csv")
345+
_create_object_files(audits_path, example_objects.audits, "sql")
346+
_create_object_files(tests_path, example_objects.tests, "yaml")
347+
_create_object_files(macros_path, example_objects.python_macros, "py")
348+
_create_object_files(macros_path, example_objects.sql_macros, "sql")
350349

351350
return config_path
352351

@@ -373,25 +372,10 @@ def _create_config(
373372
)
374373

375374

376-
def _create_macros(macros_path: Path) -> None:
377-
(macros_path / "__init__.py").touch()
378-
379-
380-
def _create_audits(audits_path: Path, example_objects: ExampleObjects) -> None:
381-
_write_file(audits_path / "assert_positive_order_ids.sql", example_objects.audit_def)
382-
383-
384-
def _create_models(models_path: Path, models: t.Set[t.Tuple[str, str]]) -> None:
385-
for model_name, model_def in models:
386-
_write_file(models_path / f"{model_name.split('.')[-1]}.sql", model_def)
387-
388-
389-
def _create_seeds(seeds_path: Path, example_objects: ExampleObjects) -> None:
390-
_write_file(seeds_path / "seed_data.csv", example_objects.seed_data)
391-
392-
393-
def _create_tests(tests_path: Path, example_objects: ExampleObjects) -> None:
394-
_write_file(tests_path / "test_full_model.yaml", example_objects.test_def)
375+
def _create_object_files(path: Path, object_dict: t.Dict[str, str], file_extension: str) -> None:
376+
for object_name, object_def in object_dict.items():
377+
# file name is table component of catalog.schema.table
378+
_write_file(path / f"{object_name.split('.')[-1]}.{file_extension}", object_def)
395379

396380

397381
def _write_file(path: Path, payload: str) -> None:

sqlmesh/integrations/dlt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def generate_dlt_models(
138138
force: bool,
139139
dlt_path: t.Optional[str] = None,
140140
) -> t.List[str]:
141-
from sqlmesh.cli.project_init import _create_models
141+
from sqlmesh.cli.project_init import _create_object_files
142142

143143
sqlmesh_models, _, _ = generate_dlt_models_and_settings(
144144
pipeline_name=pipeline_name,
@@ -152,7 +152,11 @@ def generate_dlt_models(
152152
sqlmesh_models = {model for model in sqlmesh_models if model[0] not in existing_models}
153153

154154
if sqlmesh_models:
155-
_create_models(models_path=context.path / "models", models=sqlmesh_models)
155+
_create_object_files(
156+
context.path / "models",
157+
{model[0].split(".")[-1]: model[1] for model in sqlmesh_models},
158+
"sql",
159+
)
156160
return [model[0] for model in sqlmesh_models]
157161
return []
158162

0 commit comments

Comments
 (0)