Skip to content

Commit 6ff023c

Browse files
committed
Added Oauth flow logic.
Do the flow as part of configure.
1 parent f2a715a commit 6ff023c

2 files changed

Lines changed: 234 additions & 9 deletions

File tree

src/datacustomcode/cli.py

Lines changed: 227 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,26 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import http.server
1516
from importlib import metadata
1617
import json
1718
import os
19+
import queue
20+
import socketserver
1821
import sys
19-
from typing import List, Union
22+
import threading
23+
import time
24+
from typing import (
25+
Any,
26+
List,
27+
Union,
28+
)
29+
from urllib.parse import parse_qs, urlparse
30+
import webbrowser
2031

2132
import click
2233
from loguru import logger
34+
import requests
2335

2436
from datacustomcode.scan import find_base_directory, get_package_type
2537

@@ -45,6 +57,208 @@ def version():
4557
click.echo("Version information not available")
4658

4759

60+
class OAuthCallbackHandler(http.server.SimpleHTTPRequestHandler):
61+
"""HTTP request handler to capture OAuth callback."""
62+
63+
def __init__(self, *args, auth_code_queue=None, **kwargs):
64+
self.auth_code_queue = auth_code_queue
65+
super().__init__(*args, **kwargs)
66+
67+
def do_GET(self):
68+
"""Handle GET request from OAuth callback."""
69+
parsed_path = urlparse(self.path)
70+
query_params = parse_qs(parsed_path.query)
71+
72+
if "code" in query_params:
73+
auth_code = query_params["code"][0]
74+
self.auth_code_queue.put(auth_code)
75+
self.send_response(200)
76+
self.send_header("Content-type", "text/html")
77+
self.end_headers()
78+
self.wfile.write(
79+
b"<html><body><h1>Authentication successful!</h1>"
80+
b"<p>You can close this window and return to the terminal.</p>"
81+
b"</body></html>"
82+
)
83+
elif "error" in query_params:
84+
error = query_params["error"][0]
85+
error_description = query_params.get("error_description", [""])[0]
86+
self.auth_code_queue.put(f"ERROR:{error}:{error_description}")
87+
self.send_response(400)
88+
self.send_header("Content-type", "text/html")
89+
self.end_headers()
90+
self.wfile.write(
91+
f"<html><body><h1>Authentication failed</h1>"
92+
f"<p>Error: {error}</p>"
93+
f"<p>{error_description}</p></body></html>".encode()
94+
)
95+
else:
96+
self.send_response(400)
97+
self.send_header("Content-type", "text/html")
98+
self.end_headers()
99+
self.wfile.write(b"<html><body><h1>Invalid callback</h1></body></html>")
100+
101+
def log_message(self, format, *args):
102+
"""Suppress default logging."""
103+
104+
105+
def _run_oauth_callback_server(
106+
redirect_uri: str, auth_code_queue: "queue.Queue[str]"
107+
) -> tuple[socketserver.TCPServer, int]:
108+
"""Start a local HTTP server to catch OAuth callback.
109+
110+
Args:
111+
redirect_uri: The redirect URI configured in the OAuth app
112+
auth_code_queue: Queue to put the authorization code in
113+
114+
Returns:
115+
Tuple of (server instance, actual port number)
116+
"""
117+
parsed_uri = urlparse(redirect_uri)
118+
host = parsed_uri.hostname or "localhost"
119+
port = parsed_uri.port or 5555
120+
121+
# Create a custom handler factory
122+
def handler_factory(*args, **kwargs):
123+
return OAuthCallbackHandler(*args, auth_code_queue=auth_code_queue, **kwargs)
124+
125+
server = socketserver.TCPServer((host, port), handler_factory)
126+
server.allow_reuse_address = True
127+
128+
def serve():
129+
server.serve_forever()
130+
131+
server_thread = threading.Thread(target=serve, daemon=True)
132+
server_thread.start()
133+
134+
# Wait a moment for server to start
135+
time.sleep(0.5)
136+
137+
return server, port
138+
139+
140+
def _exchange_code_for_tokens(
141+
login_url: str,
142+
client_id: str,
143+
client_secret: str,
144+
redirect_uri: str,
145+
auth_code: str,
146+
) -> Any:
147+
"""Exchange authorization code for access and refresh tokens.
148+
149+
Args:
150+
login_url: Salesforce login URL
151+
client_id: OAuth client ID
152+
client_secret: OAuth client secret
153+
redirect_uri: Redirect URI used in authorization
154+
auth_code: Authorization code from callback
155+
156+
Returns:
157+
Dictionary containing access_token and refresh_token
158+
159+
Raises:
160+
click.ClickException: If token exchange fails
161+
"""
162+
token_url = f"{login_url.rstrip('/')}/services/oauth2/token"
163+
data = {
164+
"grant_type": "authorization_code",
165+
"code": auth_code,
166+
"client_id": client_id,
167+
"client_secret": client_secret,
168+
"redirect_uri": redirect_uri,
169+
}
170+
171+
try:
172+
response = requests.post(token_url, data=data, timeout=30)
173+
response.raise_for_status()
174+
return response.json()
175+
except requests.exceptions.RequestException as e:
176+
raise click.ClickException(
177+
f"Failed to exchange authorization code for tokens: {e}"
178+
) from e
179+
180+
181+
def _perform_oauth_browser_flow(
182+
login_url: str, client_id: str, client_secret: str, redirect_uri: str
183+
) -> tuple[str, str]:
184+
"""Perform OAuth browser flow to obtain tokens.
185+
186+
Args:
187+
login_url: Salesforce login URL
188+
client_id: OAuth client ID
189+
client_secret: OAuth client secret
190+
redirect_uri: Redirect URI configured in OAuth app
191+
192+
Returns:
193+
Tuple of (refresh_token, access_token)
194+
195+
Raises:
196+
click.ClickException: If OAuth flow fails
197+
"""
198+
# Parse redirect_uri and ensure it has a port
199+
parsed_redirect = urlparse(redirect_uri)
200+
if not parsed_redirect.port:
201+
# If no port specified, default to 5555 and update redirect_uri
202+
default_port = 5555
203+
redirect_uri = f"{parsed_redirect.scheme}://{parsed_redirect.hostname}:{default_port}{parsed_redirect.path}"
204+
205+
# Create queue for communication between server and main thread
206+
auth_code_queue: queue.Queue[str] = queue.Queue()
207+
208+
# Start callback server
209+
click.echo(f"\nStarting local callback server on {redirect_uri}...")
210+
server, actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue)
211+
212+
# Build authorization URL with final redirect_uri
213+
auth_url = (
214+
f"{login_url.rstrip('/')}/services/oauth2/authorize"
215+
f"?response_type=code"
216+
f"&client_id={client_id}"
217+
f"&redirect_uri={redirect_uri}"
218+
)
219+
220+
# Open browser
221+
click.echo("Opening browser for authentication...")
222+
click.echo(f"If the browser doesn't open automatically, visit:\n{auth_url}\n")
223+
webbrowser.open(auth_url)
224+
225+
# Wait for callback (with timeout)
226+
click.echo("Waiting for authentication...")
227+
try:
228+
result = auth_code_queue.get(timeout=60) # 1 minute timeout
229+
except queue.Empty:
230+
server.shutdown()
231+
raise click.ClickException(
232+
"Authentication timeout. Please try again."
233+
) from None
234+
235+
# Shutdown server
236+
server.shutdown()
237+
238+
# Check for errors
239+
if result.startswith("ERROR:"):
240+
_, error, error_description = result.split(":", 2)
241+
raise click.ClickException(f"OAuth error: {error}. {error_description}")
242+
243+
auth_code = result
244+
245+
# Exchange code for tokens
246+
click.echo("Exchanging authorization code for tokens...")
247+
token_response = _exchange_code_for_tokens(
248+
login_url, client_id, client_secret, redirect_uri, auth_code
249+
)
250+
251+
refresh_token = token_response.get("refresh_token")
252+
access_token = token_response.get("access_token")
253+
254+
if not refresh_token:
255+
raise click.ClickException(
256+
"No refresh_token in response. Please check your OAuth app configuration."
257+
)
258+
259+
return refresh_token, access_token
260+
261+
48262
def _configure_oauth_tokens(
49263
login_url: str,
50264
client_id: str,
@@ -53,21 +267,25 @@ def _configure_oauth_tokens(
53267
"""Configure credentials for OAuth Tokens authentication."""
54268
from datacustomcode.credentials import AuthType, Credentials
55269

56-
client_secret = click.prompt("Client Secret")
57-
refresh_token = click.prompt("Refresh Token")
58-
core_token = click.prompt(
59-
"Core Token (optional, press Enter to skip)",
60-
default="",
61-
show_default=False,
62-
)
270+
client_secret = click.prompt("Client Secret", hide_input=True)
271+
redirect_uri = click.prompt("Redirect URI")
272+
273+
# Perform OAuth browser flow
274+
try:
275+
refresh_token, access_token = _perform_oauth_browser_flow(
276+
login_url, client_id, client_secret, redirect_uri
277+
)
278+
except click.ClickException as e:
279+
click.secho(f"Error: {e}", fg="red")
280+
raise click.Abort() from None
63281

64282
credentials = Credentials(
65283
login_url=login_url,
66284
client_id=client_id,
67285
auth_type=AuthType.OAUTH_TOKENS,
68286
client_secret=client_secret,
69287
refresh_token=refresh_token,
70-
core_token=core_token if core_token else None,
288+
core_token=access_token,
71289
)
72290
credentials.update_ini(profile=profile)
73291
click.secho(

src/datacustomcode/credentials.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,11 @@ def update_ini(self, profile: str = "default", ini_file: str = INI_FILE) -> None
265265
with open(expanded_ini_file, "w") as f:
266266
config.write(f)
267267

268+
# Set secure file permissions (0o600 - readable/writable by owner only)
269+
try:
270+
os.chmod(expanded_ini_file, 0o600)
271+
except OSError:
272+
# Ignore errors if we can't set file permissions (e.g., on Windows)
273+
pass
274+
268275
logger.debug(f"Saved credentials to {expanded_ini_file} [{profile}]")

0 commit comments

Comments
 (0)