Skip to content

Commit 180601d

Browse files
committed
add some unit tests
1 parent 2a60ca9 commit 180601d

2 files changed

Lines changed: 47 additions & 19 deletions

File tree

src/datacustomcode/client.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __new__(
120120
if cls._instance is None:
121121
cls._instance = super().__new__(cls)
122122

123+
spark = None
123124
# Initialize Readers and Writers from config
124125
# and/or provided reader and writer
125126
if reader is None or writer is None:
@@ -138,6 +139,22 @@ def __new__(
138139
provider = DefaultSparkSessionProvider()
139140

140141
spark = provider.get_session(config.spark_config)
142+
elif (
143+
proxy is None
144+
and config.proxy_config is not None
145+
and config.spark_config is not None
146+
):
147+
# Both reader and writer provided; we still need spark for proxy init
148+
provider = (
149+
spark_provider
150+
if spark_provider is not None
151+
else (
152+
config.spark_provider_config.to_object()
153+
if config.spark_provider_config is not None
154+
else DefaultSparkSessionProvider()
155+
)
156+
)
157+
spark = provider.get_session(config.spark_config)
141158

142159
if config.reader_config is None and reader is None:
143160
raise ValueError(
@@ -155,9 +172,12 @@ def __new__(
155172
reader_init = config.reader_config.to_object(spark) # type: ignore
156173
else:
157174
reader_init = reader
158-
if config.proxy_config is None:
175+
if proxy is not None:
176+
proxy_init = proxy
177+
elif config.proxy_config is None:
159178
raise ValueError("Proxy config is required when reader is provided")
160-
proxy_init = config.proxy_config.to_object(spark)
179+
else:
180+
proxy_init = config.proxy_config.to_object(spark)
161181
if config.writer_config is None and writer is None:
162182
raise ValueError(
163183
"Writer config is required when writer is not provided"

tests/test_client.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from datacustomcode.io.reader.base import BaseDataCloudReader
1919
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
20+
from datacustomcode.proxy.client.base import BaseProxyClient
2021

2122

2223
class MockDataCloudReader(BaseDataCloudReader):
@@ -75,6 +76,13 @@ def mock_config(mock_spark):
7576
)
7677

7778

79+
@pytest.fixture
80+
def mock_proxy():
81+
"""Mock proxy client to avoid starting Spark when reader/writer are provided."""
82+
proxy = MagicMock(spec=BaseProxyClient)
83+
return proxy
84+
85+
7886
@pytest.fixture
7987
def reset_client():
8088
"""Reset the Client singleton between tests."""
@@ -85,12 +93,12 @@ def reset_client():
8593

8694
class TestClient:
8795

88-
def test_singleton_pattern(self, reset_client, mock_spark):
96+
def test_singleton_pattern(self, reset_client, mock_spark, mock_proxy):
8997
"""Test that Client behaves as a singleton."""
9098
reader = MockDataCloudReader(mock_spark)
9199
writer = MockDataCloudWriter(mock_spark)
92100

93-
client1 = Client(reader=reader, writer=writer)
101+
client1 = Client(reader=reader, writer=writer, proxy=mock_proxy)
94102
client2 = Client()
95103

96104
assert client1 is client2
@@ -136,38 +144,38 @@ def test_initialization_with_config(self, mock_config, reset_client, mock_spark)
136144
assert client._reader is mock_reader
137145
assert client._writer is mock_writer
138146

139-
def test_read_dlo(self, reset_client, mock_spark):
147+
def test_read_dlo(self, reset_client, mock_spark, mock_proxy):
140148
reader = MagicMock(spec=BaseDataCloudReader)
141149
writer = MagicMock(spec=BaseDataCloudWriter)
142150
mock_df = MagicMock(spec=DataFrame)
143151
reader.read_dlo.return_value = mock_df
144152

145-
client = Client(reader=reader, writer=writer)
153+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
146154
result = client.read_dlo("test_dlo")
147155

148156
reader.read_dlo.assert_called_once_with("test_dlo")
149157
assert result is mock_df
150158
assert "test_dlo" in client._data_layer_history[DataCloudObjectType.DLO]
151159

152-
def test_read_dmo(self, reset_client, mock_spark):
160+
def test_read_dmo(self, reset_client, mock_spark, mock_proxy):
153161
reader = MagicMock(spec=BaseDataCloudReader)
154162
writer = MagicMock(spec=BaseDataCloudWriter)
155163
mock_df = MagicMock(spec=DataFrame)
156164
reader.read_dmo.return_value = mock_df
157165

158-
client = Client(reader=reader, writer=writer)
166+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
159167
result = client.read_dmo("test_dmo")
160168

161169
reader.read_dmo.assert_called_once_with("test_dmo")
162170
assert result is mock_df
163171
assert "test_dmo" in client._data_layer_history[DataCloudObjectType.DMO]
164172

165-
def test_write_to_dlo(self, reset_client, mock_spark):
173+
def test_write_to_dlo(self, reset_client, mock_spark, mock_proxy):
166174
reader = MagicMock(spec=BaseDataCloudReader)
167175
writer = MagicMock(spec=BaseDataCloudWriter)
168176
mock_df = MagicMock(spec=DataFrame)
169177

170-
client = Client(reader=reader, writer=writer)
178+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
171179
client._record_dlo_access("some_dlo")
172180

173181
client.write_to_dlo("test_dlo", mock_df, WriteMode.APPEND, extra_param=True)
@@ -176,12 +184,12 @@ def test_write_to_dlo(self, reset_client, mock_spark):
176184
"test_dlo", mock_df, WriteMode.APPEND, extra_param=True
177185
)
178186

179-
def test_write_to_dmo(self, reset_client, mock_spark):
187+
def test_write_to_dmo(self, reset_client, mock_spark, mock_proxy):
180188
reader = MagicMock(spec=BaseDataCloudReader)
181189
writer = MagicMock(spec=BaseDataCloudWriter)
182190
mock_df = MagicMock(spec=DataFrame)
183191

184-
client = Client(reader=reader, writer=writer)
192+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
185193
client._record_dmo_access("some_dmo")
186194

187195
client.write_to_dmo("test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True)
@@ -190,42 +198,42 @@ def test_write_to_dmo(self, reset_client, mock_spark):
190198
"test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True
191199
)
192200

193-
def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark):
201+
def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark, mock_proxy):
194202
"""Test that mixing DLOs and DMOs raises an exception."""
195203
reader = MagicMock(spec=BaseDataCloudReader)
196204
writer = MagicMock(spec=BaseDataCloudWriter)
197205
mock_df = MagicMock(spec=DataFrame)
198206

199-
client = Client(reader=reader, writer=writer)
207+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
200208
client._record_dlo_access("test_dlo")
201209

202210
with pytest.raises(DataCloudAccessLayerException) as exc_info:
203211
client.write_to_dmo("test_dmo", mock_df, WriteMode.APPEND)
204212

205213
assert "test_dlo" in str(exc_info.value)
206214

207-
def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark):
215+
def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark, mock_proxy):
208216
"""Test that mixing DMOs and DLOs raises an exception (converse case)."""
209217
reader = MagicMock(spec=BaseDataCloudReader)
210218
writer = MagicMock(spec=BaseDataCloudWriter)
211219
mock_df = MagicMock(spec=DataFrame)
212220

