Add download timeout and retry logic (#8149)

* Add timeout to download

* Retry failed downloads on network errors

Previously, the download would either fail immediately or hang
indefinitely if when the user e.g. changed their tunnel state.

* Fix progress when resuming download

* Import thiserror on all platforms

* Add to installer downloader changelog
This commit is contained in:
Sebastian Holmin 2025-05-12 15:18:17 +00:00
parent 9191354880
commit 10dc35368c
No known key found for this signature in database
GPG Key ID: 9C88494B3F2F9089
6 changed files with 179 additions and 31 deletions

1
Cargo.lock generated
View File

@ -3238,6 +3238,7 @@ dependencies = [
"hex", "hex",
"insta", "insta",
"json-canon", "json-canon",
"log",
"mockito", "mockito",
"mullvad-version", "mullvad-version",
"rand 0.8.5", "rand 0.8.5",

View File

@ -21,10 +21,10 @@ Line wrap the file at 100 chars. Th
## [Unreleased] ## [Unreleased]
### Fix ### Fix
- Fix downloads hanging indefinitely on switching networks
#### macOS #### macOS
- Fix rendering issues on old (unsupported) macOS versions. - Fix rendering issues on old (unsupported) macOS versions.
## [1.0.0] - 2025-05-13 ## [1.0.0] - 2025-05-13
### Fixed ### Fixed
#### Windows #### Windows

View File

@ -24,6 +24,7 @@ hex = { version = "0.4" }
serde = { workspace = true, features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true } serde_json = { workspace = true }
zeroize = { version = "1.8", features = ["zeroize_derive"] } zeroize = { version = "1.8", features = ["zeroize_derive"] }
log = { workspace = true }
reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"], optional = true } reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"], optional = true }
sha2 = { workspace = true, optional = true } sha2 = { workspace = true, optional = true }
@ -36,7 +37,6 @@ mullvad-version = { path = "../mullvad-version", features = ["serde"] }
clap = { workspace = true, optional = true } clap = { workspace = true, optional = true }
rand = { version = "0.8.5", optional = true } rand = { version = "0.8.5", optional = true }
[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
thiserror = { workspace = true, optional = true } thiserror = { workspace = true, optional = true }
[dev-dependencies] [dev-dependencies]

View File

@ -19,7 +19,7 @@ use crate::{
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum DownloadError { pub enum DownloadError {
#[error("Failed to download app")] #[error("Failed to download app")]
FetchApp(#[source] anyhow::Error), FetchApp(#[from] anyhow::Error),
#[error("Failed to verify app")] #[error("Failed to verify app")]
Verification(#[source] anyhow::Error), Verification(#[source] anyhow::Error),
#[error("Failed to launch app")] #[error("Failed to launch app")]

View File

@ -1,9 +1,11 @@
//! A downloader that supports HTTP range requests and resuming downloads //! A downloader that supports HTTP range requests and resuming downloads
use std::{ use std::{
error::Error,
path::Path, path::Path,
pin::Pin, pin::Pin,
task::{ready, Poll}, task::{ready, Poll},
time::Duration,
}; };
use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE}; use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE};
@ -12,7 +14,110 @@ use tokio::{
io::{self, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}, io::{self, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter},
}; };
use anyhow::Context; use thiserror::Error;
/// Start value of the read timeout. This is doubled on each retry.
const READ_TIMEOUT: Duration = Duration::from_secs(1);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
// Maximum number of retry attempts for timeouts
const MAX_RETRY_ATTEMPTS: u32 = 4;
/// Custom error type for download operations
#[derive(Error, Debug)]
pub enum DownloadError {
/// Failed to initialize client
#[error("Failed to initialize HTTP client")]
ClientInitialization(#[source] reqwest::Error),
/// Failed to get content length
#[error("Failed to request download")]
HeadRequest(#[source] reqwest::Error),
/// Server returned error status
#[error("Download failed: {0}")]
HttpStatus(reqwest::StatusCode),
/// Invalid content length header
#[error("Invalid content length header: {0}")]
InvalidContentLength(&'static str),
/// Failed to make range request
#[error("Failed to retrieve range")]
RangeRequest(#[source] reqwest::Error),
/// Failed to read chunk
#[error("Failed to read chunk")]
ChunkRead(#[source] reqwest::Error),
/// Failed to write chunk
#[error("Failed to write chunk")]
ChunkWrite(#[source] io::Error),
/// Failed to get stream position
#[error("Failed to get existing file size")]
StreamPosition(#[source] io::Error),
/// Failed to flush writer
#[error("Failed to flush writer")]
Flush(#[source] io::Error),
/// Size validation error
#[error("Size validation failed: {0}")]
SizeValidation(String),
/// File operation error
#[error("File operation failed: {0}")]
FileOperation(#[source] io::Error),
/// Other error
#[error("{0}")]
Other(&'static str),
}
impl DownloadError {
/// Checks if the error is caused by a timeout or network issue that can be retried
pub fn should_retry(&self) -> bool {
match self {
DownloadError::HeadRequest(e)
| DownloadError::RangeRequest(e)
| DownloadError::ChunkRead(e)
| DownloadError::ClientInitialization(e) => is_network_error(e),
DownloadError::HttpStatus(status) => {
// Retry server errors and timeout status
status.is_server_error() || *status == reqwest::StatusCode::REQUEST_TIMEOUT
}
// Don't retry other types of errors
_ => false,
}
}
}
/// Checks if the error is a network-related error that can be retried
fn is_network_error(error: &reqwest::Error) -> bool {
// Retry on timeout errors
// Retry on connection errors (which often happen when switching networks)
// Retry on request errors (like "connection reset")
if error.is_timeout() || error.is_connect() || error.is_request() {
return true;
}
let mut error = error as &dyn Error;
loop {
if let Some(io_err) = error.downcast_ref::<std::io::Error>() {
// Check if the error is a timeout or connection error
if io_err.kind() == io::ErrorKind::TimedOut
|| io_err.kind() == io::ErrorKind::ConnectionReset
{
return true;
}
}
if let Some(source) = error.source() {
error = source;
} else {
break false;
}
}
}
/// Receiver of the current progress so far /// Receiver of the current progress so far
pub trait ProgressUpdater: Send + 'static { pub trait ProgressUpdater: Send + 'static {
@ -67,9 +172,26 @@ pub async fn get_to_file(
progress_updater: &mut impl ProgressUpdater, progress_updater: &mut impl ProgressUpdater,
size_hint: SizeHint, size_hint: SizeHint,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let file = create_or_append(file).await?; let file = create_or_append(file)
let file = BufWriter::new(file); .await
get_to_writer(file, url, progress_updater, size_hint).await .map_err(DownloadError::FileOperation)?;
let mut file = BufWriter::new(file);
let mut attempts = 0;
let mut read_timeout = READ_TIMEOUT;
while let Err(err) =
get_to_writer(&mut file, url, progress_updater, size_hint, read_timeout).await
{
if !err.should_retry() {
anyhow::bail!(err);
}
attempts += 1;
read_timeout *= 2;
if attempts >= MAX_RETRY_ATTEMPTS {
anyhow::bail!("Max retry attempts reached: {err}");
}
log::warn!("Download failed: {err}. Retrying in with timeout: {read_timeout:?}");
}
Ok(())
} }
/// Download `url` to `writer`. /// Download `url` to `writer`.
@ -82,41 +204,59 @@ pub async fn get_to_writer(
url: &str, url: &str,
progress_updater: &mut impl ProgressUpdater, progress_updater: &mut impl ProgressUpdater,
size_hint: SizeHint, size_hint: SizeHint,
) -> anyhow::Result<()> { read_timeout: Duration,
let client = reqwest::Client::new(); ) -> Result<(), DownloadError> {
// Create a new client for each download attempt to prevent stale connections
let client = reqwest::Client::builder()
.read_timeout(read_timeout)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.map_err(DownloadError::ClientInitialization)?;
progress_updater.set_url(url); progress_updater.set_url(url);
progress_updater.set_progress(0.);
// Fetch content length first // Fetch content length first
let response = client.head(url).send().await.context("HEAD failed")?; let response = client
.head(url)
.send()
.await
.map_err(DownloadError::HeadRequest)?;
if !response.status().is_success() { if !response.status().is_success() {
return response return Err(DownloadError::HttpStatus(response.status()));
.error_for_status()
.map(|_| ())
.context("Download failed");
} }
let total_size = response let total_size = response
.headers() .headers()
.get(CONTENT_LENGTH) .get(CONTENT_LENGTH)
.context("Missing file size")?; .ok_or_else(|| DownloadError::InvalidContentLength("Missing file size"))?;
let total_size: usize = total_size.to_str()?.parse().context("invalid size")?;
size_hint.check_size(total_size)?; let total_size: usize = total_size
.to_str()
.map_err(|_| DownloadError::InvalidContentLength("Invalid content length header"))?
.parse()
.map_err(|_| DownloadError::InvalidContentLength("Invalid size format"))?;
match size_hint.check_size(total_size) {
Ok(_) => {}
Err(e) => return Err(DownloadError::SizeValidation(e.to_string())),
}
let already_fetched_bytes = writer let already_fetched_bytes = writer
.stream_position() .stream_position()
.await .await
.context("failed to get existing file size")? .map_err(DownloadError::StreamPosition)?
.try_into() .try_into()
.context("invalid size")?; .map_err(|_| DownloadError::Other("Invalid file position"))?;
progress_updater.set_progress(already_fetched_bytes as f32 / total_size as f32);
if total_size == already_fetched_bytes { if total_size == already_fetched_bytes {
progress_updater.set_progress(1.);
return Ok(()); return Ok(());
} }
if already_fetched_bytes > total_size { if already_fetched_bytes > total_size {
anyhow::bail!("Found existing file that was larger"); return Err(DownloadError::SizeValidation(
"Found existing file that was larger".to_string(),
));
} }
// Fetch content, one range at a time // Fetch content, one range at a time
@ -133,32 +273,32 @@ pub async fn get_to_writer(
.header(RANGE, range) .header(RANGE, range)
.send() .send()
.await .await
.context("Failed to retrieve range")?; .map_err(DownloadError::RangeRequest)?;
let status = response.status(); let status = response.status();
if !status.is_success() { if !status.is_success() {
return response return Err(DownloadError::HttpStatus(status));
.error_for_status()
.map(|_| ())
.context("Download failed");
} }
let mut bytes_read = 0; let mut bytes_read = 0;
while let Some(chunk) = response.chunk().await.context("Failed to read chunk")? { while let Some(chunk) = response.chunk().await.map_err(DownloadError::ChunkRead)? {
bytes_read += chunk.len(); bytes_read += chunk.len();
if bytes_read > total_size - already_fetched_bytes { if bytes_read > total_size - already_fetched_bytes {
// Protect against servers responding with more data than expected // Protect against servers responding with more data than expected
anyhow::bail!("Server returned more than requested bytes"); return Err(DownloadError::SizeValidation(
"Server returned more than requested bytes".to_string(),
));
} }
writer writer
.write_all(&chunk) .write_all(&chunk)
.await .await
.context("Failed to write chunk")?; .map_err(DownloadError::ChunkWrite)?;
} }
} }
writer.shutdown().await.context("Failed to flush")?; writer.shutdown().await.map_err(DownloadError::Flush)?;
Ok(()) Ok(())
} }
@ -261,6 +401,7 @@ impl<PU: ProgressUpdater, Writer: AsyncWrite + Unpin> AsyncWrite
mod test { mod test {
use std::io::Cursor; use std::io::Cursor;
use anyhow::Context;
use async_tempfile::TempDir; use async_tempfile::TempDir;
use rand::RngCore; use rand::RngCore;
use tokio::{fs, io::AsyncWriteExt}; use tokio::{fs, io::AsyncWriteExt};
@ -344,6 +485,7 @@ mod test {
&file_url, &file_url,
&mut progress_updater, &mut progress_updater,
SizeHint::Exact(file_data.len()), SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
) )
.await .await
.context("Complete download failed")?; .context("Complete download failed")?;
@ -378,6 +520,7 @@ mod test {
&file_url, &file_url,
&mut progress_updater, &mut progress_updater,
SizeHint::Exact(file_data.len()), SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
) )
.await .await
.expect_err("Expected interrupted download"); .expect_err("Expected interrupted download");
@ -408,6 +551,7 @@ mod test {
&file_url, &file_url,
&mut progress_updater, &mut progress_updater,
SizeHint::Exact(file_data.len()), SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
) )
.await .await
.context("Partial download failed")?; .context("Partial download failed")?;
@ -468,6 +612,7 @@ mod test {
&file_url, &file_url,
&mut FakeProgressUpdater::default(), &mut FakeProgressUpdater::default(),
SizeHint::Exact(1), SizeHint::Exact(1),
READ_TIMEOUT,
) )
.await .await
.expect_err("Reject unexpected content length"); .expect_err("Reject unexpected content length");
@ -492,6 +637,7 @@ mod test {
&file_url, &file_url,
&mut FakeProgressUpdater::default(), &mut FakeProgressUpdater::default(),
SizeHint::Exact(file_data.len()), SizeHint::Exact(file_data.len()),
READ_TIMEOUT,
) )
.await .await
.expect_err("Reject unexpected chunk sizes"); .expect_err("Reject unexpected chunk sizes");

1
test/Cargo.lock generated
View File

@ -2217,6 +2217,7 @@ dependencies = [
"ed25519-dalek", "ed25519-dalek",
"hex", "hex",
"json-canon", "json-canon",
"log",
"mullvad-version", "mullvad-version",
"reqwest", "reqwest",
"serde", "serde",