Skip to content

Commit eb773ed

Browse files
committed
- integration test for all transport
- add web-socket transport
1 parent 9532561 commit eb773ed

3 files changed

Lines changed: 572 additions & 3 deletions

File tree

modules/jooby-mcp/src/main/java/io/jooby/mcp/McpModule.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.jooby.mcp.transport.JoobySseTransportProvider;
2424
import io.jooby.mcp.transport.JoobyStatelessServerTransport;
2525
import io.jooby.mcp.transport.JoobyStreamableServerTransportProvider;
26+
import io.jooby.mcp.transport.JoobyWebSocketServerTransportProvider;
2627
import io.modelcontextprotocol.common.McpTransportContext;
2728
import io.modelcontextprotocol.json.McpJsonMapper;
2829
import io.modelcontextprotocol.json.jackson3.JacksonMcpJsonMapper;
@@ -183,6 +184,10 @@ public void install(@NonNull Jooby app) {
183184
McpServer.sync(
184185
new JoobySseTransportProvider(
185186
app, mcpConfig, mcpJsonMapper, CTX_EXTRACTOR));
187+
case WEBSOCKET ->
188+
McpServer.sync(
189+
new JoobyWebSocketServerTransportProvider(
190+
app, mcpConfig, mcpJsonMapper, CTX_EXTRACTOR));
186191
default ->
187192
throw new IllegalStateException(
188193
"Unsupported transport: " + mcpConfig.getTransport());
@@ -257,7 +262,8 @@ public McpModule mcpJsonMapper(McpJsonMapper mcpJsonMapper) {
257262
public enum Transport {
258263
SSE("sse"),
259264
STREAMABLE_HTTP("streamable-http"),
260-
STATELESS_STREAMABLE_HTTP("stateless-streamable-http");
265+
STATELESS_STREAMABLE_HTTP("stateless-streamable-http"),
266+
WEBSOCKET("websocket");
261267

262268
private final String value;
263269

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Jooby https://jooby.io
3+
* Apache License Version 2.0 https://jooby.io/LICENSE.txt
4+
* Copyright 2014 Edgar Espina
5+
*/
6+
package io.jooby.mcp.transport;
7+
8+
import java.io.IOException;
9+
import java.util.concurrent.ConcurrentHashMap;
10+
import java.util.concurrent.atomic.AtomicBoolean;
11+
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
15+
import io.jooby.Context;
16+
import io.jooby.Jooby;
17+
import io.jooby.WebSocket;
18+
import io.jooby.WebSocketCloseStatus;
19+
import io.jooby.WebSocketMessage;
20+
import io.jooby.internal.mcp.McpServerConfig;
21+
import io.modelcontextprotocol.common.McpTransportContext;
22+
import io.modelcontextprotocol.json.McpJsonMapper;
23+
import io.modelcontextprotocol.json.TypeRef;
24+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
25+
import io.modelcontextprotocol.spec.McpSchema;
26+
import io.modelcontextprotocol.spec.McpServerSession;
27+
import io.modelcontextprotocol.spec.McpServerTransport;
28+
import io.modelcontextprotocol.spec.McpServerTransportProvider;
29+
import reactor.core.publisher.Flux;
30+
import reactor.core.publisher.Mono;
31+
32+
/**
33+
* Provides WebSocket transport implementation for MCP server using Jooby framework. Handles
34+
* bidirectional client connections, message routing, and session management.
35+
*/
36+
@SuppressWarnings("PMD")
37+
public class JoobyWebSocketServerTransportProvider implements McpServerTransportProvider {
38+
39+
private static final Logger LOG =
40+
LoggerFactory.getLogger(JoobyWebSocketServerTransportProvider.class);
41+
private static final String MCP_SESSION_ATTRIBUTE = "mcpSessionId";
42+
43+
private final McpJsonMapper mcpJsonMapper;
44+
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
45+
private final McpTransportContextExtractor<Context> contextExtractor;
46+
47+
private McpServerSession.Factory sessionFactory;
48+
private final AtomicBoolean isClosing = new AtomicBoolean(false);
49+
50+
/**
51+
* Constructs a new Jooby WebSocket transport provider instance.
52+
*
53+
* @param app The Jooby application instance to register endpoints with
54+
* @param serverConfig The MCP server configuration containing endpoint settings
55+
* @param mcpJsonMapper The MCP JSON mapper for message serialization/deserialization
56+
* @param contextExtractor The extractor for transport context
57+
*/
58+
public JoobyWebSocketServerTransportProvider(
59+
Jooby app,
60+
McpServerConfig serverConfig,
61+
McpJsonMapper mcpJsonMapper,
62+
McpTransportContextExtractor<Context> contextExtractor) {
63+
this.mcpJsonMapper = mcpJsonMapper;
64+
this.contextExtractor = contextExtractor;
65+
66+
String wsEndpoint = serverConfig.getMcpEndpoint();
67+
68+
app.ws(
69+
wsEndpoint,
70+
(ctx, ws) -> {
71+
ws.onConnect(this::handleConnect);
72+
ws.onMessage(this::handleMessage);
73+
ws.onClose(this::handleClose);
74+
ws.onError(this::handleError);
75+
});
76+
}
77+
78+
@Override
79+
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
80+
this.sessionFactory = sessionFactory;
81+
}
82+
83+
@Override
84+
public Mono<Void> notifyClients(String method, Object params) {
85+
if (sessions.isEmpty()) {
86+
LOG.debug("No active WebSocket sessions to broadcast a message to");
87+
return Mono.empty();
88+
}
89+
90+
if (LOG.isDebugEnabled()) {
91+
LOG.debug("Attempting to broadcast a message to {} active WS sessions", sessions.size());
92+
}
93+
94+
return Flux.fromIterable(sessions.values())
95+
.flatMap(
96+
session ->
97+
session
98+
.sendNotification(method, params)
99+
.doOnError(
100+
e ->
101+
LOG.error(
102+
"Failed to send message to WS session {}: {}",
103+
session.getId(),
104+
e.getMessage()))
105+
.onErrorComplete())
106+
.then();
107+
}
108+
109+
@Override
110+
public Mono<Void> closeGracefully() {
111+
return Flux.fromIterable(sessions.values())
112+
.doFirst(
113+
() -> {
114+
isClosing.set(true);
115+
if (LOG.isDebugEnabled()) {
116+
LOG.debug(
117+
"Initiating graceful shutdown with {} active WS sessions", sessions.size());
118+
}
119+
})
120+
.flatMap(McpServerSession::closeGracefully)
121+
.doFinally(signalType -> sessions.clear())
122+
.then();
123+
}
124+
125+
private void handleConnect(WebSocket ws) {
126+
if (isClosing.get()) {
127+
ws.close(WebSocketCloseStatus.SERVICE_RESTARTED);
128+
return;
129+
}
130+
131+
JoobyMcpWebSocketTransport transport = new JoobyMcpWebSocketTransport(ws);
132+
McpServerSession session = sessionFactory.create(transport);
133+
String sessionId = session.getId();
134+
135+
ws.attribute(MCP_SESSION_ATTRIBUTE, sessionId);
136+
sessions.put(sessionId, session);
137+
138+
LOG.debug("New WebSocket connection established. Session ID: {}", sessionId);
139+
}
140+
141+
private void handleMessage(WebSocket ws, WebSocketMessage msg) {
142+
String sessionId = ws.attribute(MCP_SESSION_ATTRIBUTE);
143+
if (sessionId == null) {
144+
LOG.warn("Received message on WebSocket without an associated MCP session");
145+
return;
146+
}
147+
148+
McpServerSession session = sessions.get(sessionId);
149+
if (session == null) {
150+
LOG.warn("Received message for unknown WS session ID: {}", sessionId);
151+
return;
152+
}
153+
154+
try {
155+
Context ctx = ws.getContext();
156+
McpTransportContext transportContext = this.contextExtractor.extract(ctx);
157+
String body = msg.value();
158+
159+
McpSchema.JSONRPCMessage message =
160+
McpSchema.deserializeJsonRpcMessage(this.mcpJsonMapper, body);
161+
162+
// Unlike HTTP POSTs, WebSockets are fully asynchronous streams, so we just subscribe
163+
// rather than blocking and returning an HTTP StatusCode.
164+
session
165+
.handle(message)
166+
.contextWrite(
167+
reactorCtx ->
168+
reactorCtx
169+
.put(io.modelcontextprotocol.common.McpTransportContext.KEY, transportContext)
170+
.put("CTX", ctx))
171+
.subscribe(
172+
null,
173+
error ->
174+
LOG.error(
175+
"Error processing WS message for session {}: {}",
176+
sessionId,
177+
error.getMessage()));
178+
} catch (IOException | IllegalArgumentException e) {
179+
LOG.error("Failed to deserialize WS message: {}", e.getMessage());
180+
}
181+
}
182+
183+
private void handleClose(WebSocket ws, WebSocketCloseStatus status) {
184+
String sessionId = ws.attribute(MCP_SESSION_ATTRIBUTE);
185+
if (sessionId != null) {
186+
LOG.debug(
187+
"WebSocket connection closed for session: {} with status: {}",
188+
sessionId,
189+
status.getCode());
190+
sessions.remove(sessionId);
191+
}
192+
}
193+
194+
private void handleError(WebSocket ws, Throwable cause) {
195+
String sessionId = ws.attribute(MCP_SESSION_ATTRIBUTE);
196+
LOG.error("WebSocket error for session: {}", sessionId, cause);
197+
}
198+
199+
private class JoobyMcpWebSocketTransport implements McpServerTransport {
200+
201+
private final WebSocket ws;
202+
private volatile boolean closed = false;
203+
204+
public JoobyMcpWebSocketTransport(WebSocket ws) {
205+
this.ws = ws;
206+
}
207+
208+
@Override
209+
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
210+
return Mono.fromRunnable(
211+
() -> {
212+
try {
213+
if (!closed) {
214+
String jsonText = mcpJsonMapper.writeValueAsString(message);
215+
ws.send(jsonText);
216+
}
217+
} catch (Exception e) {
218+
LOG.error("Failed to send WebSocket message: {}", e.getMessage());
219+
}
220+
});
221+
}
222+
223+
@Override
224+
public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
225+
return mcpJsonMapper.convertValue(data, typeRef);
226+
}
227+
228+
@Override
229+
public Mono<Void> closeGracefully() {
230+
return Mono.fromRunnable(this::close);
231+
}
232+
233+
@Override
234+
public void close() {
235+
if (!closed) {
236+
closed = true;
237+
ws.close(WebSocketCloseStatus.NORMAL);
238+
}
239+
}
240+
}
241+
}

0 commit comments

Comments
 (0)