Skip to content

Commit be4148b

Browse files
Using pydantic for llm_gateway models
1 parent 7076dfa commit be4148b

8 files changed

Lines changed: 94 additions & 211 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ click = "^8.1.8"
9999
loguru = "^0.7.3"
100100
numpy = "*"
101101
pandas = "*"
102-
pydantic = "^1.8.2 || ^2.0.0"
102+
pydantic = "2.13.1"
103103
pyspark = "3.5.1"
104104
python = ">=3.10,<3.12"
105105
pyyaml = "^6.0"

src/datacustomcode/llm_gateway/default.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
1818
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
1919

20+
from datacustomcode.llm_gateway.types.generate_text_response_builder import GenerateTextResponseBuilder
2021

2122
class DefaultLLMGateway(LLMGateway):
2223
def generate_text(
@@ -26,7 +27,11 @@ def generate_text(
2627

2728

2829
response_data = {
29-
'generation' : {'generatedText' : "I am dreaming!!"},
30+
'version': 'v1',
31+
'status_code': 200,
32+
'data' : {
33+
'generation': {'generatedText': 'Hello World'}
34+
}
3035
}
3136

32-
return GenerateTextResponse(200, {"data": response_data})
37+
return GenerateTextResponseBuilder.build(response_data)
Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
1-
from dataclasses import dataclass
2-
3-
import betterproto
4-
5-
from .google import protobuf
6-
7-
8-
@dataclass
9-
class GenerateTextRequest(betterproto.Message):
10-
version: str = betterproto.string_field(1)
11-
model_name: str = betterproto.string_field(2)
12-
prompt: str = betterproto.string_field(3)
13-
localization: protobuf.Struct = betterproto.message_field(4)
14-
tags: protobuf.Struct = betterproto.message_field(5)
1+
from typing import Optional, Dict, Any, Literal
2+
from pydantic import BaseModel, Field
153

164

5+
class GenerateTextRequest(BaseModel):
6+
"""Request for LLM text generation"""
177

8+
version: Literal["v1"] = Field(default="v1", description="API version, must be 'v1'")
9+
model_name: str = Field(..., min_length=1, description="Name of the model to use")
10+
prompt: str = Field(..., min_length=1, max_length=1000, description="Input prompt")
11+
localization: Optional[Dict[str, Any]] = Field(default=None, description="Localization settings")
12+
tags: Optional[Dict[str, Any]] = Field(default=None, description="Additional tags")
Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,12 @@
1-
from datacustomcode.validator.base import Validator
21
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
32

43

5-
class GenerateTextRequestValidator:
6-
@staticmethod
7-
def create_validator() -> Validator:
8-
"""Create a validator with all CEL rules for GenerateTextRequest"""
9-
validator = Validator()
10-
11-
# Rule 1: version == "v1" (CEL: const)
12-
validator.add_rule(
13-
id="request.version_v1",
14-
message="Platform currently only supports version 'v1'",
15-
expression=lambda this: this.version == "v1"
16-
)
17-
18-
# Rule 2: modelName.size() >= 1 (CEL: min_len)
19-
validator.add_rule(
20-
id="request.model_name_required",
21-
message="modelName must not be empty (min_len: 1)",
22-
expression=lambda this: len(this.model_name) >= 1
23-
)
24-
25-
return validator
26-
274
class GenerateTextRequestBuilder:
285
def __init__(self):
29-
self._validator = GenerateTextRequestValidator.create_validator()
30-
self._version = "v1" # Hardcoded default for your SDK
316
self._prompt = ""
327
self._model_name = ""
33-
8+
self._localization = None
9+
self._tags = None
3410

3511
def set_prompt(self, prompt: str):
3612
self._prompt = prompt
@@ -40,18 +16,39 @@ def set_model(self, model_name: str):
4016
self._model_name = model_name
4117
return self
4218

19+
def set_localization(self, localization: dict = None, locale: str = None):
20+
"""
21+
Set localization either from a dict or a simple locale string.
22+
23+
Args:
24+
localization: Full localization dict (if provided, locale is ignored)
25+
locale: Simple locale string for defaultLocale only
26+
27+
Returns:
28+
self for method chaining
29+
"""
30+
31+
if localization is not None:
32+
self._localization = localization
33+
elif locale is not None:
34+
self._localization = {"defaultLocale": locale}
35+
else:
36+
raise ValueError("Must provide either localization or locale")
37+
38+
self._localization = localization
39+
return self
40+
41+
def set_tags(self, tags: dict):
42+
self._tags = tags
43+
return self
44+
4345
def build(self) -> GenerateTextRequest:
4446

4547
request = GenerateTextRequest(
46-
version=self._version,
4748
prompt=self._prompt,
48-
model_name=self._model_name
49+
model_name=self._model_name,
50+
localization=self._localization,
51+
tags=self._tags
4952
)
5053

51-
# 2. Run the Protovalidate check
52-
# This reads the 'max_len: 1000' rule from the .proto metadata
53-
violations = self._validator.validate(request)
54-
if violations:
55-
raise ValueError(f"Validation Error: {violations}")
56-
57-
return request
54+
return request
Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
1+
from typing import Optional, Dict, Any
12

2-
from dataclasses import dataclass
3+
from pydantic import BaseModel, Field
34

4-
import betterproto
5+
class GenerateTextResponse(BaseModel):
6+
"""Response from LLM text generation"""
57

6-
from .google import protobuf
8+
version: str = Field(default="v1", description="API version")
9+
status_code: int = Field(..., description="HTTP status code", ge=0)
10+
data: Optional[Dict[str, Any]] = Field(default=None, description="Response data")
711

12+
@property
13+
def is_success(self) -> bool:
14+
"""Check if request succeeded."""
15+
return self.status_code == 200
816

17+
@property
18+
def is_error(self) -> bool:
19+
"""Check if request failed."""
20+
return not self.is_success
921

10-
@dataclass
11-
class GenerateTextResponse(betterproto.Message):
12-
version: str = betterproto.string_field(1)
13-
status_code: int = betterproto.uint32_field(2)
14-
data: protobuf.Struct = betterproto.message_field(3)
22+
@property
23+
def text(self) -> str:
24+
"""Generated text (convenience property)."""
25+
if self.is_success:
26+
return self.data.get('generation', {}).get('generatedText', '')
27+
return ''
28+
29+
@property
30+
def error_code(self) -> str:
31+
"""Generated text (convenience property)."""
32+
if self.is_error:
33+
return self.data.get('errorCode', self.status_code)
34+
return ''
Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,21 @@
1-
from .google import protobuf
1+
from typing import Dict, Any
2+
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
23

34

45
class GenerateTextResponseBuilder:
56
def __init__(self):
6-
self._validator = Validator()
7-
self._rules = validate_pb2.MessageConstraints()
8-
9-
# Rule 1: Prompt Length
10-
prompt_rule = self.rules.cel.add()
11-
prompt_rule.id = "request.prompt_limit"
12-
prompt_rule.message = "Prompt must be 1-1000 characters."
13-
prompt_rule.expression = "this.prompt.size() > 0 && this.prompt.size() <= 1000"
14-
15-
# Rule 3: ModelName Constraint
16-
model_name_rule = self.rules.cel.add()
17-
version_rule.id = "request.version_v1"
18-
version_rule.message = "Platform currently only supports version 'v1'."
19-
version_rule.expression = "this.version == 'v1'"
20-
217
self._version = "v1" # Hardcoded default for your SDK
22-
self._prompt = ""
23-
self._model_name = ""
24-
8+
self._status_code = None
9+
self._data = None
2510

26-
def validate(self, request: GenerateTextRequest):
27-
violations = self.validator.validate(request, constraints=self.rules)
28-
if violations:
29-
# protovalidate returns a structured 'Violations' object
30-
error_msg = "; ".join([v.message for v in violations.violations])
31-
raise ValueError(f"GenerateTextRequest Validation Failed: {error_msg}")
32-
33-
def set_prompt(self, prompt: str):
34-
self._prompt = prompt
11+
def set_status_code(self, status_code: int):
12+
self._status_code = status_code
3513
return self
3614

37-
def set_model(self, model_name: str):
38-
self._model_name = model_name
15+
def set_data(self, data: dict):
16+
self._data = data
3917
return self
4018

41-
def build(self) -> GenerateTextRequest:
42-
43-
request = GenerateTextRequest(
44-
version=self._version,
45-
prompt=self._prompt,
46-
model_name=self._model_name
47-
)
48-
49-
# 2. Run the Protovalidate check
50-
# This reads the 'max_len: 1000' rule from the .proto metadata
51-
violations = _validator.validate(request)
52-
if violations:
53-
raise ValueError(f"Validation Error: {violations}")
54-
55-
return request
19+
@staticmethod
20+
def build(response_dict: Dict[str, Any]) -> GenerateTextResponse:
21+
return GenerateTextResponse.model_validate(response_dict)

src/datacustomcode/templates/function/payload/entrypoint.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,13 @@ def function(request: dict, runTime: Runtime) -> dict:
4848

4949

5050
builder = GenerateTextRequestBuilder()
51-
request = builder.set_prompt("Hello").set_model("gpt-4").build()
52-
response = runTime.llm_gateway.generate_text(request)
51+
llm_request = builder.set_prompt("Hello").set_model("").build()
52+
llm_response = runTime.llm_gateway.generate_text(llm_request)
5353

54-
if response.is_success:
55-
print(response.text)
54+
if llm_response.is_success:
55+
print(llm_response.text)
5656
else:
57-
print(response.error_code)
58-
59-
file_path = runTime.file.find_file_path("data.csv")
60-
content = open(file_path, 'r').read()
61-
logger.info(content)
57+
print(llm_response.error_code)
6258

6359
for item in items:
6460
# Item is DocElement as dict
@@ -126,7 +122,7 @@ def function(request: dict, runTime: Runtime) -> dict:
126122
}
127123

128124
# Run the function
129-
result = function(test_request)
125+
result = function(test_request, Runtime())
130126

131127
# Print the results in a more readable format
132128
print("\nChunking Results:")

src/datacustomcode/validator/base.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

0 commit comments

Comments
 (0)