Add download timeout and retry logic (#8149)
* Add timeout to download * Retry failed downloads on network errors Previously, the download would either fail immediately or hang indefinitely if when the user e.g. changed their tunnel state. * Fix progress when resuming download * Import thiserror on all platforms * Add to installer downloader changelog
This commit is contained in:
parent
9191354880
commit
10dc35368c
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -3238,6 +3238,7 @@ dependencies = [
|
||||
"hex",
|
||||
"insta",
|
||||
"json-canon",
|
||||
"log",
|
||||
"mockito",
|
||||
"mullvad-version",
|
||||
"rand 0.8.5",
|
||||
|
@ -21,10 +21,10 @@ Line wrap the file at 100 chars. Th
|
||||
|
||||
## [Unreleased]
|
||||
### Fix
|
||||
- Fix downloads hanging indefinitely on switching networks
|
||||
#### macOS
|
||||
- Fix rendering issues on old (unsupported) macOS versions.
|
||||
|
||||
|
||||
## [1.0.0] - 2025-05-13
|
||||
### Fixed
|
||||
#### Windows
|
||||
|
@ -24,6 +24,7 @@ hex = { version = "0.4" }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
zeroize = { version = "1.8", features = ["zeroize_derive"] }
|
||||
log = { workspace = true }
|
||||
|
||||
reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"], optional = true }
|
||||
sha2 = { workspace = true, optional = true }
|
||||
@ -36,7 +37,6 @@ mullvad-version = { path = "../mullvad-version", features = ["serde"] }
|
||||
clap = { workspace = true, optional = true }
|
||||
rand = { version = "0.8.5", optional = true }
|
||||
|
||||
[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
|
||||
thiserror = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -19,7 +19,7 @@ use crate::{
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum DownloadError {
|
||||
#[error("Failed to download app")]
|
||||
FetchApp(#[source] anyhow::Error),
|
||||
FetchApp(#[from] anyhow::Error),
|
||||
#[error("Failed to verify app")]
|
||||
Verification(#[source] anyhow::Error),
|
||||
#[error("Failed to launch app")]
|
||||
|
@ -1,9 +1,11 @@
|
||||
//! A downloader that supports HTTP range requests and resuming downloads
|
||||
|
||||
use std::{
|
||||
error::Error,
|
||||
path::Path,
|
||||
pin::Pin,
|
||||
task::{ready, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use reqwest::header::{HeaderValue, CONTENT_LENGTH, RANGE};
|
||||
@ -12,7 +14,110 @@ use tokio::{
|
||||
io::{self, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Start value of the read timeout. This is doubled on each retry.
|
||||
const READ_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
// Maximum number of retry attempts for timeouts
|
||||
const MAX_RETRY_ATTEMPTS: u32 = 4;
|
||||
|
||||
/// Custom error type for download operations
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DownloadError {
|
||||
/// Failed to initialize client
|
||||
#[error("Failed to initialize HTTP client")]
|
||||
ClientInitialization(#[source] reqwest::Error),
|
||||
|
||||
/// Failed to get content length
|
||||
#[error("Failed to request download")]
|
||||
HeadRequest(#[source] reqwest::Error),
|
||||
|
||||
/// Server returned error status
|
||||
#[error("Download failed: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
/// Invalid content length header
|
||||
#[error("Invalid content length header: {0}")]
|
||||
InvalidContentLength(&'static str),
|
||||
|
||||
/// Failed to make range request
|
||||
#[error("Failed to retrieve range")]
|
||||
RangeRequest(#[source] reqwest::Error),
|
||||
|
||||
/// Failed to read chunk
|
||||
#[error("Failed to read chunk")]
|
||||
ChunkRead(#[source] reqwest::Error),
|
||||
|
||||
/// Failed to write chunk
|
||||
#[error("Failed to write chunk")]
|
||||
ChunkWrite(#[source] io::Error),
|
||||
|
||||
/// Failed to get stream position
|
||||
#[error("Failed to get existing file size")]
|
||||
StreamPosition(#[source] io::Error),
|
||||
|
||||
/// Failed to flush writer
|
||||
#[error("Failed to flush writer")]
|
||||
Flush(#[source] io::Error),
|
||||
|
||||
/// Size validation error
|
||||
#[error("Size validation failed: {0}")]
|
||||
SizeValidation(String),
|
||||
|
||||
/// File operation error
|
||||
#[error("File operation failed: {0}")]
|
||||
FileOperation(#[source] io::Error),
|
||||
|
||||
/// Other error
|
||||
#[error("{0}")]
|
||||
Other(&'static str),
|
||||
}
|
||||
|
||||
impl DownloadError {
|
||||
/// Checks if the error is caused by a timeout or network issue that can be retried
|
||||
pub fn should_retry(&self) -> bool {
|
||||
match self {
|
||||
DownloadError::HeadRequest(e)
|
||||
| DownloadError::RangeRequest(e)
|
||||
| DownloadError::ChunkRead(e)
|
||||
| DownloadError::ClientInitialization(e) => is_network_error(e),
|
||||
DownloadError::HttpStatus(status) => {
|
||||
// Retry server errors and timeout status
|
||||
status.is_server_error() || *status == reqwest::StatusCode::REQUEST_TIMEOUT
|
||||
}
|
||||
// Don't retry other types of errors
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the error is a network-related error that can be retried
|
||||
fn is_network_error(error: &reqwest::Error) -> bool {
|
||||
// Retry on timeout errors
|
||||
// Retry on connection errors (which often happen when switching networks)
|
||||
// Retry on request errors (like "connection reset")
|
||||
if error.is_timeout() || error.is_connect() || error.is_request() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut error = error as &dyn Error;
|
||||
loop {
|
||||
if let Some(io_err) = error.downcast_ref::<std::io::Error>() {
|
||||
// Check if the error is a timeout or connection error
|
||||
if io_err.kind() == io::ErrorKind::TimedOut
|
||||
|| io_err.kind() == io::ErrorKind::ConnectionReset
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if let Some(source) = error.source() {
|
||||
error = source;
|
||||
} else {
|
||||
break false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Receiver of the current progress so far
|
||||
pub trait ProgressUpdater: Send + 'static {
|
||||
@ -67,9 +172,26 @@ pub async fn get_to_file(
|
||||
progress_updater: &mut impl ProgressUpdater,
|
||||
size_hint: SizeHint,
|
||||
) -> anyhow::Result<()> {
|
||||
let file = create_or_append(file).await?;
|
||||
let file = BufWriter::new(file);
|
||||
get_to_writer(file, url, progress_updater, size_hint).await
|
||||
let file = create_or_append(file)
|
||||
.await
|
||||
.map_err(DownloadError::FileOperation)?;
|
||||
let mut file = BufWriter::new(file);
|
||||
let mut attempts = 0;
|
||||
let mut read_timeout = READ_TIMEOUT;
|
||||
while let Err(err) =
|
||||
get_to_writer(&mut file, url, progress_updater, size_hint, read_timeout).await
|
||||
{
|
||||
if !err.should_retry() {
|
||||
anyhow::bail!(err);
|
||||
}
|
||||
attempts += 1;
|
||||
read_timeout *= 2;
|
||||
if attempts >= MAX_RETRY_ATTEMPTS {
|
||||
anyhow::bail!("Max retry attempts reached: {err}");
|
||||
}
|
||||
log::warn!("Download failed: {err}. Retrying in with timeout: {read_timeout:?}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download `url` to `writer`.
|
||||
@ -82,41 +204,59 @@ pub async fn get_to_writer(
|
||||
url: &str,
|
||||
progress_updater: &mut impl ProgressUpdater,
|
||||
size_hint: SizeHint,
|
||||
) -> anyhow::Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
read_timeout: Duration,
|
||||
) -> Result<(), DownloadError> {
|
||||
// Create a new client for each download attempt to prevent stale connections
|
||||
let client = reqwest::Client::builder()
|
||||
.read_timeout(read_timeout)
|
||||
.connect_timeout(CONNECT_TIMEOUT)
|
||||
.build()
|
||||
.map_err(DownloadError::ClientInitialization)?;
|
||||
|
||||
progress_updater.set_url(url);
|
||||
progress_updater.set_progress(0.);
|
||||
|
||||
// Fetch content length first
|
||||
let response = client.head(url).send().await.context("HEAD failed")?;
|
||||
let response = client
|
||||
.head(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(DownloadError::HeadRequest)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return response
|
||||
.error_for_status()
|
||||
.map(|_| ())
|
||||
.context("Download failed");
|
||||
return Err(DownloadError::HttpStatus(response.status()));
|
||||
}
|
||||
|
||||
let total_size = response
|
||||
.headers()
|
||||
.get(CONTENT_LENGTH)
|
||||
.context("Missing file size")?;
|
||||
let total_size: usize = total_size.to_str()?.parse().context("invalid size")?;
|
||||
size_hint.check_size(total_size)?;
|
||||
.ok_or_else(|| DownloadError::InvalidContentLength("Missing file size"))?;
|
||||
|
||||
let total_size: usize = total_size
|
||||
.to_str()
|
||||
.map_err(|_| DownloadError::InvalidContentLength("Invalid content length header"))?
|
||||
.parse()
|
||||
.map_err(|_| DownloadError::InvalidContentLength("Invalid size format"))?;
|
||||
|
||||
match size_hint.check_size(total_size) {
|
||||
Ok(_) => {}
|
||||
Err(e) => return Err(DownloadError::SizeValidation(e.to_string())),
|
||||
}
|
||||
|
||||
let already_fetched_bytes = writer
|
||||
.stream_position()
|
||||
.await
|
||||
.context("failed to get existing file size")?
|
||||
.map_err(DownloadError::StreamPosition)?
|
||||
.try_into()
|
||||
.context("invalid size")?;
|
||||
.map_err(|_| DownloadError::Other("Invalid file position"))?;
|
||||
|
||||
progress_updater.set_progress(already_fetched_bytes as f32 / total_size as f32);
|
||||
if total_size == already_fetched_bytes {
|
||||
progress_updater.set_progress(1.);
|
||||
return Ok(());
|
||||
}
|
||||
if already_fetched_bytes > total_size {
|
||||
anyhow::bail!("Found existing file that was larger");
|
||||
return Err(DownloadError::SizeValidation(
|
||||
"Found existing file that was larger".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch content, one range at a time
|
||||
@ -133,32 +273,32 @@ pub async fn get_to_writer(
|
||||
.header(RANGE, range)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to retrieve range")?;
|
||||
.map_err(DownloadError::RangeRequest)?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return response
|
||||
.error_for_status()
|
||||
.map(|_| ())
|
||||
.context("Download failed");
|
||||
return Err(DownloadError::HttpStatus(status));
|
||||
}
|
||||
|
||||
let mut bytes_read = 0;
|
||||
|
||||
while let Some(chunk) = response.chunk().await.context("Failed to read chunk")? {
|
||||
while let Some(chunk) = response.chunk().await.map_err(DownloadError::ChunkRead)? {
|
||||
bytes_read += chunk.len();
|
||||
if bytes_read > total_size - already_fetched_bytes {
|
||||
// Protect against servers responding with more data than expected
|
||||
anyhow::bail!("Server returned more than requested bytes");
|
||||
return Err(DownloadError::SizeValidation(
|
||||
"Server returned more than requested bytes".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
writer
|
||||
.write_all(&chunk)
|
||||
.await
|
||||
.context("Failed to write chunk")?;
|
||||
.map_err(DownloadError::ChunkWrite)?;
|
||||
}
|
||||
}
|
||||
|
||||
writer.shutdown().await.context("Failed to flush")?;
|
||||
writer.shutdown().await.map_err(DownloadError::Flush)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -261,6 +401,7 @@ impl<PU: ProgressUpdater, Writer: AsyncWrite + Unpin> AsyncWrite
|
||||
mod test {
|
||||
use std::io::Cursor;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_tempfile::TempDir;
|
||||
use rand::RngCore;
|
||||
use tokio::{fs, io::AsyncWriteExt};
|
||||
@ -344,6 +485,7 @@ mod test {
|
||||
&file_url,
|
||||
&mut progress_updater,
|
||||
SizeHint::Exact(file_data.len()),
|
||||
READ_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.context("Complete download failed")?;
|
||||
@ -378,6 +520,7 @@ mod test {
|
||||
&file_url,
|
||||
&mut progress_updater,
|
||||
SizeHint::Exact(file_data.len()),
|
||||
READ_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.expect_err("Expected interrupted download");
|
||||
@ -408,6 +551,7 @@ mod test {
|
||||
&file_url,
|
||||
&mut progress_updater,
|
||||
SizeHint::Exact(file_data.len()),
|
||||
READ_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.context("Partial download failed")?;
|
||||
@ -468,6 +612,7 @@ mod test {
|
||||
&file_url,
|
||||
&mut FakeProgressUpdater::default(),
|
||||
SizeHint::Exact(1),
|
||||
READ_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.expect_err("Reject unexpected content length");
|
||||
@ -492,6 +637,7 @@ mod test {
|
||||
&file_url,
|
||||
&mut FakeProgressUpdater::default(),
|
||||
SizeHint::Exact(file_data.len()),
|
||||
READ_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.expect_err("Reject unexpected chunk sizes");
|
||||
|
1
test/Cargo.lock
generated
1
test/Cargo.lock
generated
@ -2217,6 +2217,7 @@ dependencies = [
|
||||
"ed25519-dalek",
|
||||
"hex",
|
||||
"json-canon",
|
||||
"log",
|
||||
"mullvad-version",
|
||||
"reqwest",
|
||||
"serde",
|
||||
|
Loading…
x
Reference in New Issue
Block a user