Add download timeout and retry logic (#8149)

* Add timeout to download

* Retry failed downloads on network errors

Previously, the download would either fail immediately or hang
indefinitely if when the user e.g. changed their tunnel state.

* Fix progress when resuming download

* Import thiserror on all platforms

* Add to installer downloader changelog
This commit is contained in:
Sebastian Holmin 2025-05-12 15:18:17 +00:00
parent 9191354880
commit 10dc35368c
No known key found for this signature in database
GPG Key ID: 9C88494B3F2F9089
6 changed files with 179 additions and 31 deletions

1
Cargo.lock generated
View File

@ -3238,6 +3238,7 @@ dependencies = [
"hex",
"insta",
"json-canon",
"log",
"mockito",
"mullvad-version",
"rand 0.8.5",

View File

@ -21,10 +21,10 @@ Line wrap the file at 100 chars. Th
## [Unreleased]
### Fix
- Fix downloads hanging indefinitely on switching networks
#### macOS
- Fix rendering issues on old (unsupported) macOS versions.
## [1.0.0] - 2025-05-13
### Fixed
#### Windows

View File

@ -24,6 +24,7 @@ hex = { version = "0.4" }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
zeroize = { version = "1.8", features = ["zeroize_derive"] }
log = { workspace = true }
reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"], optional = true }
sha2 = { workspace = true, optional = true }
@ -36,7 +37,6 @@ mullvad-version = { path = "../mullvad-version", features = ["serde"] }
clap = { workspace = true, optional = true }
rand = { version = "0.8.5", optional = true }
[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
thiserror = { workspace = true, optional = true }
[dev-dependencies]

View File

@ -19,7 +19,7 @@ use crate::{
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
#[error("Failed to download app")]
FetchApp(#[source] anyhow::Error),
FetchApp(#[from] anyhow::Error),
#[error("Failed to verify app")]
Verification(#[source] anyhow::Error),
#[error("Failed to launch app")]

View File

@ -1,9 +1,11 @@
//! A downloader that supports HTTP range requests and resuming downloads
use std::{
error::Error,
path::Path,
pin::Pin,
task::{ready, Poll},
time::Duration,
};
use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE};
@ -12,7 +14,110 @@ use tokio::{
io::{self, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter},
};
use anyhow::Context;
use thiserror::Error;
/// Start value of the read timeout. This is doubled on each retry.
const READ_TIMEOUT: Duration = Duration::from_secs(1);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
// Maximum number of retry attempts for timeouts
const MAX_RETRY_ATTEMPTS: u32 = 4;
/// Custom error type for download operations
#[derive(Error, Debug)]
pub enum DownloadError {
/// Failed to initialize client
#[error("Failed to initialize HTTP client")]
ClientInitialization(#[source] reqwest::Error),
/// Failed to get content length
#[error("Failed to request download")]
HeadRequest(#[source] reqwest::Error),
/// Server returned error status
#[error("Download failed: {0}")]
HttpStatus(reqwest::StatusCode),
/// Invalid content length header
#[error("Invalid content length header: {0}")]
InvalidContentLength(&'static str),
/// Failed to make range request
#[error("Failed to retrieve range")]
RangeRequest(#[source] reqwest::Error),
/// Failed to read chunk
#[error("Failed to read chunk")]
ChunkRead(#[source] reqwest::Error),
/// Failed to write chunk
#[error("Failed to write chunk")]
ChunkWrite(#[source] io::Error),
/// Failed to get stream position
#[error("Failed to get existing file size")]
StreamPosition(#[source] io::Error),
/// Failed to flush writer
#[error("Failed to flush writer")]
Flush(#[source] io::Error),
/// Size validation error
#[error("Size validation failed: {0}")]
SizeValidation(String),
/// File operation error
#[error("File operation failed: {0}")]
FileOperation(#[source] io::Error),
/// Other error
#[error("{0}")]
Other(&'static str),
}
impl DownloadError {
/// Checks if the error is caused by a timeout or network issue that can be retried
pub fn should_retry(&self) -> bool {
match self {
DownloadError::HeadRequest(e)
| DownloadError::RangeRequest(e)
| DownloadError::ChunkRead(e)
| DownloadError::ClientInitialization(e) => is_network_error(e),
DownloadError::HttpStatus(status) => {
// Retry server errors and timeout status
status.is_server_error() || *status == reqwest::StatusCode::REQUEST_TIMEOUT
}
// Don't retry other types of errors
_ => false,
}
}
}
/// Checks if the error is a network-related error that can be retried
fn is_network_error(error: &reqwest::Error) -> bool {
// Retry on timeout errors
// Retry on connection errors (which often happen when switching networks)
// Retry on request errors (like "connection reset")
if error.is_timeout() || error.is_connect() || error.is_request() {
return true;
}
let mut error = error as &dyn Error;
loop {
if let Some(io_err) = error.downcast_ref::<std::io::Error>() {
// Check if the error is a timeout or connection error
if io_err.kind() == io::ErrorKind::TimedOut
|| io_err.kind() == io::ErrorKind::ConnectionReset
{
return true;
}
}
if let Some(source) = error.source() {
error = source;
} else {
break false;
}
}
}
/// Receiver of the current progress so far
pub trait ProgressUpdater: Send + 'static {
@ -67,9 +172,26 @@ pub async fn get_to_file(
progress_updater: &mut impl ProgressUpdater,
size_hint: SizeHint,
) -> anyhow::Result<()> {
let file = create_or_append(file).await?;
let file = BufWriter::new(file);
get_to_writer(file, url, progress_updater, size_hint).await
let file = create_or_append(file)
.await
.map_err(DownloadError::FileOperation)?;
let mut file = BufWriter::new(file);
let mut attempts = 0;
let mut read_timeout = READ_TIMEOUT;
while let Err(err) =
get_to_writer(&mut file, url, progress_updater, size_hint, read_timeout).await
{
if !err.should_retry() {
anyhow::bail!(err);
}
attempts += 1;
read_timeout *= 2;
if attempts >= MAX_RETRY_ATTEMPTS {
anyhow::bail!("Max retry attempts reached: {err}");
}
log::warn!("Download failed: {err}. Retrying in with timeout: {read_timeout:?}");
}
Ok(())
}
/// Download `url` to `writer`.
@ -82,41 +204,59 @@ pub async fn get_to_writer(
url: &str,
progress_updater: &mut impl ProgressUpdater,
size_hint: SizeHint,
) -> anyhow::Result<()> {
let client = reqwest::Client::new();
read_timeout: Duration,
) -> Result<(), DownloadError> {
// Create a new client for each download attempt to prevent stale connections
let client = reqwest::Client::builder()
.read_timeout(read_timeout)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.map_err(DownloadError::ClientInitialization)?;
progress_updater.set_url(url);
progress_updater.set_progress(0.);
// Fetch content length first
let response = client.head(url).send().await.context("HEAD failed")?;
let response = client
.head(url)
.send()
.await
.map_err(DownloadError::HeadRequest)?;
if !response.status().is_success() {
return response
.error_for_status()
.map(|_| ())
.context("Download failed");
return Err(DownloadError::HttpStatus(response.status()));
}
let total_size = response
.headers()
.get(CONTENT_LENGTH)
.context("Missing file size")?;
let total_size: usize = total_size.to_str()?.parse().context("invalid size")?;
size_hint.check_size(total_size)?;
.ok_or_else(|| DownloadError::InvalidContentLength("Missing file size"))?;
let total_size: usize = total_size
.to_str()
.map_err(|_| DownloadError::InvalidContentLength("Invalid content length header"))?
.parse()
.map_err(|_| DownloadError::InvalidContentLength("Invalid size format"))?;
match size_hint.check_size(total_size) {
Ok(_) => {}
Err(e) => return Err(DownloadError::SizeValidation(e.to_string())),
}
let already_fetched_bytes = writer
.stream_position()
.await
.context("failed to get existing file size")?
.map_err(DownloadError::StreamPosition)?
.try_into()
.context("invalid size")?;
.map_err(|_| DownloadError::Other("Invalid file position"))?;
progress_updater.set_progress(already_fetched_bytes as f32 / total_size as f32);
if total_size == already_fetched_bytes {
progress_updater.set_progress(1.);
return Ok(());
}
if already_fetched_bytes > total_size {
anyhow::bail!("Found existing file that was larger");
return Err(DownloadError::SizeValidation(
"Found existing file that was larger".to_string(),
));
}
// Fetch content, one range at a time
@ -133,32 +273,32 @@ pub async fn get_to_writer(
.header(RANGE, range)
.send()
.await
.context("Failed to retrieve range")?;
.map_err(DownloadError::RangeRequest)?;
let status = response.status();
if !status.is_success() {
return response
.error_for_status()
.map(|_| ())
.context("Download failed");
return Err(DownloadError::HttpStatus(status));
}
let mut bytes_read = 0;
while let Some(chunk) = response.chunk().await.context("Failed to read chunk")? {
while let Some(chunk) = response.chunk().await.map_err(DownloadError::ChunkRead)? {
bytes_read += chunk.len();
if bytes_read > total_size - already_fetched_bytes {
// Protect against servers responding with more data than expected
anyhow::bail!("Server returned more than requested bytes");
return Err(DownloadError::SizeValidation(
"Server returned more than requested bytes".to_string(),
));
}
writer
.write_all(&chunk)
.await
.context("Failed to write chunk")?;
.map_err(DownloadError::ChunkWrite)?;
}
}
writer.shutdown().await.context("Failed to flush")?;
writer.shutdown().await.map_err(DownloadError::Flush)?;
Ok(())
}
@ -261,6 +401,7 @@ impl<PU: ProgressUpdater, Writer: AsyncWrite + Unpin> AsyncWrite
mod test {
use std::io::Cursor;
use anyhow::Context;
use async_tempfile::TempDir;
use rand::RngCore;
use tokio::{fs, io::AsyncWriteExt};
@ -344,6 +485,7 @@ mod test {
&file_url,
&mut progress_updater,
SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
)
.await
.context("Complete download failed")?;
@ -378,6 +520,7 @@ mod test {
&file_url,
&mut progress_updater,
SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
)
.await
.expect_err("Expected interrupted download");
@ -408,6 +551,7 @@ mod test {
&file_url,
&mut progress_updater,
SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
)
.await
.context("Partial download failed")?;
@ -468,6 +612,7 @@ mod test {
&file_url,
&mut FakeProgressUpdater::default(),
SizeHint::Exact(1),
READ_TIMEOUT,
)
.await
.expect_err("Reject unexpected content length");
@ -492,6 +637,7 @@ mod test {
&file_url,
&mut FakeProgressUpdater::default(),
SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
)
.await
.expect_err("Reject unexpected chunk sizes");

1
test/Cargo.lock generated
View File

@ -2217,6 +2217,7 @@ dependencies = [
"ed25519-dalek",
"hex",
"json-canon",
"log",
"mullvad-version",
"reqwest",
"serde",