Skip to content

Commit 571d57f

Browse files
committed
fix: allow subclassing of config again
1 parent e9d3115 commit 571d57f

1 file changed

Lines changed: 33 additions & 38 deletions

File tree

sqlmesh/core/config/root.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import zlib
77

88
from pydantic import Field
9-
from pydantic.functional_validators import BeforeValidator
109
from sqlglot import exp
1110
from sqlglot.helper import first
1211
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
@@ -44,42 +43,11 @@
4443
from sqlmesh.core.user import User
4544
from sqlmesh.utils.date import to_timestamp, now
4645
from sqlmesh.utils.errors import ConfigError
47-
from sqlmesh.utils.pydantic import model_validator
48-
49-
50-
def validate_no_past_ttl(v: str) -> str:
51-
current_time = now()
52-
if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time):
53-
raise ValueError(
54-
f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`."
55-
)
56-
return v
57-
58-
59-
def gateways_ensure_dict(value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
60-
try:
61-
if not isinstance(value, GatewayConfig):
62-
GatewayConfig.parse_obj(value)
63-
return {"": value}
64-
except Exception:
65-
return value
66-
67-
68-
def validate_regex_key_dict(value: t.Dict[str | re.Pattern, t.Any]) -> t.Dict[re.Pattern, t.Any]:
69-
return compile_regex_mapping(value)
70-
46+
from sqlmesh.utils.pydantic import model_validator, field_validator
7147

7248
if t.TYPE_CHECKING:
7349
from sqlmesh.core._typing import Self
7450

75-
NoPastTTLString = str
76-
GatewayDict = t.Dict[str, GatewayConfig]
77-
RegexKeyDict = t.Dict[re.Pattern, str]
78-
else:
79-
NoPastTTLString = t.Annotated[str, BeforeValidator(validate_no_past_ttl)]
80-
GatewayDict = t.Annotated[t.Dict[str, GatewayConfig], BeforeValidator(gateways_ensure_dict)]
81-
RegexKeyDict = t.Annotated[t.Dict[re.Pattern, str], BeforeValidator(validate_regex_key_dict)]
82-
8351

8452
class Config(BaseConfig):
8553
"""An object used by a Context to configure your SQLMesh project.
@@ -121,7 +89,7 @@ class Config(BaseConfig):
12189
after_all: SQL statements or macros to be executed at the end of the `sqlmesh plan` and `sqlmesh run` commands.
12290
"""
12391

124-
gateways: GatewayDict = {"": GatewayConfig()}
92+
gateways: t.Dict[str, GatewayConfig] = {"": GatewayConfig()}
12593
default_connection: SerializableConnectionConfig = DuckDBConnectionConfig()
12694
default_test_connection_: t.Optional[SerializableConnectionConfig] = Field(
12795
default=None, alias="default_test_connection"
@@ -130,8 +98,8 @@ class Config(BaseConfig):
13098
default_gateway: str = ""
13199
notification_targets: t.List[NotificationTarget] = []
132100
project: str = ""
133-
snapshot_ttl: NoPastTTLString = c.DEFAULT_SNAPSHOT_TTL
134-
environment_ttl: t.Optional[NoPastTTLString] = c.DEFAULT_ENVIRONMENT_TTL
101+
snapshot_ttl: str = c.DEFAULT_SNAPSHOT_TTL
102+
environment_ttl: t.Optional[str] = c.DEFAULT_ENVIRONMENT_TTL
135103
ignore_patterns: t.List[str] = c.IGNORE_PATTERNS
136104
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT
137105
users: t.List[User] = []
@@ -141,12 +109,12 @@ class Config(BaseConfig):
141109
loader_kwargs: t.Dict[str, t.Any] = {}
142110
env_vars: t.Dict[str, str] = {}
143111
username: str = ""
144-
physical_schema_mapping: RegexKeyDict = {}
112+
physical_schema_mapping: t.Dict[re.Pattern, str] = {}
145113
environment_suffix_target: EnvironmentSuffixTarget = Field(
146114
default=EnvironmentSuffixTarget.default
147115
)
148116
gateway_managed_virtual_layer: bool = False
149-
environment_catalog_mapping: RegexKeyDict = {}
117+
environment_catalog_mapping: t.Dict[re.Pattern, str] = {}
150118
default_target_environment: str = c.PROD
151119
log_limit: int = c.DEFAULT_LOG_LIMIT
152120
cicd_bot: t.Optional[CICDBotConfig] = None
@@ -187,6 +155,33 @@ class Config(BaseConfig):
187155
_scheduler_config_validator = scheduler_config_validator # type: ignore
188156
_variables_validator = variables_validator
189157

158+
@field_validator("gateways", mode="before")
159+
@classmethod
160+
def _gateways_ensure_dict(cls, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
161+
try:
162+
if not isinstance(value, GatewayConfig):
163+
GatewayConfig.parse_obj(value)
164+
return {"": value}
165+
except Exception:
166+
return value
167+
168+
@field_validator("environment_catalog_mapping", "physical_schema_mapping", mode="before")
169+
@classmethod
170+
def _validate_regex_keys(
171+
cls, value: t.Dict[str | re.Pattern, t.Any]
172+
) -> t.Dict[re.Pattern, t.Any]:
173+
return compile_regex_mapping(value)
174+
175+
@field_validator("snapshot_ttl", "environment_ttl", mode="before")
176+
@classmethod
177+
def validate_no_past_ttl(cls, v: str) -> str:
178+
current_time = now()
179+
if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time):
180+
raise ValueError(
181+
f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`."
182+
)
183+
return v
184+
190185
@model_validator(mode="before")
191186
def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any:
192187
if not isinstance(data, dict):

0 commit comments

Comments
 (0)