1414# limitations under the License.
1515from __future__ import annotations
1616
17- import os
1817from typing import (
1918 TYPE_CHECKING ,
2019 Any ,
2726)
2827
2928from pydantic import (
30- BaseModel ,
31- ConfigDict ,
3229 Field ,
3330)
34- import yaml
3531
3632# This lets all readers and writers to be findable via config
3733from datacustomcode .io import * # noqa: F403
4137from datacustomcode .proxy .base import BaseProxyAccessLayer
4238from datacustomcode .proxy .client .base import BaseProxyClient # noqa: TCH002
4339from datacustomcode .spark .base import BaseSparkSessionProvider
44-
45- DEFAULT_CONFIG_NAME = "config.yaml"
40+ from datacustomcode .common_config import ForceableConfig , BaseObjectConfig , BaseConfig , default_config_file
4641
4742
4843if TYPE_CHECKING :
4944 from pyspark .sql import SparkSession
5045
5146
52- class ForceableConfig (BaseModel ):
53- force : bool = Field (
54- default = False ,
55- description = "If True, this takes precedence over parameters passed to the "
56- "initializer of the client." ,
57- )
58-
59-
6047_T = TypeVar ("_T" , bound = "BaseDataAccessLayer" )
6148
6249
63- class AccessLayerObjectConfig (ForceableConfig , Generic [_T ]):
64- model_config = ConfigDict (validate_default = True , extra = "forbid" )
50+ class AccessLayerObjectConfig (BaseObjectConfig , Generic [_T ]):
6551 type_base : ClassVar [Type [BaseDataAccessLayer ]] = BaseDataAccessLayer
66- type_config_name : str = Field (
67- description = "The config name of the object to create. "
68- "For metrics, this would might be 'ipmnormal'. For custom classes, you can "
69- "assign a name to a class variable `CONFIG_NAME` and reference it here." ,
70- )
71- options : dict [str , Any ] = Field (
72- default_factory = dict ,
73- description = "Options passed to the constructor." ,
74- )
75-
7652 def to_object (self , spark : SparkSession ) -> _T :
7753 type_ = self .type_base .subclass_from_config_name (self .type_config_name )
7854 return cast (_T , type_ (spark = spark , ** self .options ))
@@ -97,35 +73,22 @@ class SparkConfig(ForceableConfig):
9773_PX = TypeVar ("_PX" , bound = BaseProxyAccessLayer )
9874
9975
100- class ProxyAccessLayerObjectConfig (ForceableConfig , Generic [_PX ]):
76+ class ProxyAccessLayerObjectConfig (BaseObjectConfig , Generic [_PX ]):
10177 """Config for proxy clients that take no constructor args (e.g. no spark)."""
102-
103- model_config = ConfigDict (validate_default = True , extra = "forbid" )
10478 type_base : ClassVar [Type [BaseProxyAccessLayer ]] = BaseProxyAccessLayer
105- type_config_name : str = Field (
106- description = "CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient')." ,
107- )
108- options : dict [str , Any ] = Field (default_factory = dict )
109-
11079 def to_object (self ) -> _PX :
11180 type_ = self .type_base .subclass_from_config_name (self .type_config_name )
11281 return cast (_PX , type_ (** self .options ))
11382
11483
115- class SparkProviderConfig (ForceableConfig , Generic [_P ]):
116- model_config = ConfigDict (validate_default = True , extra = "forbid" )
84+ class SparkProviderConfig (BaseObjectConfig , Generic [_P ]):
11785 type_base : ClassVar [Type [BaseSparkSessionProvider ]] = BaseSparkSessionProvider
118- type_config_name : str = Field (
119- description = "CONFIG_NAME of the Spark session provider."
120- )
121- options : dict [str , Any ] = Field (default_factory = dict )
122-
12386 def to_object (self ) -> _P :
12487 type_ = self .type_base .subclass_from_config_name (self .type_config_name )
12588 return cast (_P , type_ (** self .options ))
12689
12790
128- class ClientConfig (BaseModel ):
91+ class ClientConfig (BaseConfig ):
12992 reader_config : Union [AccessLayerObjectConfig [BaseDataCloudReader ], None ] = None
13093 writer_config : Union [AccessLayerObjectConfig [BaseDataCloudWriter ], None ] = None
13194 proxy_config : Union [ProxyAccessLayerObjectConfig [BaseProxyClient ], None ] = None
@@ -163,31 +126,9 @@ def merge(
163126 )
164127 return self
165128
166- def load (self , config_path : str ) -> ClientConfig :
167- """Load a config from a file and update this config with it.
168-
169- Args:
170- config_path: The path to the config file
171-
172- Returns:
173- Self, with updated values from the loaded config.
174- """
175- with open (config_path , "r" ) as f :
176- config_data = yaml .safe_load (f )
177- loaded_config = ClientConfig .model_validate (config_data )
178-
179- return self .update (loaded_config )
180-
181-
182- config = ClientConfig ()
183129"""Global config object.
184130
185131This is the object that makes config accessible globally and globally mutable.
186132"""
187-
188-
189- def _defaults () -> str :
190- return os .path .join (os .path .dirname (__file__ ), DEFAULT_CONFIG_NAME )
191-
192-
193- config .load (_defaults ())
133+ config = ClientConfig ()
134+ config .load (default_config_file ())
0 commit comments