-
Notifications
You must be signed in to change notification settings - Fork 380
Expand file tree
/
Copy pathdecorator.py
More file actions
248 lines (215 loc) · 9.25 KB
/
decorator.py
File metadata and controls
248 lines (215 loc) · 9.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from __future__ import annotations
import typing as t
from pathlib import Path
import inspect
import re
from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlmesh.core.config.common import VirtualEnvironmentMode
from sqlmesh.core.macros import MacroRegistry
from sqlmesh.core.signal import SignalRegistry
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.core import constants as c
from sqlmesh.core.dialect import MacroFunc, parse_one
from sqlmesh.core.model.definition import (
Model,
create_python_model,
create_sql_model,
create_models_from_blueprints,
get_model_name,
parse_defaults_properties,
render_meta_fields,
render_model_defaults,
)
from sqlmesh.core.model.kind import ModelKindName, _ModelKind
from sqlmesh.utils import registry_decorator, DECORATOR_RETURN_TYPE
from sqlmesh.utils.errors import ConfigError, raise_config_error
from sqlmesh.utils.metaprogramming import build_env, serialize_env
if t.TYPE_CHECKING:
from sqlmesh.core.audit import ModelAudit
class model(registry_decorator):
"""Specifies a function is a python based model."""
registry_name = "python_models"
_dialect: DialectType = None
def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs: t.Any) -> None:
if not is_sql and "columns" not in kwargs:
raise ConfigError("Python model must define column schema.")
self.name_provided = bool(name)
self.name = name or ""
self.is_sql = is_sql
self.kwargs = kwargs
# Make sure that argument values are expressions in order to pass validation in ModelMeta.
for function_call_attribute in ("audits", "signals"):
calls = self.kwargs.pop(function_call_attribute, [])
self.kwargs[function_call_attribute] = [
(
(call, {})
if isinstance(call, str)
else (
call[0],
{
arg_key: exp.convert(
tuple(arg_value) if isinstance(arg_value, list) else arg_value
)
for arg_key, arg_value in call[1].items()
},
)
)
for call in calls
]
if "default_catalog" in kwargs:
raise ConfigError("`default_catalog` cannot be set on a per-model basis.")
self.columns = {
column_name: (
column_type
if isinstance(column_type, exp.DataType)
else exp.DataType.build(
str(column_type), dialect=self.kwargs.get("dialect", self._dialect)
)
)
for column_name, column_type in self.kwargs.pop("columns", {}).items()
}
def __call__(
self, func: t.Callable[..., DECORATOR_RETURN_TYPE]
) -> t.Callable[..., DECORATOR_RETURN_TYPE]:
if not self.name_provided:
self.name = get_model_name(Path(inspect.getfile(func)))
return super().__call__(func)
def models(
self,
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
path: Path,
module_path: Path,
dialect: t.Optional[str] = None,
default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None,
**loader_kwargs: t.Any,
) -> t.List[Model]:
blueprints = self.kwargs.pop("blueprints", None)
if isinstance(blueprints, str):
blueprints = parse_one(blueprints, dialect=dialect)
if isinstance(blueprints, MacroFunc):
from sqlmesh.core.model.definition import render_expression
blueprints = render_expression(
expression=blueprints,
module_path=module_path,
macros=loader_kwargs.get("macros"),
jinja_macros=loader_kwargs.get("jinja_macros"),
variables=get_variables(None),
path=path,
dialect=dialect,
default_catalog=loader_kwargs.get("default_catalog"),
)
if not blueprints:
raise_config_error("Failed to render blueprints property", path)
if len(blueprints) > 1:
blueprints = [exp.Tuple(expressions=blueprints)]
blueprints = blueprints[0]
return create_models_from_blueprints(
gateway=self.kwargs.get("gateway"),
blueprints=blueprints,
get_variables=get_variables,
loader=self.model,
path=path,
module_path=module_path,
dialect=dialect,
default_catalog_per_gateway=default_catalog_per_gateway,
**loader_kwargs,
)
def model(
self,
*,
module_path: Path,
path: Path,
defaults: t.Optional[t.Dict[str, t.Any]] = None,
macros: t.Optional[MacroRegistry] = None,
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
signal_definitions: t.Optional[SignalRegistry] = None,
audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None,
dialect: t.Optional[str] = None,
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT,
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
project: str = "",
default_catalog: t.Optional[str] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
infer_names: t.Optional[bool] = False,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default,
) -> Model:
"""Get the model registered by this function."""
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
entrypoint = self.func.__name__
if not self.name_provided and not infer_names:
raise ConfigError("Python model must have a name.")
kind = self.kwargs.get("kind", None)
if kind is not None:
if isinstance(kind, _ModelKind):
from sqlmesh.core.console import get_console
get_console().log_warning(
f"""Python model "{self.name}"'s `kind` argument was passed a SQLMesh `{type(kind).__name__}` object. This may result in unexpected behavior - provide a dictionary instead."""
)
elif isinstance(kind, dict):
if "name" not in kind or not isinstance(kind.get("name"), ModelKindName):
raise ConfigError(
f"""Python model "{self.name}"'s `kind` dictionary must contain a `name` key with a valid ModelKindName enum value."""
)
build_env(self.func, env=env, name=entrypoint, path=module_path)
rendered_fields = render_meta_fields(
fields={"name": self.name, **self.kwargs},
module_path=module_path,
macros=macros,
jinja_macros=jinja_macros,
variables=variables,
path=path,
dialect=dialect,
default_catalog=default_catalog,
blueprint_variables=blueprint_variables,
)
rendered_name = rendered_fields["name"]
if isinstance(rendered_name, exp.Expression):
rendered_fields["name"] = rendered_name.sql(dialect=dialect)
rendered_defaults = (
render_model_defaults(
defaults=defaults,
module_path=module_path,
macros=macros,
jinja_macros=jinja_macros,
variables=variables,
path=path,
dialect=dialect,
default_catalog=default_catalog,
)
if defaults
else {}
)
rendered_defaults = parse_defaults_properties(rendered_defaults, dialect=dialect)
common_kwargs = {
"defaults": rendered_defaults,
"path": path,
"time_column_format": time_column_format,
"python_env": serialize_env(env, path=module_path),
"physical_schema_mapping": physical_schema_mapping,
"project": project,
"default_catalog": default_catalog,
"variables": variables,
"dialect": dialect,
"columns": self.columns if self.columns else None,
"module_path": module_path,
"macros": macros,
"jinja_macros": jinja_macros,
"audit_definitions": audit_definitions,
"signal_definitions": signal_definitions,
"blueprint_variables": blueprint_variables,
"virtual_environment_mode": virtual_environment_mode,
**rendered_fields,
}
for key in ("pre_statements", "post_statements", "on_virtual_update"):
statements = common_kwargs.get(key)
if statements:
common_kwargs[key] = [
parse_one(s, dialect=common_kwargs.get("dialect")) if isinstance(s, str) else s
for s in statements
]
if self.is_sql:
query = MacroFunc(this=exp.Anonymous(this=entrypoint))
return create_sql_model(query=query, **common_kwargs)
return create_python_model(entrypoint=entrypoint, **common_kwargs)