1+ import pytest
2+ from pydantic import ValidationError
3+
4+ from datacustomcode .einstein_predictions .types import (
5+ PredictionColumn ,
6+ PredictionRequest ,
7+ PredictionResponse ,
8+ PredictionType ,
9+ PredictionColumBuilder ,
10+ PredictionRequestBuilder ,
11+ )
12+
13+ class TestPredictionColumnValidation :
14+ def test_string_values_only (self ):
15+ column = PredictionColumn (
16+ column_name = "test_col" ,
17+ string_values = ["a" , "b" , "c" ]
18+ )
19+ assert column .column_name == "test_col"
20+ assert column .string_values == ["a" , "b" , "c" ]
21+ assert column .double_values is None
22+ assert column .boolean_values is None
23+ assert column .date_values is None
24+ assert column .datetime_values is None
25+
26+ def test_double_values_only (self ):
27+ column = PredictionColumn (
28+ column_name = "test_col" ,
29+ double_values = [1.0 , 2.5 , 3.7 ]
30+ )
31+ assert column .double_values == [1.0 , 2.5 , 3.7 ]
32+ assert column .string_values is None
33+ assert column .boolean_values is None
34+ assert column .date_values is None
35+ assert column .datetime_values is None
36+
37+ def test_boolean_values_only (self ):
38+ column = PredictionColumn (
39+ column_name = "test_col" ,
40+ boolean_values = [True , False , True ]
41+ )
42+ assert column .boolean_values == [True , False , True ]
43+ assert column .string_values is None
44+ assert column .double_values is None
45+ assert column .date_values is None
46+ assert column .datetime_values is None
47+
48+
49+ def test_date_values_only (self ):
50+ column = PredictionColumn (
51+ column_name = "test_col" ,
52+ date_values = ["2024-01-01" , "2024-01-02" ]
53+ )
54+ assert column .date_values == ["2024-01-01" , "2024-01-02" ]
55+ assert column .string_values is None
56+ assert column .double_values is None
57+ assert column .boolean_values is None
58+ assert column .datetime_values is None
59+
60+ def test_datetime_values_only (self ):
61+ column = PredictionColumn (
62+ column_name = "test_col" ,
63+ datetime_values = ["2024-01-01T12:00:00" , "2024-01-02T13:00:00" ]
64+ )
65+ assert column .datetime_values == ["2024-01-01T12:00:00" , "2024-01-02T13:00:00" ]
66+ assert column .string_values is None
67+ assert column .double_values is None
68+ assert column .boolean_values is None
69+ assert column .date_values is None
70+
71+ def test_no_column_name_raises_error (self ):
72+ with pytest .raises (ValidationError ) as exc_info :
73+ PredictionColumn (
74+ column_name = "" ,
75+ string_values = ["a" , "b" ],
76+ double_values = [1.0 , 2.0 ]
77+ )
78+
79+ assert str (exc_info .value ) is not None
80+
81+ def test_no_values_raises_error (self ):
82+ with pytest .raises (ValidationError ) as exc_info :
83+ PredictionColumn (column_name = "test_col" )
84+
85+ assert str (exc_info .value ) is not None
86+
87+ def test_string_and_double_raises_error (self ):
88+ with pytest .raises (ValidationError ) as exc_info :
89+ PredictionColumn (
90+ column_name = "test_col" ,
91+ string_values = ["a" , "b" ],
92+ double_values = [1.0 , 2.0 ]
93+ )
94+
95+ assert str (exc_info .value ) is not None
96+
97+ def test_empty_values_raises_error (self ):
98+ with pytest .raises (ValidationError ) as exc_info :
99+ PredictionColumn (
100+ column_name = "test_col" ,
101+ string_values = []
102+ )
103+
104+ assert str (exc_info .value ) is not None
105+
106+ class TestPredictionColumnBuilder :
107+ def test_builder_with_string_values (self ):
108+ column = (PredictionColumBuilder ()
109+ .set_column_name ("test_col" )
110+ .set_string_values (["a" , "b" ])
111+ .build ())
112+
113+ assert column .column_name == "test_col"
114+ assert column .string_values == ["a" , "b" ]
115+
116+ class TestPredictionRequest :
117+ def test_request_with_multiple_columns (self ):
118+ request = PredictionRequest (
119+ prediction_type = PredictionType .CLASSIFICATION ,
120+ model_api_name = "classifier" ,
121+ prediction_columns = [
122+ PredictionColumn (column_name = "col1" , string_values = ["a" ]),
123+ PredictionColumn (column_name = "col2" , double_values = [1.0 ]),
124+ PredictionColumn (column_name = "col3" , boolean_values = [True ])
125+ ]
126+ )
127+
128+ assert len (request .prediction_columns ) == 3
129+
130+ def test_request_requires_model_api_name (self ):
131+ with pytest .raises (ValidationError ):
132+ PredictionRequest (
133+ prediction_type = PredictionType .REGRESSION ,
134+ model_api_name = "" ,
135+ prediction_columns = [
136+ PredictionColumn (column_name = "col1" , double_values = [1.0 ])
137+ ]
138+ )
139+
140+ def test_request_requires_prediction_columns (self ):
141+ with pytest .raises (ValidationError ):
142+ PredictionRequest (
143+ prediction_type = PredictionType .REGRESSION ,
144+ model_api_name = "model" ,
145+ prediction_columns = []
146+ )
147+
148+ class TestPredictionRequestBuilder :
149+ def test_builder_creates_valid_request (self ):
150+ request = (PredictionRequestBuilder ()
151+ .set_prediction_type (PredictionType .CLUSTERING )
152+ .set_model ("cluster_model" )
153+ .set_prediction_columns ([
154+ PredictionColumn (column_name = "test_col" , double_values = [1.0 ])
155+ ])
156+ .set_settings ({
157+ 'maxTopContributors' : 20
158+ })
159+ .build ())
160+
161+ assert request .prediction_type == PredictionType .CLUSTERING
162+ assert request .model_api_name == "cluster_model"
163+ assert len (request .prediction_columns ) == 1
164+ assert request .settings == { 'maxTopContributors' : 20 }
165+
166+
167+ class TestPredictionResponse :
168+ def test_successful_response (self ):
169+ response = PredictionResponse (
170+ version = "v1" ,
171+ prediction_type = PredictionType .REGRESSION ,
172+ status_code = 200 ,
173+ data = {"results" : [{"prediction" : {"value" : 42.5 }}]}
174+ )
175+
176+ assert response .is_success
177+ assert response .status_code == 200
178+ assert response .data is not None
179+
180+ def test_failed_response (self ):
181+ response = PredictionResponse (
182+ version = "v1" ,
183+ prediction_type = PredictionType .REGRESSION ,
184+ status_code = 500 ,
185+ data = {"error" : "Internal server error" }
186+ )
187+
188+ assert not response .is_success
189+ assert response .status_code == 500
0 commit comments