Skip to content

Commit 900526d

Browse files
Merge pull request #90 from forcedotcom/predict
Use production API for llm gateway and predictive API
2 parents cf86d27 + 28187f2 commit 900526d

32 files changed

Lines changed: 736 additions & 604 deletions

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ pydantic = "2.13.1"
105105
pyspark = "3.5.1"
106106
python = ">=3.10,<3.12"
107107
pyyaml = "^6.0"
108+
requests = "2.33.1"
108109
salesforce-cdp-connector = ">=1.0.19"
109110
setuptools_scm = "^7.1.0"
110111

src/datacustomcode/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,11 @@
1717
from datacustomcode.credentials import AuthType, Credentials
1818
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
1919
from datacustomcode.io.writer.print import PrintDataCloudWriter
20-
from datacustomcode.proxy.client.LocalProxyClientProvider import (
21-
LocalProxyClientProvider,
22-
)
2320

2421
__all__ = [
2522
"AuthType",
2623
"Client",
2724
"Credentials",
28-
"LocalProxyClientProvider",
2925
"PrintDataCloudWriter",
3026
"QueryAPIDataCloudReader",
3127
]

src/datacustomcode/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def do_oauth_browser_flow(
170170

171171
# Start callback server
172172
click.echo(f"\nStarting local callback server on {redirect_uri}...")
173-
server, actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue)
173+
server, _actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue)
174174

175175
# Build authorization URL with final redirect_uri
176176
auth_url = (

src/datacustomcode/cli.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,15 @@ def deploy(
179179
function_invoke_opt: str,
180180
sf_cli_org: Optional[str],
181181
):
182-
from datacustomcode.credentials import Credentials
183182
from datacustomcode.deploy import (
184183
COMPUTE_TYPES,
185-
AccessTokenResponse,
186184
CodeExtensionMetadata,
187-
_retrieve_access_token_from_sf_cli,
188185
deploy_full,
189186
)
187+
from datacustomcode.token_provider import (
188+
CredentialsTokenProvider,
189+
SFCLITokenProvider,
190+
)
190191

191192
logger.debug("Deploying project")
192193

@@ -220,22 +221,15 @@ def deploy(
220221
function_invoke_options = function_invoke_opt.split(",")
221222
metadata.functionInvokeOptions = function_invoke_options
222223

223-
auth: Union[Credentials, AccessTokenResponse]
224-
if sf_cli_org:
225-
try:
226-
auth = _retrieve_access_token_from_sf_cli(sf_cli_org)
227-
except RuntimeError as e:
228-
click.secho(f"Error: {e}", fg="red")
229-
raise click.Abort() from None
230-
else:
231-
try:
232-
auth = Credentials.from_available(profile=profile)
233-
except ValueError as e:
234-
click.secho(
235-
f"Error: {e}",
236-
fg="red",
237-
)
238-
raise click.Abort() from None
224+
try:
225+
if sf_cli_org:
226+
auth = SFCLITokenProvider(sf_cli_org).get_token()
227+
else:
228+
auth = CredentialsTokenProvider(profile).get_token()
229+
except RuntimeError as e:
230+
click.secho(f"Error: {e}", fg="red")
231+
raise click.Abort() from None
232+
239233
deploy_full(path, metadata, auth, network)
240234

241235

src/datacustomcode/client.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
from datacustomcode.io.reader.base import BaseDataCloudReader
3535
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
36-
from datacustomcode.proxy.client.base import BaseProxyClient
3736
from datacustomcode.spark.base import BaseSparkSessionProvider
3837

3938

@@ -107,15 +106,13 @@ class Client:
107106
_reader: BaseDataCloudReader
108107
_writer: BaseDataCloudWriter
109108
_file: DefaultFindFilePath
110-
_proxy: Optional[BaseProxyClient]
111109
_data_layer_history: dict[DataCloudObjectType, set[str]]
112110
_code_type: str
113111

114112
def __new__(
115113
cls,
116114
reader: Optional[BaseDataCloudReader] = None,
117115
writer: Optional["BaseDataCloudWriter"] = None,
118-
proxy: Optional[BaseProxyClient] = None,
119116
spark_provider: Optional["BaseSparkSessionProvider"] = None,
120117
code_type: str = "script",
121118
) -> Client:
@@ -223,11 +220,6 @@ def write_to_dmo(
223220
self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO)
224221
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) # type: ignore[no-any-return]
225222

226-
def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
227-
if self._proxy is None:
228-
raise ValueError("No proxy configured; set proxy or proxy_config")
229-
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens) # type: ignore[no-any-return]
230-
231223
def find_file_path(self, file_name: str) -> Path:
232224
"""Return a file path"""
233225

