Skip to content

Commit 5b17a69

Browse files
authored
fix(http2): cancel sending client request body on response future drop (#4042)
When the client drops the response future (e.g. due to a timeout), send_task detects the cancellation but does not notify pipe_task. The pipe_task continues to hold the h2 SendStream, preventing a RST_STREAM from being sent and keeping flow-control window capacity locked. Add a oneshot channel between send_task and pipe_task. When send_task detects cancellation via poll_canceled, it signals pipe_task through the channel. pipe_task then calls send_reset(CANCEL) on the h2 SendStream, which sends RST_STREAM to the server and frees flow-control capacity. Closes #4040
1 parent 7211ec2 commit 5b17a69

4 files changed

Lines changed: 86 additions & 0 deletions

File tree

src/client/dispatch.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ where
366366
}
367367
};
368368
trace!("send_when canceled");
369+
// Tell pipe_task to reset the h2 stream so that
370+
// RST_STREAM is sent and flow-control capacity freed.
371+
this.when.as_mut().cancel();
369372
Poll::Ready(())
370373
}
371374
Poll::Ready(Err((error, message))) => {

src/proto/h2/client.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ pin_project! {
461461
conn_drop_ref: Option<Sender<Infallible>>,
462462
#[pin]
463463
ping: Option<Recorder>,
464+
cancel_rx: Option<oneshot::Receiver<()>>,
464465
}
465466
}
466467

@@ -474,6 +475,26 @@ where
474475
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Self::Output> {
475476
let mut this = self.project();
476477

478+
// Check if the client cancelled the request (e.g. dropped the
479+
// response future due to a timeout). If so, reset the h2 stream
480+
// so that a RST_STREAM is sent and flow-control capacity is freed.
481+
let cancel_result = this.cancel_rx.as_mut().map(|rx| Pin::new(rx).poll(cx));
482+
match cancel_result {
483+
Some(Poll::Ready(Ok(()))) => {
484+
debug!("client request body send cancelled, resetting stream");
485+
this.pipe.as_mut().send_reset(h2::Reason::CANCEL);
486+
drop(this.conn_drop_ref.take().expect("Future polled twice"));
487+
drop(this.ping.take().expect("Future polled twice"));
488+
return Poll::Ready(());
489+
}
490+
Some(Poll::Ready(Err(_))) => {
491+
// Sender dropped without cancelling (normal response or error).
492+
// Stop polling the receiver.
493+
*this.cancel_rx = None;
494+
}
495+
Some(Poll::Pending) | None => {}
496+
}
497+
477498
match Pin::new(&mut this.pipe).poll(cx) {
478499
Poll::Ready(result) => {
479500
if let Err(_e) = result {
@@ -500,6 +521,10 @@ where
500521
fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut Context<'_>) {
501522
let ping = self.ping.clone();
502523

524+
// A one-shot channel so that send_task can tell pipe_task to
525+
// reset the stream when the client cancels the request.
526+
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
527+
503528
let send_stream = if !f.is_connect {
504529
if !f.eos {
505530
let mut pipe = PipeToSendStream::new(f.body, f.body_tx);
@@ -519,6 +544,7 @@ where
519544
pipe,
520545
conn_drop_ref: Some(conn_drop_ref),
521546
ping: Some(ping),
547+
cancel_rx: Some(cancel_rx),
522548
};
523549
// Clear send task
524550
self.executor
@@ -539,6 +565,7 @@ where
539565
ping: Some(ping),
540566
send_stream: Some(send_stream),
541567
exec: self.executor.clone(),
568+
cancel_tx: Some(cancel_tx),
542569
},
543570
call_back: Some(f.cb),
544571
},
@@ -558,6 +585,16 @@ pin_project! {
558585
#[pin]
559586
send_stream: Option<Option<SendStream<SendBuf<<B as Body>::Data>>>>,
560587
exec: E,
588+
cancel_tx: Option<oneshot::Sender<()>>,
589+
}
590+
}
591+
592+
impl<B: Body + 'static, E> ResponseFutMap<B, E> {
593+
/// Signal the pipe_task to reset the stream (e.g. on client cancellation).
594+
pub(crate) fn cancel(self: Pin<&mut Self>) {
595+
if let Some(cancel_tx) = self.project().cancel_tx.take() {
596+
let _ = cancel_tx.send(());
597+
}
561598
}
562599
}
563600

src/proto/h2/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ where
104104
stream,
105105
}
106106
}
107+
108+
#[cfg(feature = "client")]
109+
fn send_reset(self: Pin<&mut Self>, reason: h2::Reason) {
110+
self.project().body_tx.send_reset(reason);
111+
}
107112
}
108113

109114
impl<S> Future for PipeToSendStream<S>

tests/client.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2868,6 +2868,47 @@ mod conn {
28682868
Pin::new(&mut self.tcp).poll_read(cx, buf)
28692869
}
28702870
}
2871+
2872+
// https://github.com/hyperium/hyper/issues/4040
2873+
#[tokio::test]
2874+
async fn h2_pipe_task_cancelled_on_response_future_drop() {
2875+
let (client_io, server_io, _) = setup_duplex_test_server();
2876+
let (rst_tx, rst_rx) = oneshot::channel::<bool>();
2877+
2878+
tokio::spawn(async move {
2879+
let mut builder = h2::server::Builder::new();
2880+
builder.initial_window_size(0);
2881+
let mut h2 = builder.handshake::<_, Bytes>(server_io).await.unwrap();
2882+
let (req, _respond) = h2.accept().await.unwrap().unwrap();
2883+
tokio::spawn(async move {
2884+
let _ = poll_fn(|cx| h2.poll_closed(cx)).await;
2885+
});
2886+
2887+
let mut body = req.into_body();
2888+
let got_rst = tokio::time::timeout(Duration::from_secs(2), body.data())
2889+
.await
2890+
.map_or(false, |frame| matches!(frame, Some(Err(_)) | None));
2891+
let _ = rst_tx.send(got_rst);
2892+
});
2893+
2894+
let io = TokioIo::new(client_io);
2895+
let (mut client, conn) = conn::http2::Builder::new(TokioExecutor)
2896+
.handshake(io)
2897+
.await
2898+
.expect("http handshake");
2899+
tokio::spawn(async move {
2900+
let _ = conn.await;
2901+
});
2902+
2903+
let req = Request::post("http://localhost/")
2904+
.body(Full::new(Bytes::from(vec![b'x'; 50])))
2905+
.unwrap();
2906+
let res = tokio::time::timeout(Duration::from_millis(5), client.send_request(req)).await;
2907+
assert!(res.is_err(), "should timeout waiting for response");
2908+
2909+
let got_rst = rst_rx.await.expect("server task should complete");
2910+
assert!(got_rst, "server should receive RST_STREAM");
2911+
}
28712912
}
28722913

28732914
trait FutureHyperExt: TryFuture {

0 commit comments

Comments
 (0)