Skip to content

Commit 76c3f3f

Browse files
committed
added llm_complete
1 parent f914a25 commit 76c3f3f

6 files changed

Lines changed: 202 additions & 0 deletions

File tree

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ client.write_to_dlo('output_DLO')
170170
> [!WARNING]
171171
> Currently we only support reading from DMOs and writing to DMOs or reading from DLOs and writing to DLOs, but they cannot mix.
172172
173+
## LLM Capabilities
174+
175+
* `llm_complete(prompt_col, model_id, max_tokens)` – Generate AI completions from a prompt column
176+
177+
For example:
178+
```python
179+
from datacustomcode.ai import llm_complete
180+
from pyspark.sql.functions import concat_ws, lit, col
181+
182+
# Generate summaries
183+
prompt = concat_ws(" ",
184+
lit("Summarize:"),
185+
col("Name__c"),
186+
col("Description__c")
187+
)
188+
df = df.withColumn("summary", llm_complete(prompt))
189+
```
173190

174191
## CLI
175192

src/datacustomcode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from datacustomcode.ai import llm_complete
1617
from datacustomcode.client import Client
1718
from datacustomcode.credentials import AuthType, Credentials
1819
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
@@ -28,4 +29,5 @@
2829
"LocalProxyClientProvider",
2930
"PrintDataCloudWriter",
3031
"QueryAPIDataCloudReader",
32+
"llm_complete",
3133
]

src/datacustomcode/ai/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""LLM capabilities for Data Cloud Custom Code.
17+
18+
This module provides LLM functions
19+
20+
Available functions:
21+
llm_complete: Generate completions from a prompt column
22+
23+
Example:
24+
from datacustomcode.ai import llm_complete
25+
from pyspark.sql.functions import col
26+
27+
df = spark.read.table("Account_std__dll")
28+
df = df.withColumn("summary", llm_complete("Name__c"))
29+
"""
30+
31+
from datacustomcode.ai.llm import llm_complete
32+
33+
__all__ = [
34+
"llm_complete",
35+
]

src/datacustomcode/ai/llm.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""LLM completion functions for Data Cloud Custom Code.
17+
"""
18+
19+
from typing import Union
20+
21+
from pyspark.sql import Column
22+
from pyspark.sql.functions import call_function, lit
23+
24+
# Default values for llm_complete function
25+
# TODO: Validate these defaults
26+
_DEFAULT_MODEL_ID = "sfdc_ai__DefaultGPT4Omni"
27+
_DEFAULT_MAX_TOKENS = 200
28+
_LLM_GATEWAY_UDF_NAME = "llm_gateway_generate"
29+
30+
31+
def llm_complete(
32+
prompt_col: Union[Column, str],
33+
*,
34+
model_id: str = _DEFAULT_MODEL_ID,
35+
max_tokens: int = _DEFAULT_MAX_TOKENS,
36+
) -> Column:
37+
"""Returns the AI-generated text response as a string column.
38+
39+
Args:
40+
prompt_col: Column or column name containing the prompt text.
41+
The prompt should be a string value (max 32KB recommended).
42+
Use string functions like concat_ws(), format_string(), etc.
43+
to construct complex prompts from multiple columns.
44+
model_id: Defaults to "sfdc_ai__DefaultGPT4Omni".
45+
Available models depend on your org's configuration.
46+
max_tokens: Maximum tokens in the response. Defaults to 200.
47+
Higher values allow longer responses but increase latency and cost.
48+
49+
Returns:
50+
Column of StringType with AI-generated response.
51+
Returns null if the input prompt is null.
52+
53+
Raises:
54+
TypeError: If prompt_col is not a Column or string.
55+
ValueError: If max_tokens is not positive.
56+
"""
57+
# Input validation
58+
if not isinstance(prompt_col, (Column, str)):
59+
raise TypeError(
60+
f"prompt_col must be a Column or str, got {type(prompt_col).__name__}"
61+
)
62+
63+
if not isinstance(max_tokens, int) or max_tokens <= 0:
64+
raise ValueError(f"max_tokens must be a positive integer, got {max_tokens}")
65+
66+
# Convert string column name to Column
67+
if isinstance(prompt_col, str):
68+
from pyspark.sql.functions import col
69+
70+
prompt_col = col(prompt_col)
71+
72+
from pyspark.sql.functions import named_struct
73+
74+
template = "{prompt}"
75+
values_struct = named_struct(lit("prompt"), prompt_col)
76+
77+
return call_function(
78+
_LLM_GATEWAY_UDF_NAME,
79+
lit(template),
80+
values_struct,
81+
lit(model_id),
82+
lit(max_tokens),
83+
)

tests/ai/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

tests/ai/test_llm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for datacustomcode.ai.llm module."""
17+
18+
import pytest
19+
20+
from datacustomcode.ai import llm_complete
21+
22+
23+
class TestLlmComplete:
24+
"""Tests for llm_complete function."""
25+
26+
def test_invalid_prompt_col_type_int(self):
27+
"""Test that invalid prompt_col type raises TypeError."""
28+
with pytest.raises(TypeError, match="prompt_col must be a Column or str"):
29+
llm_complete(123)
30+
31+
def test_invalid_max_tokens_type_string(self):
32+
"""Test that string max_tokens raises ValueError."""
33+
with pytest.raises(ValueError, match="max_tokens must be a positive integer"):
34+
llm_complete("test_col", max_tokens="invalid")
35+
36+
def test_invalid_max_tokens_type_float(self):
37+
"""Test that float max_tokens raises ValueError."""
38+
with pytest.raises(ValueError, match="max_tokens must be a positive integer"):
39+
llm_complete("test_col", max_tokens=100.5)
40+
41+
def test_negative_max_tokens(self):
42+
"""Test that negative max_tokens raises ValueError."""
43+
with pytest.raises(ValueError, match="max_tokens must be a positive integer"):
44+
llm_complete("test_col", max_tokens=-1)
45+
46+
def test_zero_max_tokens(self):
47+
"""Test that zero max_tokens raises ValueError."""
48+
with pytest.raises(ValueError, match="max_tokens must be a positive integer"):
49+
llm_complete("test_col", max_tokens=0)
50+
51+

0 commit comments

Comments
 (0)