213-
client = Client(reader=reader, writer=writer)
221+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
214222
client._record_dmo_access("test_dmo")
215223

216224
with pytest.raises(DataCloudAccessLayerException) as exc_info:
217225
client.write_to_dlo("test_dlo", mock_df, WriteMode.APPEND)
218226

219227
assert "test_dmo" in str(exc_info.value)
220228

221-
def test_read_pattern_flow(self, reset_client, mock_spark):
229+
def test_read_pattern_flow(self, reset_client, mock_spark, mock_proxy):
222230
"""Test a complete flow of reading and writing within the same object type."""
223231
reader = MagicMock(spec=BaseDataCloudReader)
224232
writer = MagicMock(spec=BaseDataCloudWriter)
225233
mock_df = MagicMock(spec=DataFrame)
226234
reader.read_dlo.return_value = mock_df
227235

228-
client = Client(reader=reader, writer=writer)
236+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
229237

230238
df = client.read_dlo("source_dlo")
231239
client.write_to_dlo("target_dlo", df, WriteMode.APPEND)
@@ -239,7 +247,7 @@ def test_read_pattern_flow(self, reset_client, mock_spark):
239247

240248
# Reset for DMO test
241249
Client._instance = None
242-
client = Client(reader=reader, writer=writer)
250+
client = Client(reader=reader, writer=writer, proxy=mock_proxy)
243251
reader.read_dmo.return_value = mock_df
244252

245253
df = client.read_dmo("source_dmo")

0 commit comments

Comments
 (0)