Skip to content

Commit 98d5a84

Browse files
Copilotlpcox
andauthored
wazero improvements: Registry.Close, logging namespace fix, typed trap detection
- Add Registry.Close(ctx) method to close all registered guards that implement io.Closer; call it from UnifiedServer.Close() and InitiateShutdown() to release WASM runtime resources on shutdown - Fix logging namespace confusion in registry.go: replace log.Printf (which was using the guard:context namespace logger) with logger.LogInfo("guard", ...) for operational events - Use typed sys.ExitError check in isWasmTrap: check exit code 0 as a normal process exit (not a trap) before falling back to string matching - Add tests for Registry.Close and sys.ExitError handling in isWasmTrap Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/842dba68-e351-47f7-84e1-cf101f122182 Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com>
1 parent fd32e4a commit 98d5a84

5 files changed

Lines changed: 129 additions & 7 deletions

File tree

internal/guard/guard_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package guard
22

33
import (
44
"context"
5+
"errors"
56
"sync"
67
"testing"
78

@@ -17,6 +18,18 @@ type mockGuard struct {
1718
id string
1819
}
1920

21+
// mockClosableGuard is a guard that tracks whether Close was called
22+
type mockClosableGuard struct {
23+
mockGuard
24+
closed bool
25+
closeErr error
26+
}
27+
28+
func (m *mockClosableGuard) Close(ctx context.Context) error {
29+
m.closed = true
30+
return m.closeErr
31+
}
32+
2033
func (m *mockGuard) Name() string { return "mock-" + m.id }
2134
func (m *mockGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) {
2235
return &LabelAgentResult{DIFCMode: difc.ModeStrict}, nil
@@ -450,6 +463,59 @@ func TestGuardRegistry_HasNonNoopGuard(t *testing.T) {
450463
})
451464
}
452465

