1414# limitations under the License.
1515
1616from enum import Enum , unique
17+ from typing import (
18+ Any ,
19+ Dict ,
20+ Literal ,
21+ Optional ,
22+ )
23+
1724from pydantic import (
1825 BaseModel ,
1926 Field ,
2027 model_validator ,
2128)
2229
23- from typing import (
24- Literal ,
25- Optional ,
26- Dict ,
27- Any
28- )
2930
3031@unique
3132class PredictionType (Enum ):
@@ -35,37 +36,51 @@ class PredictionType(Enum):
3536 MULTI_OUTCOME = 4
3637 BINARY_CLASSIFICATION = 5
3738
39+
3840class PredictionColumn (BaseModel ):
3941 column_name : str = Field (min_length = 1 , description = "Column name" )
40- string_values : Optional [list [str ]] = Field (default = None , min_length = 1 , description = "Column string values" )
41- double_values : Optional [list [float ]] = Field (default = None , min_length = 1 , description = "Column double values" )
42- boolean_values : Optional [list [bool ]] = Field (default = None , min_length = 1 , description = "Column boolean values" )
43- date_values : Optional [list [str ]] = Field (default = None , min_length = 1 , description = "Column date values" )
44- datetime_values : Optional [list [str ]] = Field (default = None , min_length = 1 , description = "Column datetime values" )
42+ string_values : Optional [list [str ]] = Field (
43+ default = None , min_length = 1 , description = "Column string values"
44+ )
45+ double_values : Optional [list [float ]] = Field (
46+ default = None , min_length = 1 , description = "Column double values"
47+ )
48+ boolean_values : Optional [list [bool ]] = Field (
49+ default = None , min_length = 1 , description = "Column boolean values"
50+ )
51+ date_values : Optional [list [str ]] = Field (
52+ default = None , min_length = 1 , description = "Column date values"
53+ )
54+ datetime_values : Optional [list [str ]] = Field (
55+ default = None , min_length = 1 , description = "Column datetime values"
56+ )
4557
46- @model_validator (mode = ' after' )
58+ @model_validator (mode = " after" )
4759 def validate_exactly_one_value_type (self ):
48- set_count = sum ([
49- self .string_values is not None ,
50- self .double_values is not None ,
51- self .boolean_values is not None ,
52- self .date_values is not None ,
53- self .datetime_values is not None
54- ])
60+ set_count = sum (
61+ [
62+ self .string_values is not None ,
63+ self .double_values is not None ,
64+ self .boolean_values is not None ,
65+ self .date_values is not None ,
66+ self .datetime_values is not None ,
67+ ]
68+ )
5569
5670 if set_count != 1 :
5771 raise ValueError ("Exactly one value type must be set" )
5872
5973 return self
6074
75+
6176class PredictionColumBuilder :
6277 def __init__ (self ) -> None :
63- self ._column_name : str = None
64- self ._string_values : list [str ] = None
65- self ._double_values : list [float ] = None
66- self ._boolean_values : list [bool ] = None
67- self ._date_values : list [str ] = None
68- self ._datetime_values : list [str ] = None
78+ self ._column_name : Optional [ str ] = None
79+ self ._string_values : Optional [ list [str ] ] = None
80+ self ._double_values : Optional [ list [float ] ] = None
81+ self ._boolean_values : Optional [ list [bool ] ] = None
82+ self ._date_values : Optional [ list [str ] ] = None
83+ self ._datetime_values : Optional [ list [str ] ] = None
6984
7085 def set_column_name (self , column_name : str ) -> "PredictionColumBuilder" :
7186 self ._column_name = column_name
@@ -79,45 +94,59 @@ def set_double_values(self, double_values: list[float]) -> "PredictionColumBuild
7994 self ._double_values = double_values
8095 return self
8196
82- def set_boolean_values (self , boolean_values : list [bool ]) -> "PredictionColumBuilder" :
97+ def set_boolean_values (
98+ self , boolean_values : list [bool ]
99+ ) -> "PredictionColumBuilder" :
83100 self ._boolean_values = boolean_values
84101 return self
85102
86103 def set_date_values (self , date_values : list [str ]) -> "PredictionColumBuilder" :
87104 self ._date_values = date_values
88105 return self
89106
90- def set_datetime_values (self , datetime_values : list [str ]) -> "PredictionColumBuilder" :
107+ def set_datetime_values (
108+ self , datetime_values : list [str ]
109+ ) -> "PredictionColumBuilder" :
91110 self ._datetime_values = datetime_values
92111 return self
93112
94113 def build (self ) -> PredictionColumn :
95114 return PredictionColumn (
96- column_name = self ._column_name ,
97- string_values = self ._string_values ,
98- double_values = self ._double_values ,
99- boolean_values = self ._boolean_values ,
100- date_values = self ._date_values ,
101- datetime_values = self ._datetime_values
115+ column_name = self ._column_name ,
116+ string_values = self ._string_values ,
117+ double_values = self ._double_values ,
118+ boolean_values = self ._boolean_values ,
119+ date_values = self ._date_values ,
120+ datetime_values = self ._datetime_values ,
102121 )
103122
123+
104124class PredictionRequest (BaseModel ):
105125 version : Literal ["v1" ] = Field (
106126 default = "v1" , description = "API version, must be 'v1'"
107127 )
108128 prediction_type : PredictionType = Field (description = "Prediction type" )
109- model_api_name : str = Field (min_length = 1 , description = "API name of the model to use" )
110- prediction_columns : list [PredictionColumn ] = Field (min_length = 1 , description = "List of prediction columns" )
111- settings : Optional [Dict [str , Any ]] = Field (default = None , description = "Settings for the prediction request" )
129+ model_api_name : str = Field (
130+ min_length = 1 , description = "API name of the model to use"
131+ )
132+ prediction_columns : list [PredictionColumn ] = Field (
133+ min_length = 1 , description = "List of prediction columns"
134+ )
135+ settings : Optional [Dict [str , Any ]] = Field (
136+ default = None , description = "Settings for the prediction request"
137+ )
138+
112139
113140class PredictionRequestBuilder :
114141 def __init__ (self ) -> None :
115- self ._prediction_type : PredictionType = None
116- self ._model_api_name : str = None
142+ self ._prediction_type : Optional [ PredictionType ] = None
143+ self ._model_api_name : Optional [ str ] = None
117144 self ._prediction_columns : list [PredictionColumn ] = []
118- self ._settings : Dict [str , Any ] = None
145+ self ._settings : Optional [ Dict [str , Any ] ] = None
119146
120- def set_prediction_type (self , prediction_type : PredictionType ) -> "PredictionRequestBuilder" :
147+ def set_prediction_type (
148+ self , prediction_type : PredictionType
149+ ) -> "PredictionRequestBuilder" :
121150 self ._prediction_type = prediction_type
122151 return self
123152
@@ -126,8 +155,7 @@ def set_model(self, model_api_name: str) -> "PredictionRequestBuilder":
126155 return self
127156
128157 def set_prediction_columns (
129- self ,
130- prediction_columns : list [PredictionColumn ]
158+ self , prediction_columns : list [PredictionColumn ]
131159 ) -> "PredictionRequestBuilder" :
132160 self ._prediction_columns = prediction_columns
133161 return self
@@ -141,9 +169,10 @@ def build(self) -> PredictionRequest:
141169 prediction_type = self ._prediction_type ,
142170 model_api_name = self ._model_api_name ,
143171 prediction_columns = self ._prediction_columns ,
144- settings = self ._settings
172+ settings = self ._settings ,
145173 )
146-
174+
175+
147176class PredictionResponse (BaseModel ):
148177 version : Literal ["v1" ] = Field (default = "v1" , description = "API version" )
149178 prediction_type : PredictionType = Field (description = "Prediction type" )
@@ -153,4 +182,3 @@ class PredictionResponse(BaseModel):
153182 @property
154183 def is_success (self ) -> bool :
155184 return self .status_code == 200
156-
0 commit comments