Skip to content

Commit aae5341

Browse files
committed
Allow for spark session provider override
1 parent bf74399 commit aae5341

6 files changed

Lines changed: 119 additions & 45 deletions

File tree

src/datacustomcode/client.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from pyspark.sql import SparkSession
2525

2626
from datacustomcode.config import SparkConfig, config
27+
from datacustomcode.spark.base import BaseSparkSessionProvider
28+
from datacustomcode.spark.default import DefaultSparkSessionProvider
2729
from datacustomcode.file.path.default import DefaultFindFilePath
2830
from datacustomcode.io.reader.base import BaseDataCloudReader
2931

@@ -36,18 +38,6 @@
3638
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
3739

3840

39-
def _setup_spark(spark_config: SparkConfig) -> SparkSession:
40-
"""Setup Spark session from config."""
41-
builder = SparkSession.builder
42-
if spark_config.master is not None:
43-
builder = builder.master(spark_config.master)
44-
45-
builder = builder.appName(spark_config.app_name)
46-
for key, value in spark_config.options.items():
47-
builder = builder.config(key, value)
48-
return builder.getOrCreate()
49-
50-
5141
class DataCloudObjectType(Enum):
5242
DLO = "dlo"
5343
DMO = "dmo"
@@ -124,6 +114,7 @@ def __new__(
124114
cls,
125115
reader: Optional[BaseDataCloudReader] = None,
126116
writer: Optional[BaseDataCloudWriter] = None,
117+
spark_provider: Optional[BaseSparkSessionProvider] = None,
127118
) -> Client:
128119
if cls._instance is None:
129120
cls._instance = super().__new__(cls)
@@ -136,7 +127,15 @@ def __new__(
136127
raise ValueError(
137128
"Spark config is required when reader/writer is not provided"
138129
)
139-
spark = _setup_spark(config.spark_config)
130+
provider: BaseSparkSessionProvider
131+
if spark_provider is not None:
132+
provider = spark_provider
133+
elif config.spark_provider_config is not None:
134+
provider = config.spark_provider_config.to_object()
135+
else:
136+
provider = DefaultSparkSessionProvider()
137+
138+
spark = provider.get_session(config.spark_config)
140139

141140
if config.reader_config is None and reader is None:
142141
raise ValueError(

src/datacustomcode/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
# This lets all readers and writers to be findable via config
3737
from datacustomcode.io import * # noqa: F403
3838
from datacustomcode.io.base import BaseDataAccessLayer
39+
from datacustomcode.spark.base import BaseSparkSessionProvider
40+
from datacustomcode.spark.default import DefaultSparkSessionProvider
3941
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
4042
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
4143

@@ -89,10 +91,25 @@ class SparkConfig(ForceableConfig):
8991
)
9092

9193

94+
_P = TypeVar("_P", bound=BaseSparkSessionProvider)
95+
96+
97+
class SparkProviderConfig(ForceableConfig, Generic[_P]):
98+
model_config = ConfigDict(validate_default=True, extra="forbid")
99+
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
100+
type_config_name: str = Field(description="CONFIG_NAME of the Spark session provider.")
101+
options: dict[str, Any] = Field(default_factory=dict)
102+
103+
def to_object(self) -> _P:
104+
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
105+
return cast(_P, type_(**self.options))
106+
107+
92108
class ClientConfig(BaseModel):
93109
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
94110
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
95111
spark_config: Union[SparkConfig, None] = None
112+
spark_provider_config: Union[SparkProviderConfig[BaseSparkSessionProvider], None] = None
96113

97114
def update(self, other: ClientConfig) -> ClientConfig:
98115
"""Merge this ClientConfig with another, respecting force flags.
@@ -117,6 +134,7 @@ def merge(
117134
self.reader_config = merge(self.reader_config, other.reader_config)
118135
self.writer_config = merge(self.writer_config, other.writer_config)
119136
self.spark_config = merge(self.spark_config, other.spark_config)
137+
self.spark_provider_config = merge(self.spark_provider_config, other.spark_provider_config)
120138
return self
121139

122140
def load(self, config_path: str) -> ClientConfig:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
4+
from datacustomcode.spark.base import BaseSparkSessionProvider
5+
from datacustomcode.spark.default import DefaultSparkSessionProvider
6+
7+
__all__ = ["BaseSparkSessionProvider", "DefaultSparkSessionProvider"]
8+

src/datacustomcode/spark/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
6+
7+
if TYPE_CHECKING:
8+
from pyspark.sql import SparkSession
9+
from datacustomcode.config import SparkConfig
10+
11+
12+
class BaseSparkSessionProvider(UserExtendableNamedConfigMixin):
13+
def get_session(self, spark_config: SparkConfig) -> "SparkSession":
14+
raise NotImplementedError
15+
16+
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from datacustomcode.spark.base import BaseSparkSessionProvider
6+
7+
if TYPE_CHECKING:
8+
from pyspark.sql import SparkSession
9+
from datacustomcode.config import SparkConfig
10+
11+
12+
class DefaultSparkSessionProvider(BaseSparkSessionProvider):
13+
CONFIG_NAME = "DefaultSparkSessionProvider"
14+
15+
def get_session(self, spark_config: SparkConfig) -> "SparkSession":
16+
from pyspark.sql import SparkSession
17+
18+
builder = SparkSession.builder
19+
if spark_config.master is not None:
20+
builder = builder.master(spark_config.master)
21+
builder = builder.appName(spark_config.app_name)
22+
for key, value in spark_config.options.items():
23+
builder = builder.config(key, value)
24+
return builder.getOrCreate()
25+
26+

tests/test_client.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Client,
1010
DataCloudAccessLayerException,
1111
DataCloudObjectType,
12-
_setup_spark,
1312
)
1413
from datacustomcode.config import (
1514
AccessLayerObjectConfig,
@@ -100,37 +99,41 @@ def test_singleton_pattern(self, reset_client, mock_spark):
10099
Client(reader=MagicMock(spec=BaseDataCloudReader))
101100

102101
@patch("datacustomcode.client.config")
103-
@patch("datacustomcode.client._setup_spark")
104102
def test_initialization_with_config(
105-
self, mock_setup_spark, mock_config, reset_client, mock_spark
103+
self, mock_config, reset_client, mock_spark
106104
):
107105
"""Test client initialization using configuration."""
108-
mock_setup_spark.return_value = mock_spark
106+
from datacustomcode.spark.default import DefaultSparkSessionProvider
107+
from unittest.mock import patch as mock_patch
109108

110-
mock_reader = MagicMock(spec=BaseDataCloudReader)
111-
mock_reader_config = MagicMock()
112-
mock_reader_config.to_object.return_value = mock_reader
113-
mock_reader_config.force = False
109+
with mock_patch.object(DefaultSparkSessionProvider, "get_session") as mock_get_session:
110+
mock_get_session.return_value = mock_spark
114111

115-
mock_writer = MagicMock(spec=BaseDataCloudWriter)
116-
mock_writer_config = MagicMock()
117-
mock_writer_config.to_object.return_value = mock_writer
118-
mock_writer_config.force = False
112+
mock_reader = MagicMock(spec=BaseDataCloudReader)
113+
mock_reader_config = MagicMock()
114+
mock_reader_config.to_object.return_value = mock_reader
115+
mock_reader_config.force = False
119116

120-
mock_spark_config = MagicMock(spec=SparkConfig)
117+
mock_writer = MagicMock(spec=BaseDataCloudWriter)
118+
mock_writer_config = MagicMock()
119+
mock_writer_config.to_object.return_value = mock_writer
120+
mock_writer_config.force = False
121121

122-
mock_config.reader_config = mock_reader_config
123-
mock_config.writer_config = mock_writer_config
124-
mock_config.spark_config = mock_spark_config
122+
mock_spark_config = MagicMock(spec=SparkConfig)
123+
mock_config.spark_provider_config = None
125124

126-
client = Client()
125+
mock_config.reader_config = mock_reader_config
126+
mock_config.writer_config = mock_writer_config
127+
mock_config.spark_config = mock_spark_config
127128

128-
mock_setup_spark.assert_called_once_with(mock_spark_config)
129-
mock_reader_config.to_object.assert_called_once_with(mock_spark)
130-
mock_writer_config.to_object.assert_called_once_with(mock_spark)
129+
client = Client()
131130

132-
assert client._reader is mock_reader
133-
assert client._writer is mock_writer
131+
mock_get_session.assert_called_once_with(mock_spark_config)
132+
mock_reader_config.to_object.assert_called_once_with(mock_spark)
133+
mock_writer_config.to_object.assert_called_once_with(mock_spark)
134+
135+
assert client._reader is mock_reader
136+
assert client._writer is mock_writer
134137

135138
def test_read_dlo(self, reset_client, mock_spark):
136139
reader = MagicMock(spec=BaseDataCloudReader)
@@ -249,12 +252,12 @@ def test_read_pattern_flow(self, reset_client, mock_spark):
249252
assert "source_dmo" in client._data_layer_history[DataCloudObjectType.DMO]
250253

251254

252-
# Add tests for _setup_spark function
253-
class TestSetupSpark:
255+
# Add tests for DefaultSparkSessionProvider
256+
class TestDefaultSparkSessionProvider:
254257

255-
@patch("datacustomcode.client.SparkSession")
256-
def test_setup_spark_with_master(self, mock_spark_session):
257-
"""Test _setup_spark with master specified"""
258+
@patch("pyspark.sql.SparkSession")
259+
def test_get_session_with_master(self, mock_spark_session):
260+
"""Test DefaultSparkSessionProvider with master specified"""
258261
mock_builder = MagicMock()
259262
mock_master_builder = MagicMock()
260263
mock_app_name_builder = MagicMock()
@@ -273,7 +276,9 @@ def test_setup_spark_with_master(self, mock_spark_session):
273276
options={"spark.executor.memory": "1g"},
274277
)
275278

276-
result = _setup_spark(spark_config)
279+
from datacustomcode.spark.default import DefaultSparkSessionProvider
280+
provider = DefaultSparkSessionProvider()
281+
result = provider.get_session(spark_config)
277282

278283
mock_builder.master.assert_called_once_with("local[1]")
279284
mock_master_builder.appName.assert_called_once_with("test-app")
@@ -283,9 +288,9 @@ def test_setup_spark_with_master(self, mock_spark_session):
283288
mock_config_builder.getOrCreate.assert_called_once()
284289
assert result is mock_session
285290

286-
@patch("datacustomcode.client.SparkSession")
287-
def test_setup_spark_with_multiple_options(self, mock_spark_session):
288-
"""Test _setup_spark with multiple config options"""
291+
@patch("pyspark.sql.SparkSession")
292+
def test_get_session_with_multiple_options(self, mock_spark_session):
293+
"""Test DefaultSparkSessionProvider with multiple config options"""
289294
mock_builder = MagicMock()
290295
mock_app_name_builder = MagicMock()
291296
mock_config_builder1 = MagicMock()
@@ -310,7 +315,9 @@ def test_setup_spark_with_multiple_options(self, mock_spark_session):
310315
},
311316
)
312317

313-
result = _setup_spark(spark_config)
318+
from datacustomcode.spark.default import DefaultSparkSessionProvider
319+
provider = DefaultSparkSessionProvider()
320+
result = provider.get_session(spark_config)
314321

315322
mock_builder.appName.assert_called_once_with("test-app")
316323

0 commit comments

Comments
 (0)