Skip to content

Commit d571c06

Browse files
committed
Use SDK HTTP client for network checks, return error on check failure
1 parent a6e326a commit d571c06

3 files changed

Lines changed: 117 additions & 22 deletions

File tree

cmd/doctor/checks.go

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package doctor
22

33
import (
4+
"context"
5+
"crypto/tls"
46
"errors"
57
"fmt"
68
"io"
@@ -11,6 +13,7 @@ import (
1113
"github.com/databricks/cli/internal/build"
1214
"github.com/databricks/cli/libs/databrickscfg/profile"
1315
"github.com/databricks/cli/libs/env"
16+
"github.com/databricks/cli/libs/log"
1417
"github.com/databricks/databricks-sdk-go"
1518
"github.com/databricks/databricks-sdk-go/config"
1619
"github.com/spf13/cobra"
@@ -230,14 +233,16 @@ func checkNetwork(cmd *cobra.Command, cfg *config.Config, resolveErr error, w *d
230233
}
231234

232235
if w != nil {
233-
return checkNetworkWithHost(cmd, w.Config.Host)
236+
return checkNetworkWithHost(cmd, w.Config.Host, configuredNetworkHTTPClient(w.Config))
234237
}
235238

236-
return checkNetworkWithHost(cmd, cfg.Host)
239+
log.Warnf(cmd.Context(), "workspace client unavailable for network check, falling back to default HTTP client")
240+
return checkNetworkWithHost(cmd, cfg.Host, http.DefaultClient)
237241
}
238242

239-
func checkNetworkWithHost(cmd *cobra.Command, host string) CheckResult {
240-
ctx := cmd.Context()
243+
func checkNetworkWithHost(cmd *cobra.Command, host string, client *http.Client) CheckResult {
244+
ctx, cancel := context.WithTimeout(cmd.Context(), networkTimeout)
245+
defer cancel()
241246

242247
if host == "" {
243248
return CheckResult{
@@ -257,7 +262,6 @@ func checkNetworkWithHost(cmd *cobra.Command, host string) CheckResult {
257262
}
258263
}
259264

260-
client := &http.Client{Timeout: networkTimeout}
261265
resp, err := client.Do(req)
262266
if err != nil {
263267
return CheckResult{
@@ -276,3 +280,33 @@ func checkNetworkWithHost(cmd *cobra.Command, host string) CheckResult {
276280
Message: host + " is reachable",
277281
}
278282
}
283+
284+
func configuredNetworkHTTPClient(cfg *config.Config) *http.Client {
285+
return &http.Client{
286+
Transport: configuredNetworkHTTPTransport(cfg),
287+
}
288+
}
289+
290+
func configuredNetworkHTTPTransport(cfg *config.Config) http.RoundTripper {
291+
if cfg.HTTPTransport != nil {
292+
return cfg.HTTPTransport
293+
}
294+
295+
if !cfg.InsecureSkipVerify {
296+
return http.DefaultTransport
297+
}
298+
299+
transport, ok := http.DefaultTransport.(*http.Transport)
300+
if !ok {
301+
return http.DefaultTransport
302+
}
303+
304+
clone := transport.Clone()
305+
if clone.TLSClientConfig != nil {
306+
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
307+
} else {
308+
clone.TLSClientConfig = &tls.Config{}
309+
}
310+
clone.TLSClientConfig.InsecureSkipVerify = true
311+
return clone
312+
}

cmd/doctor/doctor.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package doctor
22

33
import (
44
"encoding/json"
5+
"errors"
56
"fmt"
67
"io"
78

@@ -22,10 +23,12 @@ type CheckResult struct {
2223
// New returns the doctor command.
2324
func New() *cobra.Command {
2425
cmd := &cobra.Command{
25-
Use: "doctor",
26-
Args: root.NoArgs,
27-
Short: "Validate your Databricks CLI setup",
28-
GroupID: "development",
26+
Use: "doctor",
27+
Args: root.NoArgs,
28+
Short: "Validate your Databricks CLI setup",
29+
GroupID: "development",
30+
SilenceUsage: true,
31+
SilenceErrors: true,
2932
}
3033

3134
cmd.RunE = func(cmd *cobra.Command, args []string) error {
@@ -39,13 +42,19 @@ func New() *cobra.Command {
3942
}
4043
buf = append(buf, '\n')
4144
_, err = cmd.OutOrStdout().Write(buf)
42-
return err
45+
if err != nil {
46+
return err
47+
}
4348
case flags.OutputText:
4449
renderResults(cmd.OutOrStdout(), results)
45-
return nil
4650
default:
4751
return fmt.Errorf("unknown output type %s", root.OutputType(cmd))
4852
}
53+
54+
if hasFailedChecks(results) {
55+
return errors.New("one or more checks failed")
56+
}
57+
return nil
4958
}
5059

5160
return cmd
@@ -77,3 +86,12 @@ func renderResults(w io.Writer, results []CheckResult) {
7786
fmt.Fprintln(w, msg)
7887
}
7988
}
89+
90+
func hasFailedChecks(results []CheckResult) bool {
91+
for _, result := range results {
92+
if result.Status == statusFail {
93+
return true
94+
}
95+
}
96+
return false
97+
}

cmd/doctor/doctor_test.go

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ func (m *noConfigProfiler) GetPath(_ context.Context) (string, error) {
6060
return m.path, nil
6161
}
6262

63+
type roundTripFunc func(*http.Request) (*http.Response, error)
64+
65+
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
66+
return f(r)
67+
}
68+
6369
func newTestCmd(ctx context.Context) *cobra.Command {
6470
cmd := &cobra.Command{}
6571
cmd.SetContext(ctx)
@@ -285,7 +291,7 @@ func TestCheckNetworkReachable(t *testing.T) {
285291
ctx := cmdio.MockDiscard(t.Context())
286292
cmd := newTestCmd(ctx)
287293

288-
result := checkNetworkWithHost(cmd, srv.URL)
294+
result := checkNetworkWithHost(cmd, srv.URL, http.DefaultClient)
289295
assert.Equal(t, "Network", result.Name)
290296
assert.Equal(t, statusPass, result.Status)
291297
assert.Contains(t, result.Message, "reachable")
@@ -295,21 +301,26 @@ func TestCheckNetworkNoHost(t *testing.T) {
295301
ctx := cmdio.MockDiscard(t.Context())
296302
cmd := newTestCmd(ctx)
297303

298-
result := checkNetworkWithHost(cmd, "")
304+
result := checkNetworkWithHost(cmd, "", http.DefaultClient)
299305
assert.Equal(t, "Network", result.Name)
300306
assert.Equal(t, statusFail, result.Status)
301307
assert.Contains(t, result.Message, "No host configured")
302308
}
303309

304-
func TestCheckNetworkWithClient(t *testing.T) {
305-
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
306-
w.WriteHeader(http.StatusOK)
307-
}))
308-
defer srv.Close()
309-
310+
func TestCheckNetworkUsesWorkspaceClientTransport(t *testing.T) {
310311
w, err := databricks.NewWorkspaceClient((*databricks.Config)(&config.Config{
311-
Host: srv.URL,
312+
Host: "https://example.com",
312313
Token: "test-token",
314+
HTTPTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
315+
assert.Equal(t, http.MethodHead, r.Method)
316+
assert.Equal(t, "https://example.com", r.URL.String())
317+
return &http.Response{
318+
StatusCode: http.StatusOK,
319+
Body: http.NoBody,
320+
Header: make(http.Header),
321+
Request: r,
322+
}, nil
323+
}),
313324
}))
314325
require.NoError(t, err)
315326

@@ -390,6 +401,38 @@ func TestRenderResultsJSONOmitsEmptyDetail(t *testing.T) {
390401
assert.NotContains(t, string(buf), "detail")
391402
}
392403

404+
func TestHasFailedChecks(t *testing.T) {
405+
tests := []struct {
406+
name string
407+
results []CheckResult
408+
want bool
409+
}{
410+
{
411+
name: "no failures",
412+
results: []CheckResult{
413+
{Name: "Test", Status: statusPass},
414+
{Name: "Info", Status: statusInfo},
415+
{Name: "Warn", Status: statusWarn},
416+
},
417+
want: false,
418+
},
419+
{
420+
name: "has failure",
421+
results: []CheckResult{
422+
{Name: "Test", Status: statusPass},
423+
{Name: "Broken", Status: statusFail},
424+
},
425+
want: true,
426+
},
427+
}
428+
429+
for _, tt := range tests {
430+
t.Run(tt.name, func(t *testing.T) {
431+
assert.Equal(t, tt.want, hasFailedChecks(tt.results))
432+
})
433+
}
434+
}
435+
393436
func TestNewCommandJSON(t *testing.T) {
394437
clearConfigEnv(t)
395438

@@ -412,7 +455,7 @@ func TestNewCommandJSON(t *testing.T) {
412455
cmd.SetArgs([]string{"--output", "json"})
413456

414457
err := cmd.Execute()
415-
require.NoError(t, err)
458+
require.ErrorContains(t, err, "one or more checks failed")
416459

417460
var results []CheckResult
418461
err = json.Unmarshal(buf.Bytes(), &results)
@@ -445,7 +488,7 @@ func TestNewCommandJSONTrailingNewline(t *testing.T) {
445488
cmd.SetArgs([]string{"--output", "json"})
446489

447490
err := cmd.Execute()
448-
require.NoError(t, err)
491+
require.ErrorContains(t, err, "one or more checks failed")
449492
assert.Positive(t, buf.Len())
450493
assert.Equal(t, byte('\n'), buf.Bytes()[buf.Len()-1])
451494
}

0 commit comments

Comments
 (0)