|
1 | 1 | package cmd |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "os" |
| 5 | + "strings" |
4 | 6 | "testing" |
5 | 7 |
|
| 8 | + "github.com/spf13/cobra" |
6 | 9 | "github.com/stretchr/testify/assert" |
| 10 | + "github.com/stretchr/testify/require" |
| 11 | + |
| 12 | + "github.com/github/gh-aw-mcpg/internal/config" |
7 | 13 | ) |
8 | 14 |
|
| 15 | +// TestDetectGuardWasm_FileNotFound tests that detectGuardWasm returns empty string |
| 16 | +// when the baked-in guard at containerGuardWasmPath does not exist. |
| 17 | +// In standard test environments (non-container), the baked-in guard is absent. |
| 18 | +func TestDetectGuardWasm_FileNotFound(t *testing.T) { |
| 19 | + // Confirm the baked-in path does not exist in this environment |
| 20 | + _, err := os.Stat(containerGuardWasmPath) |
| 21 | + if err == nil { |
| 22 | + t.Skipf("baked-in guard found at %s (running in container) — skipping 'not found' test", containerGuardWasmPath) |
| 23 | + } |
| 24 | + |
| 25 | + result := detectGuardWasm() |
| 26 | + assert.Empty(t, result, "detectGuardWasm should return empty string when guard file does not exist") |
| 27 | +} |
| 28 | + |
| 29 | +// TestDetectGuardWasm_FileExists verifies that detectGuardWasm returns the |
| 30 | +// containerGuardWasmPath when that file is present on the filesystem. |
| 31 | +// This test creates a temporary file at the expected path to simulate the container environment. |
| 32 | +func TestDetectGuardWasm_FileExists(t *testing.T) { |
| 33 | + // Skip if we cannot write to /guards/github/; test can only run where the |
| 34 | + // directory is pre-created (e.g. the production container image). |
| 35 | + if _, err := os.Stat(containerGuardWasmPath); err == nil { |
| 36 | + // File already exists (running in container): just verify the function works. |
| 37 | + result := detectGuardWasm() |
| 38 | + assert.Equal(t, containerGuardWasmPath, result, |
| 39 | + "detectGuardWasm should return the baked-in path when the file exists") |
| 40 | + } |
| 41 | + // If the file does not exist and we cannot create it (no permission), skip. |
| 42 | + t.Skip("baked-in guard not present and cannot create it in this environment") |
| 43 | +} |
| 44 | + |
| 45 | +// TestNewProxyCmd_AllFlagsRegistered verifies that newProxyCmd registers all expected flags. |
| 46 | +func TestNewProxyCmd_AllFlagsRegistered(t *testing.T) { |
| 47 | + cmd := newProxyCmd() |
| 48 | + require.NotNil(t, cmd) |
| 49 | + |
| 50 | + expectedFlags := []string{ |
| 51 | + "guard-wasm", |
| 52 | + "policy", |
| 53 | + "github-token", |
| 54 | + "listen", |
| 55 | + "log-dir", |
| 56 | + "guards-mode", |
| 57 | + "github-api-url", |
| 58 | + "tls", |
| 59 | + "tls-dir", |
| 60 | + "trusted-bots", |
| 61 | + "trusted-users", |
| 62 | + "otlp-endpoint", |
| 63 | + "otlp-service-name", |
| 64 | + "otlp-sample-rate", |
| 65 | + } |
| 66 | + |
| 67 | + for _, flagName := range expectedFlags { |
| 68 | + t.Run("flag_"+flagName, func(t *testing.T) { |
| 69 | + flag := cmd.Flags().Lookup(flagName) |
| 70 | + assert.NotNil(t, flag, "flag --%s should be registered on proxy command", flagName) |
| 71 | + }) |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +// TestNewProxyCmd_CommandMetadata verifies the command's metadata is correctly set. |
| 76 | +func TestNewProxyCmd_CommandMetadata(t *testing.T) { |
| 77 | + cmd := newProxyCmd() |
| 78 | + require.NotNil(t, cmd) |
| 79 | + |
| 80 | + assert.Equal(t, "proxy", cmd.Use, "proxy command Use should be 'proxy'") |
| 81 | + assert.NotEmpty(t, cmd.Short, "proxy command Short should not be empty") |
| 82 | + assert.NotEmpty(t, cmd.Long, "proxy command Long should not be empty") |
| 83 | + assert.True(t, cmd.SilenceUsage, "proxy command should silence usage on error") |
| 84 | + assert.NotNil(t, cmd.RunE, "proxy command RunE should be set") |
| 85 | + |
| 86 | + // Long description should mention proxy and DIFC |
| 87 | + assert.Contains(t, cmd.Long, "proxy", "Long description should mention 'proxy'") |
| 88 | + assert.Contains(t, cmd.Long, "DIFC", "Long description should mention 'DIFC'") |
| 89 | +} |
| 90 | + |
| 91 | +// TestNewProxyCmd_DefaultFlagValues verifies that flags have the expected defaults |
| 92 | +// when no environment variables are set. |
| 93 | +func TestNewProxyCmd_DefaultFlagValues(t *testing.T) { |
| 94 | + // Clear relevant env vars to get clean defaults |
| 95 | + envVarsToClear := []string{ |
| 96 | + "MCP_GATEWAY_GUARD_POLICY_JSON", |
| 97 | + "MCP_GATEWAY_LOG_DIR", |
| 98 | + "OTEL_EXPORTER_OTLP_ENDPOINT", |
| 99 | + "OTEL_SERVICE_NAME", |
| 100 | + } |
| 101 | + for _, envVar := range envVarsToClear { |
| 102 | + t.Setenv(envVar, "") |
| 103 | + } |
| 104 | + |
| 105 | + cmd := newProxyCmd() |
| 106 | + require.NotNil(t, cmd) |
| 107 | + |
| 108 | + tests := []struct { |
| 109 | + flagName string |
| 110 | + expectedType string |
| 111 | + validate func(t *testing.T, cmd *cobra.Command) |
| 112 | + }{ |
| 113 | + { |
| 114 | + flagName: "listen", |
| 115 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 116 | + t.Helper() |
| 117 | + val, err := cmd.Flags().GetString("listen") |
| 118 | + require.NoError(t, err) |
| 119 | + assert.Equal(t, "127.0.0.1:8080", val, "--listen default should be 127.0.0.1:8080") |
| 120 | + }, |
| 121 | + }, |
| 122 | + { |
| 123 | + flagName: "guards-mode", |
| 124 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 125 | + t.Helper() |
| 126 | + val, err := cmd.Flags().GetString("guards-mode") |
| 127 | + require.NoError(t, err) |
| 128 | + assert.Equal(t, "filter", val, "--guards-mode default for proxy should be 'filter'") |
| 129 | + }, |
| 130 | + }, |
| 131 | + { |
| 132 | + flagName: "github-token", |
| 133 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 134 | + t.Helper() |
| 135 | + val, err := cmd.Flags().GetString("github-token") |
| 136 | + require.NoError(t, err) |
| 137 | + assert.Equal(t, "", val, "--github-token default should be empty") |
| 138 | + }, |
| 139 | + }, |
| 140 | + { |
| 141 | + flagName: "github-api-url", |
| 142 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 143 | + t.Helper() |
| 144 | + val, err := cmd.Flags().GetString("github-api-url") |
| 145 | + require.NoError(t, err) |
| 146 | + assert.Equal(t, "", val, "--github-api-url default should be empty") |
| 147 | + }, |
| 148 | + }, |
| 149 | + { |
| 150 | + flagName: "tls", |
| 151 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 152 | + t.Helper() |
| 153 | + val, err := cmd.Flags().GetBool("tls") |
| 154 | + require.NoError(t, err) |
| 155 | + assert.False(t, val, "--tls default should be false") |
| 156 | + }, |
| 157 | + }, |
| 158 | + { |
| 159 | + flagName: "tls-dir", |
| 160 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 161 | + t.Helper() |
| 162 | + val, err := cmd.Flags().GetString("tls-dir") |
| 163 | + require.NoError(t, err) |
| 164 | + assert.Equal(t, "", val, "--tls-dir default should be empty") |
| 165 | + }, |
| 166 | + }, |
| 167 | + { |
| 168 | + flagName: "otlp-service-name", |
| 169 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 170 | + t.Helper() |
| 171 | + val, err := cmd.Flags().GetString("otlp-service-name") |
| 172 | + require.NoError(t, err) |
| 173 | + assert.Equal(t, config.DefaultTracingServiceName, val, |
| 174 | + "--otlp-service-name default should be the DefaultTracingServiceName constant") |
| 175 | + }, |
| 176 | + }, |
| 177 | + { |
| 178 | + flagName: "otlp-sample-rate", |
| 179 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 180 | + t.Helper() |
| 181 | + val, err := cmd.Flags().GetFloat64("otlp-sample-rate") |
| 182 | + require.NoError(t, err) |
| 183 | + assert.Equal(t, config.DefaultTracingSampleRate, val, |
| 184 | + "--otlp-sample-rate default should be DefaultTracingSampleRate") |
| 185 | + }, |
| 186 | + }, |
| 187 | + { |
| 188 | + flagName: "policy", |
| 189 | + validate: func(t *testing.T, cmd *cobra.Command) { |
| 190 | + t.Helper() |
| 191 | + val, err := cmd.Flags().GetString("policy") |
| 192 | + require.NoError(t, err) |
| 193 | + assert.Equal(t, "", val, "--policy default should be empty when env var is unset") |
| 194 | + }, |
| 195 | + }, |
| 196 | + } |
| 197 | + |
| 198 | + for _, tt := range tests { |
| 199 | + t.Run(tt.flagName, func(t *testing.T) { |
| 200 | + tt.validate(t, cmd) |
| 201 | + }) |
| 202 | + } |
| 203 | +} |
| 204 | + |
| 205 | +// TestNewProxyCmd_PolicyDefaultFromEnv verifies that --policy picks up its |
| 206 | +// default value from the MCP_GATEWAY_GUARD_POLICY_JSON environment variable. |
| 207 | +func TestNewProxyCmd_PolicyDefaultFromEnv(t *testing.T) { |
| 208 | + envPolicy := `{"allow-only":{"repos":"public","min-integrity":"none"}}` |
| 209 | + t.Setenv("MCP_GATEWAY_GUARD_POLICY_JSON", envPolicy) |
| 210 | + |
| 211 | + cmd := newProxyCmd() |
| 212 | + require.NotNil(t, cmd) |
| 213 | + |
| 214 | + val, err := cmd.Flags().GetString("policy") |
| 215 | + require.NoError(t, err) |
| 216 | + assert.Equal(t, envPolicy, val, |
| 217 | + "--policy default should reflect MCP_GATEWAY_GUARD_POLICY_JSON environment variable") |
| 218 | +} |
| 219 | + |
| 220 | +// TestNewProxyCmd_OTLPEndpointDefaultFromEnv verifies that --otlp-endpoint picks |
| 221 | +// up its default value from OTEL_EXPORTER_OTLP_ENDPOINT. |
| 222 | +func TestNewProxyCmd_OTLPEndpointDefaultFromEnv(t *testing.T) { |
| 223 | + endpoint := "http://otel-collector:4318" |
| 224 | + t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint) |
| 225 | + |
| 226 | + cmd := newProxyCmd() |
| 227 | + require.NotNil(t, cmd) |
| 228 | + |
| 229 | + val, err := cmd.Flags().GetString("otlp-endpoint") |
| 230 | + require.NoError(t, err) |
| 231 | + assert.Equal(t, endpoint, val, |
| 232 | + "--otlp-endpoint default should reflect OTEL_EXPORTER_OTLP_ENDPOINT environment variable") |
| 233 | +} |
| 234 | + |
| 235 | +// TestNewProxyCmd_OTLPServiceNameDefaultFromEnv verifies that --otlp-service-name |
| 236 | +// picks up its default value from OTEL_SERVICE_NAME. |
| 237 | +func TestNewProxyCmd_OTLPServiceNameDefaultFromEnv(t *testing.T) { |
| 238 | + serviceName := "my-custom-proxy" |
| 239 | + t.Setenv("OTEL_SERVICE_NAME", serviceName) |
| 240 | + |
| 241 | + cmd := newProxyCmd() |
| 242 | + require.NotNil(t, cmd) |
| 243 | + |
| 244 | + val, err := cmd.Flags().GetString("otlp-service-name") |
| 245 | + require.NoError(t, err) |
| 246 | + assert.Equal(t, serviceName, val, |
| 247 | + "--otlp-service-name default should reflect OTEL_SERVICE_NAME environment variable") |
| 248 | +} |
| 249 | + |
| 250 | +// TestNewProxyCmd_GuardWasmRequiredWhenNoBakedInGuard verifies that --guard-wasm is |
| 251 | +// marked as required when the baked-in container guard does not exist. |
| 252 | +func TestNewProxyCmd_GuardWasmRequiredWhenNoBakedInGuard(t *testing.T) { |
| 253 | + // This test is only meaningful when running outside a container. |
| 254 | + if _, err := os.Stat(containerGuardWasmPath); err == nil { |
| 255 | + t.Skipf("baked-in guard found at %s — in container, --guard-wasm is optional", containerGuardWasmPath) |
| 256 | + } |
| 257 | + |
| 258 | + cmd := newProxyCmd() |
| 259 | + require.NotNil(t, cmd) |
| 260 | + |
| 261 | + // Execute with no flags — the command should fail with "required flag" error, |
| 262 | + // not with any other error, confirming the flag is marked required. |
| 263 | + cmd.SetArgs([]string{}) |
| 264 | + err := cmd.Execute() |
| 265 | + require.Error(t, err, "executing proxy command without --guard-wasm should return an error") |
| 266 | + assert.True(t, |
| 267 | + strings.Contains(err.Error(), "guard-wasm") || strings.Contains(err.Error(), "required"), |
| 268 | + "error should mention --guard-wasm or required: %v", err) |
| 269 | +} |
| 270 | + |
| 271 | +// TestNewProxyCmd_GuardWasmFlagHelpText verifies the --guard-wasm flag help text |
| 272 | +// reflects whether the baked-in guard was auto-detected. |
| 273 | +func TestNewProxyCmd_GuardWasmFlagHelpText(t *testing.T) { |
| 274 | + cmd := newProxyCmd() |
| 275 | + require.NotNil(t, cmd) |
| 276 | + |
| 277 | + flag := cmd.Flags().Lookup("guard-wasm") |
| 278 | + require.NotNil(t, flag, "--guard-wasm flag should exist") |
| 279 | + |
| 280 | + _, err := os.Stat(containerGuardWasmPath) |
| 281 | + if err == nil { |
| 282 | + // In container environment: help text should mention auto-detected path |
| 283 | + assert.Contains(t, flag.Usage, "auto-detected", |
| 284 | + "--guard-wasm help should mention auto-detected when baked-in guard exists") |
| 285 | + assert.Contains(t, flag.Usage, containerGuardWasmPath, |
| 286 | + "--guard-wasm help should include the detected path") |
| 287 | + } else { |
| 288 | + // Not in container: help text should say required |
| 289 | + assert.Contains(t, flag.Usage, "required", |
| 290 | + "--guard-wasm help should say 'required' when no baked-in guard exists") |
| 291 | + } |
| 292 | +} |
| 293 | + |
| 294 | +// TestNewProxyCmd_GuardsModeCompletion verifies the guards-mode flag has the |
| 295 | +// correct shell completion function returning valid enum values. |
| 296 | +func TestNewProxyCmd_GuardsModeCompletion(t *testing.T) { |
| 297 | + cmd := newProxyCmd() |
| 298 | + require.NotNil(t, cmd) |
| 299 | + |
| 300 | + completionFn, ok := cmd.GetFlagCompletionFunc("guards-mode") |
| 301 | + require.True(t, ok, "guards-mode flag should have a completion function registered") |
| 302 | + require.NotNil(t, completionFn, "guards-mode completion function should not be nil") |
| 303 | + |
| 304 | + completions, directive := completionFn(cmd, nil, "") |
| 305 | + |
| 306 | + assert.Equal(t, cobra.ShellCompDirectiveNoFileComp, directive, |
| 307 | + "guards-mode completion should use ShellCompDirectiveNoFileComp directive") |
| 308 | + assert.ElementsMatch(t, []string{"strict", "filter", "propagate"}, completions, |
| 309 | + "guards-mode completion should return all valid enforcement modes") |
| 310 | +} |
| 311 | + |
| 312 | +// TestNewProxyCmd_TrustedBotsAndUsersDefaultNil verifies that --trusted-bots and |
| 313 | +// --trusted-users default to nil (no pre-configured trusted users/bots). |
| 314 | +func TestNewProxyCmd_TrustedBotsAndUsersDefaultNil(t *testing.T) { |
| 315 | + cmd := newProxyCmd() |
| 316 | + require.NotNil(t, cmd) |
| 317 | + |
| 318 | + bots, err := cmd.Flags().GetStringSlice("trusted-bots") |
| 319 | + require.NoError(t, err) |
| 320 | + assert.Empty(t, bots, "--trusted-bots should default to empty/nil") |
| 321 | + |
| 322 | + users, err := cmd.Flags().GetStringSlice("trusted-users") |
| 323 | + require.NoError(t, err) |
| 324 | + assert.Empty(t, users, "--trusted-users should default to empty/nil") |
| 325 | +} |
| 326 | + |
| 327 | +// TestNewProxyCmd_LogDirDefault verifies --log-dir uses getDefaultLogDir() as default. |
| 328 | +func TestNewProxyCmd_LogDirDefault(t *testing.T) { |
| 329 | + t.Setenv("MCP_GATEWAY_LOG_DIR", "") |
| 330 | + |
| 331 | + cmd := newProxyCmd() |
| 332 | + require.NotNil(t, cmd) |
| 333 | + |
| 334 | + val, err := cmd.Flags().GetString("log-dir") |
| 335 | + require.NoError(t, err) |
| 336 | + assert.Equal(t, getDefaultLogDir(), val, |
| 337 | + "--log-dir should use getDefaultLogDir() as its default value") |
| 338 | +} |
| 339 | + |
| 340 | +// TestNewProxyCmd_ListenFlag verifies --listen, -l shorthand and default value. |
| 341 | +func TestNewProxyCmd_ListenFlag(t *testing.T) { |
| 342 | + cmd := newProxyCmd() |
| 343 | + require.NotNil(t, cmd) |
| 344 | + |
| 345 | + flag := cmd.Flags().Lookup("listen") |
| 346 | + require.NotNil(t, flag) |
| 347 | + assert.Equal(t, "127.0.0.1:8080", flag.DefValue, "--listen default should be 127.0.0.1:8080") |
| 348 | + |
| 349 | + // Verify the flag has a shorthand |
| 350 | + shortFlag := cmd.Flags().ShorthandLookup("l") |
| 351 | + require.NotNil(t, shortFlag, "-l shorthand should be registered for --listen") |
| 352 | + assert.Equal(t, "listen", shortFlag.Name, "-l should map to --listen") |
| 353 | +} |
| 354 | + |
| 355 | +// TestNewProxyCmd_IsAddedToRootCmd verifies the proxy subcommand is registered |
| 356 | +// on the root command so it's accessible via `awmg proxy`. |
| 357 | +func TestNewProxyCmd_IsAddedToRootCmd(t *testing.T) { |
| 358 | + found := false |
| 359 | + for _, sub := range rootCmd.Commands() { |
| 360 | + if sub.Use == "proxy" { |
| 361 | + found = true |
| 362 | + break |
| 363 | + } |
| 364 | + } |
| 365 | + assert.True(t, found, "proxy subcommand should be registered on the root command") |
| 366 | +} |
| 367 | + |
9 | 368 | func TestClientAddr(t *testing.T) { |
10 | 369 | tests := []struct { |
11 | 370 | name string |
|
0 commit comments