466+
func TestGuardRegistry_Close(t *testing.T) {
467+
t.Run("close calls Close on guards that implement it", func(t *testing.T) {
468+
registry := NewRegistry()
469+
g := &mockClosableGuard{mockGuard: mockGuard{id: "wasm"}}
470+
registry.Register("server1", g)
471+
472+
registry.Close(context.Background())
473+
474+
assert.True(t, g.closed, "expected guard Close to be called")
475+
})
476+
477+
t.Run("close skips guards that do not implement Close", func(t *testing.T) {
478+
registry := NewRegistry()
479+
registry.Register("server1", NewNoopGuard())
480+
481+
// Should not panic
482+
registry.Close(context.Background())
483+
})
484+
485+
t.Run("close on empty registry is safe", func(t *testing.T) {
486+
registry := NewRegistry()
487+
// Should not panic
488+
registry.Close(context.Background())
489+
})
490+
491+
t.Run("close calls Close on all closable guards", func(t *testing.T) {
492+
registry := NewRegistry()
493+
g1 := &mockClosableGuard{mockGuard: mockGuard{id: "wasm1"}}
494+
g2 := &mockClosableGuard{mockGuard: mockGuard{id: "wasm2"}}
495+
registry.Register("server1", g1)
496+
registry.Register("server2", g2)
497+
498+
registry.Close(context.Background())
499+
500+
assert.True(t, g1.closed, "expected guard 1 Close to be called")
501+
assert.True(t, g2.closed, "expected guard 2 Close to be called")
502+
})
503+
504+
t.Run("close continues when one guard returns an error", func(t *testing.T) {
505+
registry := NewRegistry()
506+
g1 := &mockClosableGuard{mockGuard: mockGuard{id: "failing"}, closeErr: errors.New("close failed")}
507+
g2 := &mockClosableGuard{mockGuard: mockGuard{id: "ok"}}
508+
registry.Register("server1", g1)
509+
registry.Register("server2", g2)
510+
511+
// Should not panic even when one guard returns an error
512+
registry.Close(context.Background())
513+
514+
assert.True(t, g1.closed, "expected failing guard Close to be called")
515+
assert.True(t, g2.closed, "expected ok guard Close to be called")
516+
})
517+
}
518+
453519
func TestCreateGuard(t *testing.T) {
454520
tests := []struct {
455521
name string

internal/guard/registry.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package guard
22

33
import (
4+
"context"
45
"fmt"
56
"sync"
67

@@ -30,7 +31,7 @@ func (r *Registry) Register(serverID string, guard Guard) {
3031
defer r.mu.Unlock()
3132

3233
r.guards[serverID] = guard
33-
log.Printf("[Guard] Registered guard '%s' for server '%s'", guard.Name(), serverID)
34+
logger.LogInfo("guard", "Registered guard '%s' for server '%s'", guard.Name(), serverID)
3435
}
3536

3637
// Get retrieves the guard for a server, or returns a noop guard if not found
@@ -46,7 +47,6 @@ func (r *Registry) Get(serverID string) Guard {
4647

4748
// Return noop guard as default
4849
debugLog.Printf("No guard registered for serverID=%s, returning noop guard", serverID)
49-
log.Printf("[Guard] No guard registered for server '%s', using noop guard", serverID)
5050
return NewNoopGuard()
5151
}
5252

@@ -76,7 +76,7 @@ func (r *Registry) Remove(serverID string) {
7676
r.mu.Lock()
7777
defer r.mu.Unlock()
7878
delete(r.guards, serverID)
79-
log.Printf("[Guard] Removed guard for server '%s'", serverID)
79+
logger.LogInfo("guard", "Removed guard for server '%s'", serverID)
8080
}
8181

8282
// List returns all registered server IDs
@@ -103,6 +103,20 @@ func (r *Registry) GetGuardInfo() map[string]string {
103103
return info
104104
}
105105

106+
// Close closes all registered guards that implement io.Closer.
107+
// It should be called during server shutdown to release WASM runtime resources.
108+
func (r *Registry) Close(ctx context.Context) {
109+
r.mu.Lock()
110+
defer r.mu.Unlock()
111+
for id, g := range r.guards {
112+
if c, ok := g.(interface{ Close(context.Context) error }); ok {
113+
if err := c.Close(ctx); err != nil {
114+
logger.LogWarn("guard", "Failed to close guard for server %s: %v", id, err)
115+
}
116+
}
117+
}
118+
}
119+
106120
// GuardFactory is a function that creates a guard instance
107121
type GuardFactory func() (Guard, error)
108122

@@ -116,7 +130,7 @@ func RegisterGuardType(name string, factory GuardFactory) {
116130
registeredGuardsMu.Lock()
117131
defer registeredGuardsMu.Unlock()
118132
registeredGuards[name] = factory
119-
log.Printf("[Guard] Registered guard type: %s", name)
133+
logger.LogInfo("guard", "Registered guard type: %s", name)
120134
}
121135

122136
// CreateGuard creates a guard instance by name using registered factories

internal/guard/wasm.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/tetratelabs/wazero"
1616
"github.com/tetratelabs/wazero/api"
1717
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
18+
"github.com/tetratelabs/wazero/sys"
1819
)
1920

2021
var logWasm = logger.New("guard:wasm")
@@ -830,10 +831,22 @@ func parsePathLabeledResponse(responseJSON []byte, originalData interface{}) (di
830831
return pld.ToCollectionLabeledData(), nil
831832
}
832833

833-
// isWasmTrap reports whether err is a WASM execution trap such as the
834-
// "wasm error: unreachable" produced when a Rust-compiled guard panics.
834+
// isWasmTrap reports whether err represents a WASM execution trap that should
835+
// permanently poison the guard. Normal process exits (exit code 0, e.g. TinyGo
836+
// init) are NOT considered traps. A non-zero exit code is treated as a trap.
837+
// As a fallback for wazero execution faults (e.g. Rust panic → unreachable),
838+
// the function also matches on wazero's "wasm error:" message prefix.
835839
func isWasmTrap(err error) bool {
836-
return err != nil && strings.Contains(err.Error(), "wasm error:")
840+
if err == nil {
841+
return false
842+
}
843+
// A normal WASI process exit (exit code 0) is not a trap — don't poison the guard.
844+
var exitErr *sys.ExitError
845+
if errors.As(err, &exitErr) {
846+
return exitErr.ExitCode() != 0
847+
}
848+
// Fallback for wazero execution traps (e.g. Rust panic → unreachable).
849+
return strings.Contains(err.Error(), "wasm error:")
837850
}
838851

839852
// callWasmFunction calls an exported function in the WASM module.

internal/guard/wasm_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616

1717
"github.com/github/gh-aw-mcpg/internal/difc"
1818
"github.com/tetratelabs/wazero"
19+
"github.com/tetratelabs/wazero/sys"
1920
)
2021

2122
func TestMain(m *testing.M) {
@@ -1152,6 +1153,26 @@ func TestIsWasmTrap(t *testing.T) {
11521153
err := errors.New("wasm error: out of bounds memory access")
11531154
assert.True(t, isWasmTrap(err))
11541155
})
1156+
1157+
t.Run("sys.ExitError with exit code 0 is not a trap", func(t *testing.T) {
1158+
err := sys.NewExitError(0)
1159+
assert.False(t, isWasmTrap(err))
1160+
})
1161+
1162+
t.Run("sys.ExitError with non-zero exit code is a trap", func(t *testing.T) {
1163+
err := sys.NewExitError(1)
1164+
assert.True(t, isWasmTrap(err))
1165+
})
1166+
1167+
t.Run("wrapped sys.ExitError with exit code 0 is not a trap", func(t *testing.T) {
1168+
err := fmt.Errorf("wrapper: %w", sys.NewExitError(0))
1169+
assert.False(t, isWasmTrap(err))
1170+
})
1171+
1172+
t.Run("wrapped sys.ExitError with non-zero exit code is a trap", func(t *testing.T) {
1173+
err := fmt.Errorf("wrapper: %w", sys.NewExitError(2))
1174+
assert.True(t, isWasmTrap(err))
1175+
})
11551176
}
11561177

11571178
func TestWasmGuardFailedState(t *testing.T) {

internal/server/unified.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,9 @@ func (us *UnifiedServer) GetToolHandler(backendID string, toolName string) func(
718718

719719
// Close cleans up resources
720720
func (us *UnifiedServer) Close() error {
721+
if us.guardRegistry != nil {
722+
us.guardRegistry.Close(context.Background())
723+
}
721724
us.launcher.Close()
722725
return nil
723726
}
@@ -753,6 +756,11 @@ func (us *UnifiedServer) InitiateShutdown() int {
753756
logger.LogInfo("shutdown", "Terminating %d backend servers", serversTerminated)
754757
us.launcher.Close()
755758

759+
// Release WASM runtime resources held by guards
760+
if us.guardRegistry != nil {
761+
us.guardRegistry.Close(context.Background())
762+
}
763+
756764
logger.LogInfo("shutdown", "Backend servers terminated successfully")
757765
})
758766
return serversTerminated

0 commit comments

Comments
 (0)