src/datacustomcode/cmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,6 @@ def _cmd_output(
104104

105105

106106
def cmd_output(*cmd: str, **kwargs: Any) -> Union[str, None]:
107-
returncode, stdout_b, stderr_b = _cmd_output(*cmd, **kwargs)
107+
_returncode, stdout_b, _stderr_b = _cmd_output(*cmd, **kwargs)
108108
stdout = stdout_b.decode() if stdout_b is not None else None
109109
return stdout

src/datacustomcode/config.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@
3737
# This lets all readers and writers to be findable via config
3838
from datacustomcode.io import * # noqa: F403
3939
from datacustomcode.io.base import BaseDataAccessLayer
40-
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH002
41-
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH002
42-
from datacustomcode.proxy.base import BaseProxyAccessLayer
43-
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
4440
from datacustomcode.spark.base import BaseSparkSessionProvider
4541

4642
if TYPE_CHECKING:
4743
from pyspark.sql import SparkSession
4844

45+
from datacustomcode.io.reader.base import BaseDataCloudReader
46+
from datacustomcode.io.writer.base import BaseDataCloudWriter
47+
4948

5049
_T = TypeVar("_T", bound="BaseDataAccessLayer")
5150

@@ -55,7 +54,7 @@ class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):
5554

5655
def to_object(self, spark: SparkSession) -> _T:
5756
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
58-
return cast(_T, type_(spark=spark, **self.options))
57+
return cast("_T", type_(spark=spark, **self.options))
5958

6059

6160
class SparkConfig(ForceableConfig):
@@ -74,31 +73,18 @@ class SparkConfig(ForceableConfig):
7473

7574
_P = TypeVar("_P", bound=BaseSparkSessionProvider)
7675

77-
_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)
78-
79-
80-
class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
81-
"""Config for proxy clients that take no constructor args (e.g. no spark)."""
82-
83-
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
84-
85-
def to_object(self) -> _PX:
86-
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
87-
return cast(_PX, type_(**self.options))
88-
8976

9077
class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
9178
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
9279

9380
def to_object(self) -> _P:
9481
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
95-
return cast(_P, type_(**self.options))
82+
return cast("_P", type_(**self.options))
9683

9784

