Skip to content

Commit c4c3d6a

Browse files
authored
Add decompressed size limits to snapshot downloads. (#2479)
1 parent eac9ef7 commit c4c3d6a

2 files changed

Lines changed: 89 additions & 9 deletions

File tree

cmd/soroban-cli/src/commands/snapshot/create.rs

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use stellar_xdr::curr::{
2020
ScAddress, ScContractInstance, ScVal,
2121
};
2222
use tokio::fs::OpenOptions;
23-
use tokio::io::BufReader;
23+
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
2424
use tokio_util::io::StreamReader;
2525
use url::Url;
2626

@@ -168,6 +168,9 @@ pub enum Error {
168168

169169
#[error("corrupted bucket file: expected hash {expected}, got {actual}")]
170170
CorruptedBucket { expected: String, actual: String },
171+
172+
#[error("decompressed size exceeds maximum of {max}")]
173+
DecompressedSizeLimitExceeded { max: ByteSize },
171174
}
172175

173176
/// Checkpoint frequency is usually 64 ledgers, but in local test nets it'll
@@ -176,6 +179,12 @@ pub enum Error {
176179
/// select good ledger numbers when they select one that doesn't exist.
177180
const CHECKPOINT_FREQUENCY: u32 = 64;
178181

182+
/// Maximum decompressed size for bucket files (10 GiB).
183+
const MAX_BUCKET_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024 * 1024;
184+
185+
/// Maximum decompressed size for ledger header files (100 MiB).
186+
const MAX_LEDGER_HEADER_DECOMPRESSED_SIZE: u64 = 100 * 1024 * 1024;
187+
179188
impl Cmd {
180189
#[allow(clippy::too_many_lines)]
181190
pub async fn run(&self, global_args: &global::Args) -> Result<(), Error> {
@@ -501,6 +510,35 @@ impl Cmd {
501510
}
502511
}
503512

513+
/// Copy decompressed data from `reader` to `writer`, enforcing a maximum
514+
/// decompressed size. Returns an error if the decompressed output exceeds
515+
/// `max_bytes`.
516+
async fn copy_with_limit<R: AsyncRead + Unpin, W: tokio::io::AsyncWrite + Unpin>(
517+
reader: R,
518+
writer: &mut W,
519+
max_bytes: u64,
520+
) -> Result<(), Error> {
521+
let mut limited = reader.take(max_bytes);
522+
tokio::io::copy(&mut limited, writer)
523+
.await
524+
.map_err(Error::StreamingBucket)?;
525+
526+
// If the underlying reader still has data, the limit was exceeded.
527+
let mut decoder = limited.into_inner();
528+
let mut overflow = [0u8; 1];
529+
if decoder
530+
.read(&mut overflow)
531+
.await
532+
.map_err(Error::StreamingBucket)?
533+
> 0
534+
{
535+
return Err(Error::DecompressedSizeLimitExceeded {
536+
max: ByteSize(max_bytes),
537+
});
538+
}
539+
Ok(())
540+
}
541+
504542
fn ledger_to_path_components(ledger: u32) -> (String, String, String, String) {
505543
let ledger_hex = format!("{ledger:08x}");
506544
let ledger_hex_0 = ledger_hex[0..=1].to_string();
@@ -597,7 +635,7 @@ async fn get_ledger_metadata_from_archive(
597635
.map(|result| result.map_err(std::io::Error::other));
598636
let stream_reader = StreamReader::new(stream);
599637
let buf_reader = BufReader::new(stream_reader);
600-
let mut decoder = GzipDecoder::new(buf_reader);
638+
let decoder = GzipDecoder::new(buf_reader);
601639

602640
let mut file = OpenOptions::new()
603641
.create(true)
@@ -607,9 +645,10 @@ async fn get_ledger_metadata_from_archive(
607645
.await
608646
.map_err(Error::WriteOpeningCachedBucket)?;
609647

610-
tokio::io::copy(&mut decoder, &mut file)
611-
.await
612-
.map_err(Error::StreamingBucket)?;
648+
if let Err(e) = copy_with_limit(decoder, &mut file, MAX_LEDGER_HEADER_DECOMPRESSED_SIZE).await {
649+
let _ = fs::remove_file(&dl_path);
650+
return Err(e);
651+
}
613652

614653
fs::rename(&dl_path, &cache_path).map_err(Error::RenameDownloadFile)?;
615654

@@ -709,7 +748,7 @@ async fn cache_bucket(
709748
.map(|result| result.map_err(std::io::Error::other));
710749
let stream_reader = StreamReader::new(stream);
711750
let buf_reader = BufReader::new(stream_reader);
712-
let mut decoder = GzipDecoder::new(buf_reader);
751+
let decoder = GzipDecoder::new(buf_reader);
713752
let dl_path = cache_path.with_extension("dl");
714753
let mut file = OpenOptions::new()
715754
.create(true)
@@ -718,9 +757,12 @@ async fn cache_bucket(
718757
.open(&dl_path)
719758
.await
720759
.map_err(Error::WriteOpeningCachedBucket)?;
721-
tokio::io::copy(&mut decoder, &mut file)
722-
.await
723-
.map_err(Error::StreamingBucket)?;
760+
761+
if let Err(e) = copy_with_limit(decoder, &mut file, MAX_BUCKET_DECOMPRESSED_SIZE).await {
762+
let _ = fs::remove_file(&dl_path);
763+
return Err(e);
764+
}
765+
724766
fs::rename(&dl_path, &cache_path).map_err(Error::RenameDownloadFile)?;
725767
}
726768
Ok(cache_path)
@@ -740,3 +782,35 @@ struct HistoryBucket {
740782
curr: String,
741783
snap: String,
742784
}
785+
786+
#[cfg(test)]
787+
mod test {
788+
use super::*;
789+
790+
#[tokio::test]
791+
async fn test_copy_with_limit_under_limit() {
792+
let input: &[u8] = b"hello";
793+
let mut output = Vec::new();
794+
copy_with_limit(input, &mut output, 10).await.unwrap();
795+
assert_eq!(output, b"hello");
796+
}
797+
798+
#[tokio::test]
799+
async fn test_copy_with_limit_exact_limit() {
800+
let input: &[u8] = b"hello";
801+
let mut output = Vec::new();
802+
copy_with_limit(input, &mut output, 5).await.unwrap();
803+
assert_eq!(output, b"hello");
804+
}
805+
806+
#[tokio::test]
807+
async fn test_copy_with_limit_over_limit() {
808+
let input: &[u8] = b"hello world, this exceeds the limit";
809+
let mut output = Vec::new();
810+
let err = copy_with_limit(input, &mut output, 10).await.unwrap_err();
811+
assert!(
812+
matches!(err, Error::DecompressedSizeLimitExceeded { .. }),
813+
"expected DecompressedSizeLimitExceeded, got: {err}"
814+
);
815+
}
816+
}

cmd/soroban-cli/src/utils.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,15 @@ pub fn get_name_from_stellar_asset_contract_storage(storage: &ScMap) -> Option<S
211211
}
212212

213213
pub mod http {
214+
use std::time::Duration;
215+
214216
use crate::commands::version;
215217
fn user_agent() -> String {
216218
format!("{}/{}", env!("CARGO_PKG_NAME"), version::pkg())
217219
}
218220

221+
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
222+
219223
/// Creates and returns a configured `reqwest::Client`.
220224
///
221225
/// # Panics
@@ -228,6 +232,7 @@ pub mod http {
228232
// 3. This simplifies error handling for callers, as they can assume a valid client.
229233
reqwest::Client::builder()
230234
.user_agent(user_agent())
235+
.connect_timeout(CONNECT_TIMEOUT)
231236
.build()
232237
.expect("Failed to build reqwest client")
233238
}
@@ -240,6 +245,7 @@ pub mod http {
240245
pub fn blocking_client() -> reqwest::blocking::Client {
241246
reqwest::blocking::Client::builder()
242247
.user_agent(user_agent())
248+
.connect_timeout(CONNECT_TIMEOUT)
243249
.build()
244250
.expect("Failed to build reqwest blocking client")
245251
}

0 commit comments

Comments
 (0)