9885
class ClientConfig(BaseConfig):
99-
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
100-
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
101-
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
86+
reader_config: Union[AccessLayerObjectConfig["BaseDataCloudReader"], None] = None
87+
writer_config: Union[AccessLayerObjectConfig["BaseDataCloudWriter"], None] = None
10288
spark_config: Union[SparkConfig, None] = None
10389
spark_provider_config: Union[
10490
SparkProviderConfig[BaseSparkSessionProvider], None
@@ -126,7 +112,6 @@ def merge(
126112

127113
self.reader_config = merge(self.reader_config, other.reader_config)
128114
self.writer_config = merge(self.writer_config, other.writer_config)
129-
self.proxy_config = merge(self.proxy_config, other.proxy_config)
130115
self.spark_config = merge(self.spark_config, other.spark_config)
131116
self.spark_provider_config = merge(
132117
self.spark_provider_config, other.spark_provider_config

src/datacustomcode/config.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ spark_config:
1919
spark.sql.execution.arrow.pyspark.enabled: 'true'
2020
spark.driver.extraJavaOptions: -Djava.security.manager=allow
2121

22-
proxy_config:
23-
type_config_name: LocalProxyClientProvider
22+
einstein_predictions_config:
23+
type_config_name: DefaultEinsteinPredictions
2424
options:
2525
credentials_profile: default
2626

27-
einstein_predictions_config:
28-
type_config_name: DefaultEinsteinPredictions
29-
options: {}
27+
llm_gateway_config:
28+
type_config_name: DefaultLLMGateway
29+
options:
30+
credentials_profile: default

src/datacustomcode/deploy.py

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
import os
2020
import re
2121
import shutil
22-
import subprocess
2322
import tempfile
2423
import time
2524
from typing import (
26-
TYPE_CHECKING,
2725
Any,
2826
Callable,
2927
Dict,
@@ -37,15 +35,10 @@
3735
import requests
3836

3937
from datacustomcode.cmd import cmd_output
40-
from datacustomcode.credentials import AuthType
4138
from datacustomcode.scan import find_base_directory, get_package_type
4239

43-
if TYPE_CHECKING:
44-
from datacustomcode.credentials import Credentials
45-
4640
DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code"
4741
DATA_TRANSFORMS_PATH = "services/data/v63.0/ssot/data-transforms"
48-
AUTH_PATH = "services/oauth2/token"
4942
WAIT_FOR_DEPLOYMENT_TIMEOUT = 3000
5043

5144
# Available compute types for Data Cloud deployments.
@@ -163,80 +156,6 @@ class AccessTokenResponse(BaseModel):
163156
instance_url: str
164157

165158

166-
def _retrieve_access_token(credentials: Credentials) -> AccessTokenResponse:
167-
"""Get an access token for the Salesforce API."""
168-
logger.debug("Getting oauth token...")
169-
170-
url = f"{credentials.login_url.rstrip('/')}/{AUTH_PATH.lstrip('/')}"
171-
172-
if credentials.auth_type == AuthType.OAUTH_TOKENS:
173-
data = {
174-
"grant_type": "refresh_token",
175-
"refresh_token": credentials.refresh_token,
176-
"client_id": credentials.client_id,
177-
"client_secret": credentials.client_secret,
178-
}
179-
elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS:
180-
data = {
181-
"grant_type": "client_credentials",
182-
"client_id": credentials.client_id,
183-
"client_secret": credentials.client_secret,
184-
}
185-
else:
186-
raise ValueError(f"Unsupported auth_type: {credentials.auth_type}")
187-
188-
response = _make_api_call(url, "POST", data=data)
189-
return AccessTokenResponse(**response)
190-
191-
192-
def _retrieve_access_token_from_sf_cli(sf_cli_org: str) -> AccessTokenResponse:
193-
"""Get an access token from the Salesforce CLI."""
194-
try:
195-
result = subprocess.run(
196-
["sf", "org", "display", "--target-org", sf_cli_org, "--json"],
197-
capture_output=True,
198-
text=True,
199-
check=True,
200-
timeout=30,
201-
)
202-
except FileNotFoundError as exc:
203-
raise RuntimeError(
204-
"The 'sf' command was not found. "
205-
"Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli"
206-
) from exc
207-
except subprocess.TimeoutExpired as exc:
208-
raise RuntimeError(
209-
f"'sf org display' timed out for org '{sf_cli_org}'"
210-
) from exc
211-
except subprocess.CalledProcessError as exc:
212-
raise RuntimeError(
213-
f"'sf org display' failed for org '{sf_cli_org}'.\n"
214-
f"Ensure the org is authenticated via 'sf org login web'.\n"
215-
f"stderr: {exc.stderr.strip()}"
216-
) from exc
217-
218-
try:
219-
data = json.loads(result.stdout)
220-
except json.JSONDecodeError as exc:
221-
raise RuntimeError(f"Failed to parse 'sf org display' output: {exc}") from exc
222-
223-
if data.get("status") != 0:
224-
raise RuntimeError(
225-
f"SF CLI error for org '{sf_cli_org}': "
226-
f"{data.get('message', 'unknown error')}"
227-
)
228-
229-
org_result = data.get("result", {})
230-
access_token = org_result.get("accessToken")
231-
instance_url = org_result.get("instanceUrl")
232-
if not access_token or not instance_url:
233-
raise RuntimeError(
234-
f"'sf org display' did not return an access token or instance URL "
235-
f"for org '{sf_cli_org}'"
236-
)
237-
return AccessTokenResponse(access_token=access_token, instance_url=instance_url)
238-
239-
240159
class CreateDeploymentResponse(BaseModel):
241160
fileUploadUrl: str
242161

@@ -567,16 +486,11 @@ def zip(
567486
def deploy_full(
568487
directory: str,
569488
metadata: CodeExtensionMetadata,
570-
credentials: Union["Credentials", AccessTokenResponse],
489+
access_token: AccessTokenResponse,
571490
docker_network: str,
572491
callback=None,
573492
) -> AccessTokenResponse:
574493
"""Deploy a data transform in the DataCloud."""
575-
if isinstance(credentials, AccessTokenResponse):
576-
access_token = credentials
577-
else:
578-
access_token = _retrieve_access_token(credentials)
579-
580494
# prepare payload
581495
config = get_config(directory)
582496

@@ -587,7 +501,6 @@ def deploy_full(
587501
wait_for_deployment(access_token, metadata, callback)
588502

589503
# create data transform
590-
591504
if isinstance(config, DataTransformConfig):
592505
create_data_transform(directory, access_token, metadata, config)
593506
return access_token

0 commit comments

Comments
 (0)